# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import time

import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import numpy as np
from datetime import datetime
import pandas as pd
from distutils.util import strtobool
from statsmodels.distributions.empirical_distribution import ECDF
from torch.utils.data import WeightedRandomSampler
from einops import rearrange
from scipy.io import arff
from torch import Tensor
# import numpy as np
from sklearn.preprocessing import MinMaxScaler
from types import SimpleNamespace

class TSGDataset(Dataset):  # For generation task. Unified Univariate Generation Dataset
    def __init__(self, data_dict: dict):
        self.data_dict = data_dict
        self.cal_data_stats()
        
    def cal_data_stats(self):
        total_items = 0
        n_items_dict = {}
        key_list = []
        key_idx_list = []
        
        for key, data in self.data_dict.items():
            num_items = data.shape[0]
            total_items += num_items
            n_items_dict[key] = num_items
            key_list.append(key)
            key_idx_list.append(total_items)
        self.total_items = total_items
        self.items_dict = n_items_dict    
        self.key_list = key_list
        self.key_idx_list = np.array(key_idx_list)
    
    def get_reweight_sampler(self):
        dataset_weights = np.array([1 / self.items_dict[key] for key in self.key_list], dtype=np.float32)
        sample_weights = np.repeat(dataset_weights, [self.items_dict[key] for key in self.key_list])
        sampler = WeightedRandomSampler(weights=sample_weights, num_samples=self.total_items, replacement=True)
        return sampler
        
    def __len__(self):
        return self.total_items  # self.num_slices
    
    def __getitem__(self, idx):
        assert idx < self.total_items, f"Index({idx}) must be less than number of items({self.total_items})."
        data_key = np.where(self.key_idx_list > idx)[0].min()  # np.argmin(self.key_idx_list > idx)
        data_start_idx = self.key_idx_list[data_key-1] if data_key > 0 else 0
        data: np.ndarray = self.data_dict[self.key_list[data_key]]
        valid_idx = idx  - data_start_idx
        context = data[valid_idx,:,:]

        return {
            'context': context,  # shape: (window,)
            'data_key': data_key
            }  

    
