import json
import pickle
import random
import torch
import numpy as np
from scipy import signal
from os.path import join, isdir
import datetime
import os
from os.path import isfile
from torch.nn import functional as F
import pandas as pd

# datasets for TSF
DATASET_NAMES = ['electricity_nips','exchange_rate_nips','traffic_nips']

# datasets for TSC
DATASET_CLASSIFICATION = ['DistalPhalanxTW', 'MiddlePhalanxTW', 'ProximalPhalanxTW']


# root dir for TSF
DATA_ROOT = 'data_new'

# root dir for TSC
DATA_ROOT_TSC = 'Univariate_arff'

# interval length for Impute And Predict
DELTAS = [0.5, 0.8, 1.0, 1.2, 1.5, 2.0, 3.0]
LEN_CLIPS = [0.4, 0.5, 0.6, 0.8, 1.0]


# PRE_MODEL_TYPES = ['gru', 'lstm', 'mixer', 'mlp', 'resnet18', 'tcn']
PRE_MODEL_TYPES = ['gru', 'lstm', 'mixer', 'mlp', 'resnet18']
IMP_MODEL_TYPES = ['mixer']
TSC_MODEL_TYPES = ['fcn', 'mixer', 'mlp', 'resnet18']
PYPOTS_IMP_MODEL_TYPES = ['saits','transformer','brits']
ABLATION_STUDY_IMP_MODEL_TYPES = ['brits', 'mixer', 'saits', 'transformer']
MASK_LENGTH = [2,5,8,10,15] # sum == 40

EPOCHS_TSC = {
    'mlp': 100,
    'resnet18': 200,
    'fcn': 100,
    'mixer': 100,
}

RS_LENS = [2,5,8,10,15]
RS_TYPES = ['block','random']

BOOK_DADASET_NAMES = {
    'electricity_nips': 'Electricity',
    'exchange_rate_nips': 'Exchange',
    'traffic_nips': 'Traffic',
}

BOOK_MODEL_TYPES = {
    'brits': 'BRITS',
    'mixer': 'MLP-Mixer',
    'saits': 'SAITS',
    'transformer': 'Transformer',
    "fcn": "FCN",
    "resnet18": "ResNet-18",
    'mlp': 'MLP',
    'tcn': 'TCN',
    'lstm': 'LSTM',
    'gru': 'GRU',
}

BOOK_METHODS = {
    'MIA': 'MIA',
    'random': 'RA',
    'block': 'DeR',
}


def filtrate_dataframe(key:str, reserved:list, df:pd.DataFrame):
    '''
    Input:
        key:str, the column name to filtrate
        reserved: list, if df[key] is in this list, this row will be reserved
        df: pandas.DataFrame
    Return:
        the modified pandas.DataFrame
    '''
    for k in df[key].unique():
        if k not in reserved:
            index = df.loc[df[key]==k].index
            df.drop(index=index, inplace=True)
    return df


def get_metric(y,y_hat):
    '''
    y and y_hat are torch.FloatTensor
    shape are [B,*]
    return:
        dict: {
            mae: float,
            mse: float 
        }
    '''
    assert y.shape == y_hat.shape
    return {
        "mae": F.l1_loss(y,y_hat,reduction='mean').item(),
        "mse": F.mse_loss(y,y_hat,reduction='mean').item()
    }


def mkdir(dir):
    if isdir(dir):
        pass
    else:
        os.makedirs(dir,mode=0o777)
    return True


def get_datetime():
    time1 = datetime.datetime.now()
    time2 = datetime.datetime.strftime(time1,'%Y-%m-%d %H:%M:%S')
    return str(time2)


class Logger():
    def __init__(self,log_file_path) -> None:
        self.path = log_file_path
        with open(self.path,'w') as f:
            f.write(get_datetime() + "\n")
            print(get_datetime())
        return
    
    def log(self,content,enable_print=True):
        with open(self.path,'a') as f:
            f.write(str(content))
            f.write('\n')
            if enable_print:
                print(content)
        return


def load_pkl(path):
  with open(path, 'rb') as f:
    res = pickle.load(f)
  return res


