import numpy as np
import pandas as pd
from matplotlib import figure
import matplotlib.pyplot as plt
import seaborn as sns
import math
import os
import time
import random
import gc
import statistics
import json
import argparse
from tqdm.auto import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.utils import class_weight
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

import warnings
warnings.filterwarnings(action='ignore')

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import sys
sys.setrecursionlimit(15000)

from model.ConvPool import ConvPool, train_ConvPool
from model.ConvSwitch import ConvSwitch, train_ConvSwitch
from utils.data_load import TimeSeriesWithLabels, load_uea_dataset, load_ucr_dataset

########### Hyper-Parameter-fix ###########

num_workers = 4
pin_memory = True
device = 'cuda'

import datetime

result_folder = f'/data3/pooling/check/{datetime.datetime.now()}' 
os.makedirs(result_folder, exist_ok=True)

########### Data Load ###########

def main(args, train_dataset, valid_dataset, test_dataset, num, data_type):
    performance, train_loss_list, valid_loss_list, valid_acc_list = [], [], [], []
    if args.model == 'ConvPool':
        model = ConvPool(input_size= train_dataset.input_size, 
                          time_length = train_dataset.timelength,
                          classes= train_dataset.num_classes, 
                          data_type= data_type,
                          args= args
                          )
        performance, train_loss_list, valid_loss_list, valid_acc_list = train_ConvPool(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, result_folder)
        
    elif args.model == 'ConvSwitch':
        model = ConvSwitch(input_size= train_dataset.input_size, 
                       time_length = train_dataset.timelength,
                       classes= train_dataset.num_classes, 
                       data_type= data_type,
                       args= args
                       )
        performance, train_loss_list, valid_loss_list, valid_acc_list = train_ConvSwitch(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, result_folder)
               
    return performance, train_loss_list, valid_loss_list, valid_acc_list

