import os
import torch
import numpy as np
import jax
import pandas as pd

from scipy import io
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset
from Utils.model_utils import normalize_to_neg_one_to_one, unnormalize_to_zero_to_one
from Utils.masking_utils import noise_mask
import equinox as eqx
import jax.tree_util as jtu
from diffusion_crf import TimeSeries
from typing import Optional, Mapping, Tuple, List, Sequence, Union, Any, Callable, Dict, Iterator
from jaxtyping import Array, PRNGKeyArray, PyTree, Scalar


def make_train_val_test_split(key: PRNGKeyArray,
                              data_series: TimeSeries,
                              train_val_test_split: Tuple[int, int, int],
                              seq_length: int,
                              cond_length: int,
                              no_batching: bool = False):
    total_size = len(data_series)
    train_proportion, val_proportion, test_proportion = train_val_test_split

    if test_proportion is None:
        # Then not going to make a test set
        n_train = int(total_size * train_proportion)
        n_val = total_size - n_train

        # Get the start and end indices for the train, val, test sets
        # Subtract cond_length from the val and test start because it only
        # matters that the things we predict have not been seen yet
        train_start, train_end = 0, n_train
        val_start, val_end = train_end - cond_length, total_size

        train_series = data_series[train_start:train_end]
        val_series = data_series[val_start:val_end]

        if no_batching == False:
            train_batches = train_series.make_windowed_batches(window_size=seq_length)
            val_batches = val_series.make_windowed_batches(window_size=seq_length)
        else:
            return train_series, val_series

        # Shuffle the batches
        k1, k2 = jax.random.split(key, 2)
        train_idx = jax.random.randint(k1, (train_batches.batch_size,), minval=0, maxval=train_batches.batch_size)
        val_idx = jax.random.randint(k2, (val_batches.batch_size,), minval=0, maxval=val_batches.batch_size)

        return train_batches[train_idx], val_batches[val_idx]

    else:

        n_train = int(total_size * train_proportion)
        n_val = int(total_size * val_proportion)
        n_test = total_size - n_train - n_val  # Ensure all data is used

        # Get the start and end indices for the train, val, test sets
        # Subtract cond_length from the val and test start because it only
        # matters that the things we predict have not been seen yet
        train_start, train_end = 0, n_train
        val_start, val_end = train_end - cond_length, train_end + n_val
        test_start, test_end = val_end - cond_length, total_size

        train_series = data_series[train_start:train_end]
        val_series = data_series[val_start:val_end]
        test_series = data_series[test_start:test_end]

        if no_batching == True:
            return train_series, val_series, test_series

        train_batches = train_series.make_windowed_batches(window_size=seq_length)
        val_batches = val_series.make_windowed_batches(window_size=seq_length)
        test_batches = test_series.make_windowed_batches(window_size=seq_length)

        # Shuffle the batches
        k1, k2, k3 = jax.random.split(key, 3)
        train_idx = jax.random.randint(k1, (train_batches.batch_size,), minval=0, maxval=train_batches.batch_size)
        val_idx = jax.random.randint(k2, (val_batches.batch_size,), minval=0, maxval=val_batches.batch_size)
        test_idx = jax.random.randint(k3, (test_batches.batch_size,), minval=0, maxval=test_batches.batch_size)

        return train_batches[train_idx], val_batches[val_idx], test_batches[test_idx]

def get_stocks_dataset(config: dict,
                       args=None,
                       key: Optional[PRNGKeyArray] = None,
                       train_val_test_split: Optional[Tuple[int, int, int]] = None,
                       return_raw_data: bool = False):
    from Utils.io_utils import instantiate_from_config
    dataset = instantiate_from_config(config['old_dataloader'], args=args)
    data = dataset.data

    # Create timeseries objects out of data
    import jax.numpy as jnp
    import jax
    from Utils.Data_utils.my_datasets import make_series
    data = jnp.array(data)
    data_series = make_series(data)

    if return_raw_data:
        return data_series

    if train_val_test_split is not None:
        # Divide the data into train, val, test
        cond_length = config['seq_length'] - config['pred_length']
        return make_train_val_test_split(key, data_series, train_val_test_split, config['seq_length'], cond_length)

    # Now need to batch the data
    return data_series.make_windowed_batches(window_size=config['seq_length'])