def save_pkl(obj, path):
  with open(path, 'wb') as f:
    pickle.dump(obj, f)


def load_json(path):
  with open(path, 'r') as f:
    res = json.load(f)
  return res


def save_json(obj, path, indent=4):
  with open(path, 'w') as f:
    json.dump(obj, f, indent=indent)


def setup_seed(seed = 3407):
     torch.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)


class GenerateData2():
    """add filter"""
    def __init__(self,dataset_name,use_filter=True,window=15,order=5,regenerate=False) -> None:
        self.dataset_name = dataset_name
        self.train_path = join(DATA_ROOT,f'{dataset_name}_train.pkl') if not use_filter else join(DATA_ROOT,f'{dataset_name}_train_window_{window}_order_{order}.pkl')
        self.test_path = join(DATA_ROOT,f'{dataset_name}_test.pkl') if not use_filter else join(DATA_ROOT,f'{dataset_name}_test_window_{window}_order_{order}.pkl')
        self.regenerate = regenerate
        self.use_filter = use_filter
        self.window = window
        self.order = order
    
    def generate(self):
        print(f"Generating dataset {self.train_path} and {self.test_path}...")

        if isfile(self.train_path) and isfile(self.test_path):
            print(f"Dataset {self.train_path} and {self.test_path} alrealy exist")
            return
        
        origin_train = join(DATA_ROOT,f'{self.dataset_name}_train.pkl')
        origin_test = join(DATA_ROOT,f'{self.dataset_name}_test.pkl')
        assert isfile(origin_train) and isfile(origin_test)

        origin_train = load_pkl(origin_train)
        origin_test = load_pkl(origin_test)

        train = []
        for series in origin_train:
            train.append(signal.savgol_filter(series,self.window,self.order))
        save_pkl(train, self.train_path)

        test = []
        for series in origin_test:
            test.append(signal.savgol_filter(series,self.window,self.order))
        save_pkl(test, self.test_path)
        

def get_dataset_params(dataset_name:str, len_clip:float=1.0):
    if dataset_name == "exchange_rate_nips":
        context_length, prediction_length = 120, 30
    elif dataset_name == "m4_daily":
        context_length, prediction_length = 56, 14
    elif dataset_name == "traffic_nips":
        context_length, prediction_length = 96, 24
    elif dataset_name == "electricity_nips":
        context_length, prediction_length = 96, 24
    elif dataset_name in DATASET_CLASSIFICATION:
        context_length, prediction_length = 80, 0
    else:
        raise NotImplementedError(f"Dataset {dataset_name} not implement")
    return int(context_length * len_clip), prediction_length


def filter_data(dataset_name,window=15,order=5):
    for mode in ["train","test"]:
        path = join(DATA_ROOT,f"{dataset_name}_{mode}.pkl")
        data = load_pkl(path)
        data_filter = []
        for item in data:
            data_filter.append(signal.savgol_filter(item,window,order))
        path_new = join(DATA_ROOT,f'{dataset_name}_{mode}_window_{window}_order_{order}.pkl')
        save_pkl(data_filter,path_new)




def convert_csv_to_list(csv_path:str, pkl_dir:str):
    name = csv_path.split('/')[-1][0:-4]
    df = pd.read_csv(csv_path)
    ot = df['OT'].values

    len_train = int(0.9*len(ot))

    train = [ot[0:len_train]]
    test = [ot[len_train:]]

    save_pkl(train, join(pkl_dir, f'{name}_train.pkl'))
    save_pkl(test, join(pkl_dir, f'{name}_test.pkl'))


def show_dataset():
    for dataset_name in DATASET_NAMES:
        modes = ['train', 'test'][0:1]
        for mode in modes:
            path = join(DATA_ROOT, f'{dataset_name}_{mode}.pkl')
            data = load_pkl(path)
            length = 0
            for i in data:
                length += len(i)
            print(f'{path}.length = {length}, num_sample = {length//100}')


if __name__ == "__main__":
    print(get_dataset_params('exchange_rate_nips', 0.5))