if __name__=='__main__':
    random_seed=10
    torch.manual_seed(random_seed) # for torch.~~
    torch.backends.cudnn.deterministic = True # for deep learning CUDA library
    torch.backends.cudnn.benchmark = False # for deep learning CUDA library
    np.random.seed(random_seed) # for numpy-based backend, scikit-learn
    random.seed(random_seed) # for python random library-based e.g., torchvision
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
    
    # Arguments parsing
    parser = argparse.ArgumentParser()
    
    ### uni-var ###
    data_name = ['ACSF1',
                 'Adiac',
                 #'AllGestureWiimoteX',
                 #'AllGestureWiimoteY',
                 #'AllGestureWiimoteZ',
                 'ArrowHead',
                 'Beef',
                 'BeetleFly',
                 'BirdChicken',
                 'BME',
                 'Car',
                 'CBF',
                 'Chinatown',
                 'ChlorineConcentration',
                 'CinCECGTorso',
                 'Coffee',
                 'Computers',
                 'CricketX',
                 'CricketY',
                 'CricketZ',
                 'Crop',
                 'DiatomSizeReduction',
                 'DistalPhalanxOutlineAgeGroup',
                 'DistalPhalanxOutlineCorrect',
                 'DistalPhalanxTW',
                 'DodgerLoopDay',
                 'DodgerLoopGame',
                 'DodgerLoopWeekend',
                 'Earthquakes',
                 'ECG200',
                 'ECG5000',
                 'ECGFiveDays',
                 'ElectricDevices',
                 'EOGHorizontalSignal',
                 'EOGVerticalSignal',
                 'EthanolLevel',
                 'FaceAll',
                 'FaceFour',
                 'FacesUCR',
                 'FiftyWords',
                 'Fish',
                 'FordA',
                 'FordB',
                 'FreezerRegularTrain',
                 'FreezerSmallTrain',
                 'Fungi',
                 #'GestureMidAirD1',
                 #'GestureMidAirD2',
                 #'GestureMidAirD3',
                 #'GesturePebbleZ1',
                 #'GesturePebbleZ2',
                 'GunPoint',
                 'GunPointAgeSpan',
                 'GunPointMaleVersusFemale',
                 'GunPointOldVersusYoung',
                 'Ham',
                 'HandOutlines',
                 'Haptics',
                 'Herring',
                 'HouseTwenty',
                 'InlineSkate',
                 'InsectEPGRegularTrain',
                 'InsectEPGSmallTrain',
                 'InsectWingbeatSound',
                 'ItalyPowerDemand',
                 'LargeKitchenAppliances',
                 'Lightning2',
                 'Lightning7',
                 'Mallat',
                 'Meat',
                 'MedicalImages',
                 #'MelbournePedestrian',
                 'MiddlePhalanxOutlineAgeGroup',
                 'MiddlePhalanxOutlineCorrect',
                 'MiddlePhalanxTW',
                 'MixedShapesRegularTrain',
                 'MixedShapesSmallTrain',
                 'MoteStrain',
                 #'NonInvasiveFatalECGThorax1',
                 #'NonInvasiveFatalECGThorax2',
                 'OliveOil',
                 'OSULeaf',
                 'PhalangesOutlinesCorrect',
                 'Phoneme',
                 #'PickupGestureWiimoteZ',
                 'PigAirwayPressure',
                 'PigArtPressure',
                 'PigCVP',
                 #'PLAID',
                 'Plane',
                 'PowerCons',
                 'ProximalPhalanxOutlineAgeGroup',
                 'ProximalPhalanxOutlineCorrect',
                 'ProximalPhalanxTW',
                 'RefrigerationDevices',
                 'Rock',
                 'ScreenType',
                 'SemgHandGenderCh2',
                 'SemgHandMovementCh2',
                 'SemgHandSubjectCh2',
                 #'ShakeGestureWiimoteZ',
                 'ShapeletSim',
                 'ShapesAll',
                 'SmallKitchenAppliances',
                 #'SmoothSubspace',
                 'SonyAIBORobotSurface1',
                 'SonyAIBORobotSurface2',
                 'StarLightCurves',
                 'Strawberry',
                 'SwedishLeaf',
                 'Symbols',
                 'SyntheticControl',
                 'ToeSegmentation1',
                 'ToeSegmentation2',
                 'Trace',
                 'TwoLeadECG',
                 'TwoPatterns',
                 'UMD',
                 'UWaveGestureLibraryAll',
                 'UWaveGestureLibraryX',
                 'UWaveGestureLibraryY',
                 'UWaveGestureLibraryZ',
                 'Wafer',
                 'Wine',
                 'WordSynonyms',
                 'Worms',
                 'WormsTwoClass',
                 'Yoga'
                ]
    
    report_perform = pd.DataFrame()
    report_trainloss = pd.DataFrame()
    report_validloss = pd.DataFrame()
    report_validacc = pd.DataFrame()

    for i, dataset in enumerate(data_name):
        if i > -1:
            train_dataset = TimeSeriesWithLabels(dataset, 'univar', 'TRAIN', 'train') 
            valid_dataset = TimeSeriesWithLabels(dataset, 'univar', 'TRAIN', 'valid') 
            test_dataset = TimeSeriesWithLabels(dataset, 'univar', 'TEST', 'test') 

            parser = argparse.ArgumentParser()        
            parser.add_argument('--gpuidx', default=1, type=int, help='gpu index')
            parser.add_argument('--model', default='ConvPool', type=str, help='ConvPool | ConvSwitch')
            parser.add_argument('--pool', default='STP', type=str, help='GTP | STP | DTP | Switch')
            parser.add_argument('--pool_op', default='MAX', type=str, help='MAX | AVG')
            parser.add_argument('--switch_op', default='batch', type=str, help='batch')
            parser.add_argument('--deep_extract', default='FCN', type=str, help='FCN | ResNet')
            parser.add_argument('--proto_num', default=5, type=int, help = '3 | 4 | 5 | 6 | 7')
            parser.add_argument('--cost_type', default='euclidean', type=str, help='cosine | dotprod | euclidean')

            parser.add_argument('--batch_size', default=8, type=int, help='batch size')
            parser.add_argument('--num_epoch', default=300, type=int, help='# of training epochs')
            parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')

            args = parser.parse_args()
            print(i, dataset)

            os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuidx)

            perform, train_loss_list, valid_loss_list, valid_acc_list = main(args=args, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, num=i, data_type='uni')

            perform_df = pd.DataFrame(perform, index=['loss', 'acc', 'f1macro', 'f1micro', 'f1weight', 'f1mean', 'option'], columns = [dataset]).T
            report_perform = pd.concat([report_perform, perform_df], axis=0)
            report_perform.to_csv(os.path.join(result_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_uni_performance.csv'))

            train_loss_df = pd.DataFrame(train_loss_list, columns = [dataset]).T
            report_trainloss = pd.concat([report_trainloss, train_loss_df], axis=0)
            report_trainloss.to_csv(os.path.join(result_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_uni_trainloss.csv'))

            valid_loss_df = pd.DataFrame(valid_loss_list, columns = [dataset]).T
            report_validloss = pd.concat([report_validloss, valid_loss_df], axis=0)
            report_validloss.to_csv(os.path.join(result_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_uni_validloss.csv'))

            valid_acc_df = pd.DataFrame(valid_acc_list, columns = [dataset]).T
            report_validacc = pd.concat([report_validacc, valid_acc_df], axis=0)
            report_validacc.to_csv(os.path.join(result_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_uni_validacc.csv'))

    ### multi-var ###
    data_name = ['ArticularyWordRecognition',
                 'AtrialFibrillation',
                 'BasicMotions',
                 #'CharacterTrajectories',
                 'Cricket',
                 'DuckDuckGeese',
                 'EigenWorms',
                 'Epilepsy',
                 'EthanolConcentration',
                 'ERing',
                 'FaceDetection',
                 'FingerMovements',
                 'HandMovementDirection',
                 'Handwriting',
                 'Heartbeat',
                 #'InsectWingbeat',
                 #'JapaneseVowels',
                 'Libras',
                 'LSST',
                 'MotorImagery',
                 'NATOPS',
                 #'PenDigits',
                 'PEMS-SF',
                 #'Phoneme',
                 'RacketSports',
                 'SelfRegulationSCP1',
                 'SelfRegulationSCP2',
                 #'SpokenArabicDigits',
                 #'StandWalkJump',
                 #'UWaveGestureLibrary'
                ]

    report_perform = pd.DataFrame()
    report_trainloss = pd.DataFrame()
    report_validloss = pd.DataFrame()
    report_validacc = pd.DataFrame()
    
    for i, dataset in enumerate(data_name):
        train_dataset = TimeSeriesWithLabels(dataset, 'multivar', 'TRAIN', 'train') 
        valid_dataset = TimeSeriesWithLabels(dataset, 'multivar', 'TRAIN', 'valid') 
        test_dataset = TimeSeriesWithLabels(dataset, 'multivar', 'TEST', 'test') 
        
        parser = argparse.ArgumentParser()        
        parser.add_argument('--gpuidx', default=1, type=int, help='gpu index')
        parser.add_argument('--model', default='ConvPool', type=str, help='ConvPool | ConvSwitch')
        parser.add_argument('--pool', default='STP', type=str, help='GTP | STP | DTP | Switch')
        parser.add_argument('--pool_op', default='MAX', type=str, help='MAX | AVG')
        parser.add_argument('--switch_op', default='batch', type=str, help='batch')
        parser.add_argument('--deep_extract', default='FCN', type=str, help='FCN | ResNet')
        parser.add_argument('--proto_num', default=5, type=int, help = '3 | 4 | 5 | 6 | 7')
        parser.add_argument('--cost_type', default='euclidean', type=str, help='cosine | dotprod | euclidean')
        
        parser.add_argument('--batch_size', default=8, type=int, help='batch size')
        parser.add_argument('--num_epoch', default=300, type=int, help='# of training epochs')
        parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
        
        args = parser.parse_args()
        print(i, dataset)
        
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuidx)
        
        random_seed=10
        torch.manual_seed(random_seed) # for torch.~~
        torch.backends.cudnn.deterministic = True # for deep learning CUDA library
        torch.backends.cudnn.benchmark = False # for deep learning CUDA library
        np.random.seed(random_seed) # for numpy-based backend, scikit-learn
        random.seed(random_seed) # for python random library-based e.g., torchvision
        torch.cuda.manual_seed(random_seed)
        torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
        
        perform, train_loss_list, valid_loss_list, valid_acc_list = main(args=args, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, num=i, data_type='mul')
        
        perform_df = pd.DataFrame(perform, index=['loss', 'acc', 'f1macro', 'f1micro', 'f1weight', 'f1mean', 'option'], columns = [dataset]).T
        report_perform = pd.concat([report_perform, perform_df], axis=0)
        report_perform.to_csv(os.path.join(result_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_mul_performance.csv'))
        
        train_loss_df = pd.DataFrame(train_loss_list, columns = [dataset]).T
        report_trainloss = pd.concat([report_trainloss, train_loss_df], axis=0)
        report_trainloss.to_csv(os.path.join(result_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_mul_trainloss.csv'))
        
        valid_loss_df = pd.DataFrame(valid_loss_list, columns = [dataset]).T
        report_validloss = pd.concat([report_validloss, valid_loss_df], axis=0)
        report_validloss.to_csv(os.path.join(result_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_mul_validloss.csv'))
        
        valid_acc_df = pd.DataFrame(valid_acc_list, columns = [dataset]).T
        report_validacc = pd.concat([report_validacc, valid_acc_df], axis=0)
        report_validacc.to_csv(os.path.join(result_folder, f'{args.model}_{args.deep_extract}_{args.pool}_{args.pool_op}_{args.switch_op}_mul_validacc.csv'))

                         
        