def get_energy_dataset(config: dict,
                       args=None,
                       key: Optional[PRNGKeyArray] = None,
                       train_val_test_split: Optional[Tuple[int, int, int]] = None,
                       return_raw_data: bool = False):
    from Utils.io_utils import instantiate_from_config
    dataset = instantiate_from_config(config['old_dataloader'], args=args)
    data = dataset.data

    # Create timeseries objects out of data
    import jax.numpy as jnp
    import jax
    from Utils.Data_utils.my_datasets import make_series
    data = jnp.array(data)
    data_series = make_series(data)

    if return_raw_data:
        return data_series

    if train_val_test_split is not None:
        # Divide the data into train, val, test
        cond_length = config['seq_length'] - config['pred_length']
        return make_train_val_test_split(key, data_series, train_val_test_split, config['seq_length'], cond_length)

    # Now need to batch the data
    warnings.warn('Depracated!  You should be passing in a train_val_test_split tuple to avoid data leakage!')
    out = data_series.make_windowed_batches(window_size=config['seq_length'])
    return out

def get_etth_dataset(config: dict,
                       args=None,
                       key: Optional[PRNGKeyArray] = None,
                       train_val_test_split: Optional[Tuple[int, int, int]] = None,
                       return_raw_data: bool = False):
    from Utils.io_utils import instantiate_from_config
    dataset = instantiate_from_config(config['old_dataloader'], args=args)
    data = dataset.data

    # Create timeseries objects out of data and batch the data
    import jax.numpy as jnp
    import jax
    from Utils.Data_utils.my_datasets import make_series
    data = jnp.array(data)
    data_series = make_series(data)

    if return_raw_data:
        return data_series

    if train_val_test_split is not None:
        # Divide the data into train, val, test
        cond_length = config['seq_length'] - config['pred_length']
        return make_train_val_test_split(key, data_series, train_val_test_split, config['seq_length'], cond_length)

    # Now need to batch the data
    warnings.warn('Depracated!  You should be passing in ')
    out = data_series.make_windowed_batches(window_size=config['seq_length'])
    return out

def get_fmri_dataset(config: dict,
                       args=None,
                       key: Optional[PRNGKeyArray] = None,
                       train_val_test_split: Optional[Tuple[int, int, int]] = None,
                       return_raw_data: bool = False):
    from Utils.io_utils import instantiate_from_config
    dataset = instantiate_from_config(config['old_dataloader'], args=args)
    data = dataset.data

    # Create timeseries objects out of data and batch the data
    import jax.numpy as jnp
    import jax
    from Utils.Data_utils.my_datasets import make_series
    data = jnp.array(data)
    data_series = make_series(data)

    if return_raw_data:
        return data_series

    if train_val_test_split is not None:
        # Divide the data into train, val, test
        cond_length = config['seq_length'] - config['pred_length']
        return make_train_val_test_split(key, data_series, train_val_test_split, config['seq_length'], cond_length)

    # Now need to batch the data
    warnings.warn('Depracated!  You should be passing in ')
    out = data_series.make_windowed_batches(window_size=config['seq_length'])
    return out


