import * as React from 'react';
import Box from '@mui/material/Box';
import { Container, Paper, Alert, CircularProgress, Typography, Modal, TextField } from '@mui/material';
import Stepper from '@mui/material/Stepper';
import Step from '@mui/material/Step';
import StepLabel from '@mui/material/StepLabel';
import Button from '@mui/material/Button';
import Dataset from '../components/Dataset';
import QueryStatsIcon from '@mui/icons-material/QueryStats';
import { sendMessage, TABULAR_REGRESSION, TABULAR_CLASSIFIATION } from "../loadPyodide";
import { saveAs } from 'file-saver';
import { state2df } from "../utils/dataframes";
import deepcopy from "../utils/deepcopy";
import * as pd from "danfojs";
import { ClassificationReport } from '../components/ClassificationReport';
import { RegressionReport } from '../components/RegressionReport';

const TrainingSpinner = () => {
    return (
        <>
            <CircularProgress size={100} />
            <Typography variant="h6" sx={{ mt: 2 }}>
                Training Model...
            </Typography>
        </>
    );
};

const TrainStep = (props) => {

    const [isTraining, setTraining] = React.useState(0);
    const [model, setModel] = React.useState(null);
    if (props.dataset === null) {
        return <Alert variant="outlined" severity="error">
            No data is available for training
        </Alert>;
    }

    const trainModel = async () => {
        setTraining(1);

        const target = props.datasetMeta.data.$target;
        const castTypes = props.datasetMeta.data.castTypes;
        console.log(castTypes);
        let message;
        if (props.variant === "classification") {
            message = TABULAR_CLASSIFIATION;
        } else {
            message = TABULAR_REGRESSION;
        }
        const model = await sendMessage(message, {
            data: pd.toJSON(df, {
                format: "row"
            }),
            target: target,
        });
        props.setModelAndEval(model);
        setTraining(0);
    };

    const df = state2df(
        deepcopy(props.datasetMeta.data.$columns),
        deepcopy(props.dataset),
        props.datasetMeta.data.$dtypes,
        props.datasetMeta.data.$filters,
        props.datasetMeta.data.$target,
        props.datasetMeta.data.castTypes,
        props.datasetMeta.data.hiddenColumns,
        true,
        props.datasetMeta.data.dropNa
    );



    const getTrainingStep = () => {
        if (isTraining == 0) {
            return <>
                <Box>
                    <QueryStatsIcon
                        sx={{ fontSize: 196, color: 'gray', marginBottom: '20px' }}
                    />
                </Box>
                <Button
                    variant="contained"
                    color="primary"
                    size="large"
                    onClick={trainModel}
                    sx={{
                        padding: '12px 24px',
                        fontSize: '14px',
                        boxShadow: '0px 4px 10px rgba(0, 0, 0, 0.1)',
                        borderRadius: '8px',
                    }}
                >
                    Train
                </Button>
            </>
        } else {
            return <TrainingSpinner />;
        }

    }

    return (
        <Container
            maxWidth="sm"
            sx={{
                display: 'flex',
                flexDirection: 'column',
                alignItems: 'center',
                justifyContent: 'center',
                height: '100vh',
                textAlign: 'center',
                gap: 4, // Space between elements
            }}
        >
            {getTrainingStep()}
        </Container>
    );
};