class TSGDataModule(pl.LightningDataModule):
    '''
    Data module for unified time series generation task.
    Slicing is also done with this module. So the train/val is i.i.d within train dataset.
    '''
    def __init__(self, data_path_dict,args, window=96, val_portion=0.1, as_tensor:bool=True, normalize="centered_pit", batch_size=128, num_workers=0, pin_memory=True, drop_last=False, reweight=False,
                 input_channels=1,**kwargs):
        super().__init__()

        # print(args)
        # print(data_path_dict)
        self.args = SimpleNamespace(**args)

        self.data_path_dict = data_path_dict  # {data_name: data_path}
        self.data_dict = {}
        self.norm_data_dict = {}
        self.normalizer_dict = {}
        self.norm_train_dict = {}
        self.norm_val_dict = {}
        self.window = window
        self.val_portion = val_portion
        self.as_tensor = as_tensor
        assert normalize in [None, 'zscore', 'robust_iqr', 'robust_mad', 'pit', 'centered_pit', 'minmax'], f"Normalize({normalize}) must be in (zscore, robust_iqr, robust_mad, pit)."
        self.normalize = normalize
        
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.kwargs = kwargs
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.reweight = reweight
        # self.transform = None
        self.kwargs = kwargs
        self.key_list = []
        self.drop_last = drop_last
        # print(self.drop_last)
        self.drop_last=False
        # time.sleep(500)
        self.input_channels = input_channels
    def prepare_data(self,) -> None:


        print(f"Normalizing data with: {self.normalize}")
        self.key_list = []

        self.key_list.append(self.args.d_name)

        data_path_fix=None

        if self.args.baseline:
            if self.args.d_name == 'polynomial':
                data_path = f'./logs_irgen_moencde_final_polynomial_baseline/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen={self.args.seq_len}epochs=30'
            elif self.args.d_name in self.args.medical_datasets:
                data_path = f'./logs_irgen_moencde_final_ECG_womoe_baseline/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen=24epochs=30'
                data_path_fix = f'./logs_irgen_moencde_final_ECG/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen=24epochs=30'
            else:
                data_path=f'./logs_irgen_moencde_mndiff/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen={self.args.seq_len}epochs=30'
                data_path_fix = f'./logs_irgen_moencde_final/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen={self.args.seq_len}epochs=30'

        else:
            if self.args.d_name=='polynomial':
                data_path = f'./logs_irgen_moencde_final_polynomial/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen={self.args.seq_len}epochs=30'

            elif self.args.d_name in self.args.medical_datasets:
                data_path=f'./logs_irgen_moencde_final_ECG/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen=24epochs=30'
            else:
                data_path = f'./logs_irgen_moencde_final/{self.args.d_name}/MoeNcdeIrreg-seed=42-miss={self.args.miss_}seqlen={self.args.seq_len}epochs=30'

        train_data,val_data= load_data_from_file_irrtsg(data_path,data_name=self.args.d_name,args=self.args,data_path_fix=data_path_fix)
        self.norm_train_dict[self.args.d_name] = train_data
        self.norm_val_dict[self.args.d_name] = val_data
        print(f"Loaded data: {self.args.d_name}; Train shape: {train_data.shape}, Validation shape: {val_data.shape}.")

    def prepare_data_old(self, args=None) -> None:

        print(f"Normalizing data with: {self.normalize}")
        self.key_list = []
        for data_name, data_path in self.data_path_dict.items():
            self.key_list.append(data_name)
            train_data, val_data = load_data_from_file(data_path, data_name=data_name)
            # _, val_data = self.split_train_val(train_data)  # slice and split here
            self.norm_train_dict[data_name] = train_data
            self.norm_val_dict[data_name] = val_data
            print(f"Loaded data: {data_name}; Train shape: {train_data.shape}, Validation shape: {val_data.shape}.")

    def split_train_val(self, data: np.ndarray):
        # By default, data are sliced into non-overlapped sequences.
        # shuffle stack_data, only along the first dimension
        np.random.shuffle(data)
        total_instances = data.shape[0]
        num_val_instances = int(total_instances * self.val_portion)
        train_data = data[:-num_val_instances]

        val_data = data[-num_val_instances:]

        return train_data, val_data
        
    def train_dataloader(self):
        train_dataset = TSGDataset(self.norm_train_dict)
        sampler = None
        if self.reweight:
            sampler = train_dataset.get_reweight_sampler()
            return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, drop_last=self.drop_last, sampler=sampler, **self.kwargs)
        else:
            return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, shuffle=True, drop_last=self.drop_last, **self.kwargs)

    def train_dataloader_eval(self):
        train_dataset = TSGDataset(self.norm_train_dict)
        sampler = None
        if self.reweight:
            sampler = train_dataset.get_reweight_sampler()
            return DataLoader(train_dataset, batch_size=2048, num_workers=self.num_workers,
                              pin_memory=self.pin_memory, drop_last=self.drop_last, sampler=sampler, **self.kwargs)
        else:
            return DataLoader(train_dataset, batch_size=2048, num_workers=self.num_workers,
                              pin_memory=self.pin_memory, shuffle=True, drop_last=self.drop_last, **self.kwargs)

    def val_dataloader(self):
        val_dataset = TSGDataset(self.norm_val_dict)
        return DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=self.pin_memory, **self.kwargs)
    
    def test_dataloader(self, **kwargs):
        return None
       
    def fit_normalizer(self, data: np.ndarray):
        normalizer = {}
        data = data.flatten()
        if self.normalize == 'zscore':
            normalizer['mean'] = np.nanmean(data)
            normalizer['std'] = np.nanstd(data)
        elif self.normalize == 'robust_iqr':
            normalizer['median'] = np.median(data)
            normalizer['iqr'] = np.subtract(*np.percentile(data, [75, 25]))
        elif self.normalize == 'robust_mad':
            normalizer['median'] = np.median(data)
            normalizer['mad'] = np.median(np.abs(data - normalizer['median']))
        elif self.normalize == 'minmax':
            normalizer['min'] = np.nanmin(data)
            normalizer['max'] = np.nanmax(data)
        elif self.normalize == 'pit' or self.normalize == 'centered_pit':
            ecdf = ECDF(data)
            normalizer['ecdf'] = ecdf
        return normalizer
    
    def transform(self, data: np.ndarray, normalizer=None, data_name=None):
        # if data_name is specified, the normalizer argument will be ignored.
        assert normalizer is not None or data_name is not None, "Must specify either normalizer or data name."
        if data_name is not None:
            assert data_name in self.normalizer_dict.keys(), f"Data name({data_name}) must be in normalizer dict key."
            normalizer = self.normalizer_dict[data_name]
        if self.normalize == 'zscore':
            return (data - normalizer['mean']) / (normalizer['std'] + 1e-8)
        elif self.normalize == 'robust_iqr':
            return (data - normalizer['median']) / (normalizer['iqr'] + 1e-8)
        elif self.normalize == 'robust_mad':
            return (data - normalizer['median']) / (normalizer['mad'] + 1e-8)
        if self.normalize == 'minmax':
            return (data - normalizer['min']) / (normalizer['max'] - normalizer['min'] + 1e-8)
        elif self.normalize == 'pit' or self.normalize == 'centered_pit':
            data_shape = data.shape
            norm_data = normalizer['ecdf'](data.flatten()).reshape(data_shape)
            if self.normalize == 'centered_pit':
                norm_data = norm_data * 2 - 1
            return norm_data
        
    def inverse_transform(self, data: np.ndarray, normalizer=None, data_name=None):
        # if data_name is specified, the normalizer argument will be ignored.
        assert normalizer is not None or data_name is not None, "Must specify either normalizer or data name."
        if data_name is not None:
            assert data_name in self.normalizer_dict.keys(), f"Data name({data_name}) must be in normalizer dict key."
            normalizer = self.normalizer_dict[data_name]
        if self.normalize == 'zscore':
            return data * normalizer['std'] + normalizer['mean']
        elif self.normalize == 'robust_iqr':
            return data * normalizer['iqr'] + normalizer['median']
        elif self.normalize == 'robust_mad':
            return data * normalizer['mad'] + normalizer['median']
        if self.normalize == 'minmax':
            return data * (normalizer['max'] - normalizer['min']) + normalizer['min']
        elif self.normalize == 'pit' or self.normalize == 'centered_pit':
            ecdf: ECDF = normalizer['ecdf']
            ecdf.x[0] = ecdf.x[1]
            if self.normalize == 'centered_pit':
                data = (data + 1) / 2
            return np.interp(data, ecdf.y, ecdf.x)