class CustomDataset(Dataset):
    def __init__(
        self,
        name,
        data_root,
        window=64,
        proportion=0.8,
        save2npy=True,
        neg_one_to_one=True,
        seed=123,
        period='train',
        output_dir='./OUTPUT',
        predict_length=None,
        missing_ratio=None,
        style='separate',
        distribution='geometric',
        mean_mask_length=3,
        args=None
    ):
        super(CustomDataset, self).__init__()
        assert period in ['train', 'test'], 'period must be train or test.'
        if period == 'train':
            assert ~(predict_length is not None or missing_ratio is not None), ''
        self.name, self.pred_len, self.missing_ratio = name, predict_length, missing_ratio
        self.style, self.distribution, self.mean_mask_length = style, distribution, mean_mask_length
        self.rawdata, self.scaler = self.read_data(data_root, self.name)
        self.dir = os.path.join(output_dir, 'samples')
        os.makedirs(self.dir, exist_ok=True)

        self.window, self.period = window, period
        self.len, self.var_num = self.rawdata.shape[0], self.rawdata.shape[-1]
        self.sample_num_total = max(self.len - self.window + 1, 0)
        self.save2npy = save2npy
        self.auto_norm = neg_one_to_one

        self.data = self.__normalize(self.rawdata)
        train, inference = self.__getsamples(self.data, proportion, seed)

        self.samples = train if period == 'train' else inference
        if period == 'test':
            if missing_ratio is not None:
                self.masking = self.mask_data(seed)
            elif predict_length is not None:
                masks = np.ones(self.samples.shape)
                masks[:, -predict_length:, :] = 0
                self.masking = masks.astype(bool)
            else:
                raise NotImplementedError()
        self.sample_num = self.samples.shape[0]

        self.args = args

    def __getsamples(self, data, proportion, seed):
        x = np.zeros((self.sample_num_total, self.window, self.var_num))
        for i in range(self.sample_num_total):
            start = i
            end = i + self.window
            x[i, :, :] = data[start:end, :]

        train_data, test_data = self.divide(x, proportion, seed)

        if self.save2npy:
            if 1 - proportion > 0:
                np.save(os.path.join(self.dir, f"{self.name}_ground_truth_{self.window}_test.npy"), self.unnormalize(test_data))
            np.save(os.path.join(self.dir, f"{self.name}_ground_truth_{self.window}_train.npy"), self.unnormalize(train_data))
            if self.auto_norm:
                if 1 - proportion > 0:
                    np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_test.npy"), unnormalize_to_zero_to_one(test_data))
                np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_train.npy"), unnormalize_to_zero_to_one(train_data))
            else:
                if 1 - proportion > 0:
                    np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_test.npy"), test_data)
                np.save(os.path.join(self.dir, f"{self.name}_norm_truth_{self.window}_train.npy"), train_data)
        return train_data, test_data

    def normalize(self, sq):
        d = sq.reshape(-1, self.var_num)
        d = self.scaler.transform(d)
        if self.auto_norm:
            d = normalize_to_neg_one_to_one(d)
        return d.reshape(-1, self.window, self.var_num)

    def unnormalize(self, sq):
        d = self.__unnormalize(sq.reshape(-1, self.var_num))
        return d.reshape(-1, self.window, self.var_num)

    def __normalize(self, rawdata):
        data = self.scaler.transform(rawdata)
        if self.auto_norm:
            data = normalize_to_neg_one_to_one(data)
        return data

    def __unnormalize(self, data):
        if self.auto_norm:
            data = unnormalize_to_zero_to_one(data)
        x = data
        return self.scaler.inverse_transform(x)

    @staticmethod
    def divide(data, ratio, seed=2023):
        if isinstance(data, np.ndarray):
            size = data.shape[0]
        else:
            size = data.batch_size # For when we are using the JAX version
        # Store the state of the RNG to restore later.
        st0 = np.random.get_state()
        np.random.seed(seed)

        regular_train_num = int(np.ceil(size * ratio))
        # id_rdm = np.random.permutation(size)
        id_rdm = np.arange(size)
        regular_train_id = id_rdm[:regular_train_num]
        irregular_train_id = id_rdm[regular_train_num:]

        if isinstance(data, np.ndarray):
            regular_data = data[regular_train_id, :]
            irregular_data = data[irregular_train_id, :]
        else:
            regular_data = data[regular_train_id]
            irregular_data = data[irregular_train_id]

        # Restore RNG.
        np.random.set_state(st0)
        return regular_data, irregular_data

    @staticmethod
    def read_data(filepath, name=''):
        """Reads a single .csv
        """
        df = pd.read_csv(filepath, header=0)
        if name == 'etth':
            df.drop(df.columns[0], axis=1, inplace=True)
        data = df.values
        scaler = MinMaxScaler()
        scaler = scaler.fit(data)
        return data, scaler

    def mask_data(self, seed=2023):
        masks = np.ones_like(self.samples)
        # Store the state of the RNG to restore later.
        st0 = np.random.get_state()
        np.random.seed(seed)

        for idx in range(self.samples.shape[0]):
            x = self.samples[idx, :, :]  # (seq_length, feat_dim) array
            mask = noise_mask(x, self.missing_ratio, self.mean_mask_length, self.style,
                              self.distribution)  # (seq_length, feat_dim) boolean array
            masks[idx, :, :] = mask

        if self.save2npy:
            np.save(os.path.join(self.dir, f"{self.name}_masking_{self.window}.npy"), masks)

        # Restore RNG.
        np.random.set_state(st0)
        return masks.astype(bool)

    def __getitem__(self, ind):
        if self.period == 'test':
            x = self.samples[ind, :, :]  # (seq_length, feat_dim) array
            m = self.masking[ind, :, :]  # (seq_length, feat_dim) boolean array
            return torch.from_numpy(x).float(), torch.from_numpy(m)
        x = self.samples[ind, :, :]  # (seq_length, feat_dim) array
        return torch.from_numpy(x).float()

    def __len__(self):
        return self.sample_num