export default function TabularPredictorFlow(props) {
    // state moved from redux
    const initialState = {
        data: {
            $columns: [],
            $dtypes: [],
            $filters: [],
            $target: null,
            empty: true,
            castTypes: {},
            hiddenColumns: [],
            dropNa: false,
        },
    };

    const [dataset, setDataset] = React.useState(null);
    const [datasetMeta, setDatasetMeta] = React.useState(initialState);
    const [model, setModel] = React.useState(null);

    const downloadModel = () => {
        if (model === null) {
            console.log("model is null");
            return;
        }
        console.log("download model");
        console.log(model);
        let blob = new Blob([model.model]);
        saveAs(blob, "model.fnnx");
    };


    const setData = (payload) => {
        console.log("setData", payload);
        let meta = deepcopy(payload);
        delete meta.$data;

        const newMeta = deepcopy(datasetMeta);
        newMeta.data.$columns = meta.$columns;
        newMeta.data.$dtypes = meta.$dtypes;
        newMeta.data.$filters = meta.$filters;
        newMeta.data.$target = meta.$target;
        console.log("newMeta")
        console.log(newMeta)
        setDatasetMeta(newMeta);
        setDataset(payload.$data);
    };

    const setFilters = (payload) => {
        const state = deepcopy(datasetMeta);
        state.data.$filters = payload;
        setDatasetMeta(state);
    }

    const setTarget = (payload) => {
        const state = deepcopy(datasetMeta);
        state.data.$target = payload;
        setDatasetMeta(state);
    }

    const addCast = (payload) => {
        const state = deepcopy(datasetMeta);
        state.data.castTypes[payload.field] = payload.type;
        const f = (filter) => {
            return filter.column != payload.field;
        }
        state.data.$filters = state.data.$filters.filter(f)
        setDatasetMeta(state);
    }

    const setHiddenColumns = (payload) => {
        const state = deepcopy(datasetMeta);
        state.data.hiddenColumns = payload.columns;
        setDatasetMeta(state);
    }

    const toggleDropNa = (payload) => {
        const state = deepcopy(datasetMeta);
        console.log('toggle Drop Na')
        state.data.dropNa = !state.data.dropNa;
        setDatasetMeta(state);
    }
    //


    const [activeStep, setActiveStep] = React.useState(0);

    const handleNext = () => {
        setActiveStep((prevActiveStep) => prevActiveStep + 1);
    };


    const handleBack = () => {
        setActiveStep((prevActiveStep) => prevActiveStep - 1);
    };

    // const handleReset = () => {
    //     setActiveStep(0);
    // };

    const setModelAndEval = (model) => {
        setModel(model);
        handleNext();
    }

    console.log(206)
    console.log(datasetMeta)
    const getPaper = (jsx) => {
        return (
            <Paper
                sx={{
                    p: 2,
                    display: "flex",
                    flexDirection: "column",
                    height: "77vh",
                    width: "100%",
                }}
            >
                {jsx}
            </Paper>
        );
    }

    const modalStyle = {
        position: 'absolute',
        top: '50%',
        left: '50%',
        transform: 'translate(-50%, -50%)',
        width: 400,
        bgcolor: 'background.paper',
        boxShadow: 24,
        p: 4,
    };

    const ReportWithPredict = (props) => {
        const [openModal, setOpenModal] = React.useState(false);

        const handleOpenModal = () => {
            setPredictionResult(null); // reset prediction result
            setOpenModal(true);
        };
        const handleCloseModal = () => setOpenModal(false);

        const [sepalLength, setSepalLength] = React.useState('');
        const [sepalWidth, setSepalWidth] = React.useState('');
        const [petalLength, setPetalLength] = React.useState('');
        const [petalWidth, setPetalWidth] = React.useState('');
        const [predictionResult, setPredictionResult] = React.useState(null);

        const handlePredict = () => {
            // Hard coded prediction
            setPredictionResult('Iris-setosa');
        };

        const modalContent = (
            <Box sx={modalStyle}>
                <h2>Predict</h2>
                <form>
                    <TextField label="Sepal Length" variant="outlined" fullWidth margin="normal" value={sepalLength} onChange={(e) => setSepalLength(e.target.value)} />
                    <TextField label="Sepal Width" variant="outlined" fullWidth margin="normal" value={sepalWidth} onChange={(e) => setSepalWidth(e.target.value)} />
                    <TextField label="Petal Length" variant="outlined" fullWidth margin="normal" value={petalLength} onChange={(e) => setPetalLength(e.target.value)} />
                    <TextField label="Petal Width" variant="outlined" fullWidth margin="normal" value={petalWidth} onChange={(e) => setPetalWidth(e.target.value)} />
                    <Button variant="contained" color="primary" onClick={handlePredict} sx={{ mt: 2 }}>
                        Predict
                    </Button>
                </form>
                {predictionResult && (
                    <Typography variant="h6" sx={{ mt: 2 }}>
                        Prediction: {predictionResult}
                    </Typography>
                )}
            </Box>
        );

        return (
            <div>
                {props.variant === "classification" ? <ClassificationReport download={props.downloadModel} predict={handleOpenModal} /> : <RegressionReport download={props.downloadModel} />}

                {/* <Button variant="contained" color="primary" onClick={handleOpenModal} sx={{ mt: 2 }}>
                    Predict
                </Button> */}
                <Modal open={openModal} onClose={handleCloseModal}>
                    {modalContent}
                </Modal>
            </div>
        );
    };

    const getReport = () => {
        return <ReportWithPredict variant={props.variant} downloadModel={downloadModel} />;
    }

    const getStep = (step) => {
        switch (step) {
            case 0:
                return getPaper(<Dataset
                    dataset={dataset}
                    datasetMeta={datasetMeta}
                    setData={setData}
                    setFilters={setFilters}
                    setTarget={setTarget}
                    addCast={addCast}
                    setHiddenColumns={setHiddenColumns}
                    toggleDropNa={toggleDropNa} />);
            case 1:
                return getPaper(<TrainStep dataset={dataset} datasetMeta={datasetMeta} setModelAndEval={setModelAndEval} variant={props.variant} />);
            case 2:
                return <Box sx={{ height: "77vh" }}>{getReport()}</Box>;
        }
    }

    return (

        <Box>
            <Stepper activeStep={activeStep} sx={{ mb: 4 }}>
                <Step>
                    <StepLabel>Upload data</StepLabel>
                </Step>
                <Step>
                    <StepLabel>Train</StepLabel>
                </Step>
                <Step>
                    <StepLabel>Evaluate</StepLabel>
                </Step>
            </Stepper>

            {getStep(activeStep)}

            <Box sx={{ display: 'flex', flexDirection: 'row', pt: 2 }}>
                <Button
                    color="inherit"
                    disabled={activeStep === 0}
                    onClick={handleBack}
                    sx={{ mr: 1 }}
                >
                    Back
                </Button>
                <Box sx={{ flex: '1 1 auto' }} />

                <Button onClick={handleNext}>
                    {activeStep === 2 ? 'Finish' : 'Next'}
                </Button>
            </Box>

        </Box>

    );
}