class MinMaxScaler3D():
    '''
    This class contains methods to perform a scaling of values within
    a three-dimensional dataset of size (bz, length, ch) where
    bz is the batch size, length is the sequence length and ch is
    the number of channels / features of the dataset.
    For every channel, a scaler is initialized independently and applied
    to the corresponding channel dimension.
    The range in which the features should be scaled can be determined
    using the init() argument feature_range.
    '''

    def __init__(self,
                 feature_range: int = (-1, 1)):
        self.scalers = {}
        self.feature_range = feature_range

    def fit_transform(self, dataset: Tensor):
        for i in range(dataset.shape[2]):
            self.scalers[i] = MinMaxScaler(feature_range=self.feature_range)
            dataset[:, :, i] = torch.from_numpy(self.scalers[i].fit_transform(dataset[:, :, i]))
        return dataset

    def transform(self, dataset: Tensor):
        for i in range(dataset.shape[2]):
            dataset[:, :, i] = torch.from_numpy(self.scalers[i].transform(dataset[:, :, i]))
        return dataset

    def inverse_transform(self, dataset: Tensor):
        for i in range(dataset.shape[2]):
            dataset[:, :, i] = torch.from_numpy(self.scalers[i].inverse_transform(dataset[:, :, i]))
        return dataset

def load_data_from_file_irrtsg(file_path: str,data_name=None,args=None,data_path_fix=None):
    custom_datasets=['stock']


    if not args.baseline:
        moe_weights=np.load(file_path+f"/moe_experts_all_{data_name}.npy")
        samples_all = np.load(file_path + f"/moe_samples_all_{data_name}.npy")

    else:
        moe_weights = np.load(data_path_fix + f"/moe_experts_all_{data_name}.npy")
        samples_all = np.load(file_path + f"/samples_all_{data_name}_woautenc.npy")
        print(file_path)
        print('***************111***********************')

    moe_weights=np.expand_dims(moe_weights,axis=-1)
    moe_weights = np.repeat(moe_weights, samples_all.shape[-1], axis=-1)
    sample_merge=np.concatenate([samples_all,moe_weights],axis=1)

    return sample_merge,sample_merge