class fMRIDataset(CustomDataset):
    def __init__(
        self,
        proportion=1.,
        **kwargs
    ):
        super().__init__(proportion=proportion, **kwargs)

    @staticmethod
    def read_data(filepath, name=''):
        """Reads a single .csv
        """
        data = io.loadmat(filepath + '/sim4.mat')['ts']
        scaler = MinMaxScaler()
        scaler = scaler.fit(data)
        return data, scaler


if __name__ == '__main__':
    import matplotlib.pyplot as plt
    from debug import *
    import tqdm
    from diffusion_crf import *
    from Models.experiment_identifier import ExperimentIdentifier

    # Load the dataset
    eid = ExperimentIdentifier.make_experiment_id(config_name='fmri',
                                                  objective='ml',
                                                  model_name='my_non_probabilistic',
                                                  sde_type='brownian',
                                                  freq=0,
                                                  group='asdf',
                                                  seed=0)
    config = eid.create_config()
    datasets = eid.get_data()
    train_data, val_data, test_data = datasets['train_data'], datasets['val_data'], datasets['test_data']
    idx = 0
    data_series = train_data[idx]

    # Create the SDE
    key = random.PRNGKey(0)
    y_dim = data_series.yts.shape[-1]
    freq = 0
    sde_type = 'brownian'


    def to_latent(noise_std: float, latent_sigma: float, data_series: TimeSeries, key: PRNGKeyArray):
        sde = BrownianMotion(sigma=noise_std, dim=y_dim)
        encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                    x_dim=sde.dim,
                                                    sigma=latent_sigma)
        prob_series = encoder(data_series)
        conditioned_sde = ConditionedLinearSDE(sde, prob_series)
        sampled_series = conditioned_sde.sample(key, data_series.ts)
        return sampled_series

    obs_sigmas = [0.001]
    # latent_sigmas = jnp.logspace(-4, -2, 7, base=10)
    latent_sigmas = [0.0001]
    all_series = [data_series]
    titles = ['Original']
    for noise_std in obs_sigmas:
        for latent_sigma in latent_sigmas:
            sampled_series = to_latent(noise_std, latent_sigma, data_series, key)
            all_series.append(sampled_series)
            print(f'o={noise_std}, l={latent_sigma}')
            titles.append(f'o={noise_std:.2f}, l={latent_sigma:.2f}')



    TimeSeries.plot_multiple_series(all_series,
                                    titles=titles,
                                    marker_size=3,
                                    width_per_series=3,
                                    height_per_dim=1.5)

    latent_sigma = 0.0001
    noise_std = 0.001
    keys = random.split(key, train_data.batch_size)
    out = jax.vmap(partial(to_latent, noise_std, latent_sigma))(train_data, keys)
    assert not jnp.isnan(out.yts).any()


    import pdb; pdb.set_trace()