import torch
import numpy as np
import random
import argparse
from data_load.data_loader import Dataset_classifiction
from experiment.train_pred_repr import train_repr
from experiment.train_classi_repr import train_class_repr
from experiment.train_anomaly_repr import train_anomal_repr

import os
import csv
import pandas as pd
# random seed
os.environ['KMP_DUPLICATE_LIB_OK']= 'True'
fix_seed = 1
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)

parser = argparse.ArgumentParser(description='represetntaion learning')

#data_load
parser.add_argument('--tasks', type = str, default='classification', help='support [forecasting, classification, anomal_detect_kpi, anomal_detect_yahoo]')
parser.add_argument('--method', type = str, default='model', help='the model name')
parser.add_argument('--root_path', type=str, default='data/', help='root path of the data file') 
parser.add_argument('--dataset', type=str, default='UCRs', help='the datasets corresponding to the tasks')
parser.add_argument('--file_name', type=str, default='ETTm1.csv', help='data file') #
parser.add_argument('--scale', type=bool, default=True, help='standard scale the data')
parser.add_argument('--batch_size', type=int, default=16, help='batch size for training') #
parser.add_argument('--num_workers', type=int, default=1, help='parallel workers')

#input setting
parser.add_argument('--feature', type=str, default='S', help='the prediction for univariates (S) or multivariates (M)') 
parser.add_argument('--target', type=str, default='MT_321', help='the target column of the prediction') 
parser.add_argument('--pred_len', type=int, default=128, help='prediction length') 
parser.add_argument('--input_dim', type=int, default=1, help='prediction length') 

# train_param
parser.add_argument('--beta', type=float, default=2, help='beta') 
parser.add_argument('--alpha', type=float, default=0.5, help='alpha') 
parser.add_argument('--norm', type=bool, default=False, help='if norm the features') 
#parser.add_argument('--input_size', type=int, default=1, help='the length of the time series') 

# encoder params
parser.add_argument('--num_heads', type=int, default=8, help = 'number of heads of attention')
parser.add_argument('--num_layers', type=int, default=3, help = 'number of layers of attention')
parser.add_argument('--dim_feedforward', type=int, default=256, help = 'the dimension of feedforward')
parser.add_argument('--encoding_dim', type=int, default=128, help='number of decomp layers')

# model_param
parser.add_argument('--theta_size', type=int, default=8, help='mapping dimensions for decomp') 
parser.add_argument('--middle_size', type=int, default=256, help='middle size of the mapping block') 
parser.add_argument('--map_layers', type=int, default=2, help='linear layers for mapping block') 
parser.add_argument('--degree_of_polynomial', type=int, default=3, help='degree of polynomials')
parser.add_argument('--harmonics', type=int, default=1, help='number of hormonics')
parser.add_argument('--trend_layers', type=int, default=1, help='number of trend blocks')
parser.add_argument('--season_layers', type=int, default=1, help='number of seasonal blocks')
parser.add_argument('--decoder_layers', type=int, default=1, help='number of decoder layers')
parser.add_argument('--decomp_layers', type=int, default=3, help='number of decomp layers')

#training
parser.add_argument('--learning_rate', type=float, default=1e-3, help='learning rate for decoder')
parser.add_argument('--epochs', type=int, default=10, help='training epochs')
parser.add_argument('--scope', type=bool, default=False, help='limit the range of contrastive loss')

#load model
parser.add_argument('--checkpoints', type=str, default='checkpoints/', help='location of model checkpoints')
parser.add_argument('--patience', type=int, default=5, help='patience of the earlystopping')

# testing
parser.add_argument('--use_gpu', type=bool, default=True, help='if the gpu is available')
parser.add_argument('--gpu', type=str, default='4', help='allocate the gpu')
args = parser.parse_args()
args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False

#print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
if args.tasks == 'classification':
    if args.feature == 'M':
        data_disp = pd.read_csv('/DataDimensions.csv')
        print(data_disp.columns)
        data_list = data_disp['Problem'].values
        print(data_list)
        data_train = data_disp['TrainSize'].values
        print(data_train)
        data_test = data_disp['TestSize'].values
        data_dim = data_disp['NumDimensions'].values
    else:
        data_disp = pd.read_csv('data/DataSummary.csv')
        data_list = data_disp['Name'].values
        data_train = data_disp['Train '].values
        data_test = data_disp['Test '].values
        data_dim = [1 for _ in range(len(data_list))]
    #data_list = ['ECGFiveDays', 'FordB']
    print(data_list)
    #creat_csv(configs.methods)
    #datas = [9, 20, 29, 31, 66]
    #datas = [7, 10, 14, 21, 24]
    for i in [20]: #len(data_list)):
        #configs = Configs()
        random.seed(fix_seed)
        torch.manual_seed(fix_seed)
        np.random.seed(fix_seed)

        configs = parser.parse_args()
        name = data_list[i]
        train, test = data_train[i], data_test[i]
        if train < 100:
            configs.batch_size = 2
        configs.file_name = name
        configs.dataset = name
        configs.input_dim = data_dim[i]
        print(configs.input_dim)
        exp = train_class_repr(configs)
        setting = 'methods_{}_tasks_{}_data_{}_ed{}_id{}_alpha{}_beta_{}_loss1'.format(configs.method, configs.tasks, configs.file_name, configs.encoding_dim, configs.input_dim, configs.alpha, configs.beta)
        print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
        exp.train_model(setting)
        exp.train_class_submodel(setting)
        acc = exp.test_class_submodel(setting)
        print(acc)
elif args.tasks == 'forecasting':
    file_names = ['electricity.csv', 'traffic.csv', 'weather.csv']
    input_dims = [321, 862, 21]
    for i in range(3):
        configs = parser.parse_args()
        configs.file_name = file_names[i]
        configs.input_dim = input_dims[i]
        exp = train_pred_repr(configs)
        setting = 'methods_{}_tasks_{}_data_{}_ed{}_id{}'.format(configs.method, configs.tasks, configs.file_name, configs.encoding_dim, configs.input_dim)
        exp.train_model(setting)
        exp.train_fore_submodel(setting)
        mae, mse = exp.test_fore_submodel(setting)
        print(mae, mse)
else:
    configs = parser.parse_args()
    configs.norm = True
    exp = train_anomal_repr(configs)
    setting = 'methods_{}_tasks_{}_ed{}_id{}'.format(configs.method, configs.tasks, configs.encoding_dim, configs.input_dim)
    print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting))
    exp.train_model(setting)
    exp.test_anomal_detection(setting)