def load_data_from_file(file_path: str,data_name=None):
    custom_datasets=['stock']
    if data_name in custom_datasets:
        moe_weights=np.load(file_path+f"/moe_experts_all_{data_name}.npy")
        # samples_all=np.load(file_path+f"/samples_all_{data_name}.npy")
        samples_all = np.load(file_path + f"/samples_all_{data_name}_replace.npy")
        # print('ok')
        # time.sleep(500)
        moe_weights=np.expand_dims(moe_weights,axis=-1)
        moe_weights = np.repeat(moe_weights, samples_all.shape[-1], axis=-1)
        sample_merge=np.concatenate([samples_all,moe_weights],axis=1)
        return sample_merge,sample_merge
    if file_path.endswith(".csv"):
        loaded_data = pd.read_csv(file_path)
        return loaded_data.values  # no index columns, by default.
    elif file_path.endswith(".tsf"):
        loaded_data, frequency, forecast_horizon, contain_missing_values, contain_equal_length = convert_tsf_to_dataframe(
            file_path, 
            replace_missing_vals_with="NaN",
            value_column_name="series_value"
            )
        data = np.stack(loaded_data['series_value'].values).T
        return data  # no date column
    elif file_path.endswith(".npy"):
        loaded_data = np.load(file_path)  # shape like (N, T) by default
        return loaded_data.T
    elif file_path.endswith('arff'):
        def extract_data(data):
            res_data = []
            res_labels = []
            for t_data, t_label in data:
                t_data = np.array([d.tolist() for d in t_data])
                t_label = t_label.decode("utf-8")
                res_data.append(t_data)
                res_labels.append(t_label)
            return np.array(res_data).swapaxes(1, 2), np.array(res_labels)
        file_path_new='/'.join(file_path.split('/')[:-1])
        loaded_data_train = arff.loadarff(file_path_new+'/NATOPS_TRAIN.arff')[0]
        loaded_data_test = arff.loadarff(file_path_new + '/NATOPS_TEST.arff')[0]
        loaded_data_train, train_y = extract_data(loaded_data_train)
        loaded_data_test, test_y = extract_data(loaded_data_test)
        loaded_data=np.concatenate([loaded_data_train,loaded_data_test],axis=0)
        scaler = MinMaxScaler3D(feature_range=(-1, 1))
        loaded_data = scaler.fit_transform(loaded_data.copy())
        loaded_data=loaded_data.transpose(0,1,2)

        labels_all=np.concatenate([train_y,test_y],axis=0)
        labels = np.unique(labels_all)
        transform = {k: i for i, k in enumerate(labels)}
        labels_all = np.vectorize(transform.get)(labels_all)

        labels_expanded = labels_all[:, None, None]
        labels_tiled = np.tile(labels_expanded, (1, 51, 1))
        data_with_label = np.concatenate([loaded_data, labels_tiled], axis=-1)

        return data_with_label





# Codes below are from: https://github.com/rakshitha123/TSForecasting/blob/master/utils/data_loader.py
def convert_tsf_to_dataframe(
    full_file_path_and_name,
    replace_missing_vals_with="NaN",
    value_column_name="series_value",
):
    col_names = []
    col_types = []
    all_data = {}
    line_count = 0
    frequency = None
    forecast_horizon = None
    contain_missing_values = None
    contain_equal_length = None
    found_data_tag = False
    found_data_section = False
    started_reading_data_section = False

    with open(full_file_path_and_name, "r", encoding="cp1252") as file:
        for line in file:
            # Strip white space from start/end of line
            line = line.strip()

            if line:
                if line.startswith("@"):  # Read meta-data
                    if not line.startswith("@data"):
                        line_content = line.split(" ")
                        if line.startswith("@attribute"):
                            if (
                                len(line_content) != 3
                            ):  # Attributes have both name and type
                                raise Exception("Invalid meta-data specification.")

                            col_names.append(line_content[1])
                            col_types.append(line_content[2])
                        else:
                            if (
                                len(line_content) != 2
                            ):  # Other meta-data have only values
                                raise Exception("Invalid meta-data specification.")

                            if line.startswith("@frequency"):
                                frequency = line_content[1]
                            elif line.startswith("@horizon"):
                                forecast_horizon = int(line_content[1])
                            elif line.startswith("@missing"):
                                contain_missing_values = bool(
                                    strtobool(line_content[1])
                                )
                            elif line.startswith("@equallength"):
                                contain_equal_length = bool(strtobool(line_content[1]))

                    else:
                        if len(col_names) == 0:
                            raise Exception(
                                "Missing attribute section. Attribute section must come before data."
                            )

                        found_data_tag = True
                elif not line.startswith("#"):
                    if len(col_names) == 0:
                        raise Exception(
                            "Missing attribute section. Attribute section must come before data."
                        )
                    elif not found_data_tag:
                        raise Exception("Missing @data tag.")
                    else:
                        if not started_reading_data_section:
                            started_reading_data_section = True
                            found_data_section = True
                            all_series = []

                            for col in col_names:
                                all_data[col] = []

                        full_info = line.split(":")

                        if len(full_info) != (len(col_names) + 1):
                            raise Exception("Missing attributes/values in series.")

                        series = full_info[len(full_info) - 1]
                        series = series.split(",")

                        if len(series) == 0:
                            raise Exception(
                                "A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series. Missing values should be indicated with ? symbol"
                            )

                        numeric_series = []

                        for val in series:
                            if val == "?":
                                numeric_series.append(replace_missing_vals_with)
                            else:
                                numeric_series.append(float(val))

                        if numeric_series.count(replace_missing_vals_with) == len(
                            numeric_series
                        ):
                            raise Exception(
                                "All series values are missing. A given series should contains a set of comma separated numeric values. At least one numeric value should be there in a series."
                            )

                        all_series.append(pd.Series(numeric_series).array)

                        for i in range(len(col_names)):
                            att_val = None
                            if col_types[i] == "numeric":
                                att_val = int(full_info[i])
                            elif col_types[i] == "string":
                                att_val = str(full_info[i])
                            elif col_types[i] == "date":
                                att_val = datetime.strptime(
                                    full_info[i], "%Y-%m-%d %H-%M-%S"
                                )
                            else:
                                raise Exception(
                                    "Invalid attribute type."
                                )  # Currently, the code supports only numeric, string and date types. Extend this as required.

                            if att_val is None:
                                raise Exception("Invalid attribute value.")
                            else:
                                all_data[col_names[i]].append(att_val)

                line_count = line_count + 1

        if line_count == 0:
            raise Exception("Empty file.")
        if len(col_names) == 0:
            raise Exception("Missing attribute section.")
        if not found_data_section:
            raise Exception("Missing series information under data section.")

        all_data[value_column_name] = all_series
        loaded_data = pd.DataFrame(all_data)

        return (
            loaded_data,
            frequency,
            forecast_horizon,
            contain_missing_values,
            contain_equal_length,
        )
