import numpy as np
import pandas as pd
from matplotlib import figure
import matplotlib.pyplot as plt
import seaborn as sns
import math
import os
import time
import random
import gc
import statistics
import json
import argparse
from tqdm.auto import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.utils import class_weight
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

import warnings
warnings.filterwarnings(action='ignore')

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import sys
sys.setrecursionlimit(15000)

from model.ConvPool import ConvPool, LRP_ConvPool
from model.ConvSwitch import ConvSwitch, LRP_ConvSwitch
from utils.data_load import TimeSeriesWithLabels, load_uea_dataset, load_ucr_dataset


########### Hyper-Parameter-fix ###########

num_workers = 4
pin_memory = True
device = 'cuda'


def main(args, train_dataset, valid_dataset, test_dataset, num, data_type, model_folder, result_folder, name):
    if args.model == 'ConvPool':
        model = ConvPool(input_size= train_dataset.input_size, 
                          time_length = train_dataset.timelength,
                          classes= train_dataset.num_classes, 
                          data_type= data_type,
                          args= args
                          )
        LRP_ConvPool(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, model_folder, result_folder, name)
      
    else:
        model = ConvSwitch(input_size= train_dataset.input_size, 
                          time_length = train_dataset.timelength,
                          classes= train_dataset.num_classes, 
                          data_type= data_type,
                          args= args
                          )
        LRP_ConvSwitch(args, train_dataset, valid_dataset, test_dataset, num, data_type, model, model_folder, result_folder, name)
        
        
if __name__=='__main__':
    random_seed=10
    torch.manual_seed(random_seed) # for torch.~~
    torch.backends.cudnn.deterministic = True # for deep learning CUDA library
    torch.backends.cudnn.benchmark = False # for deep learning CUDA library
    np.random.seed(random_seed) # for numpy-based backend, scikit-learn
    random.seed(random_seed) # for python random library-based e.g., torchvision
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
    
    # Arguments parsing
    parser = argparse.ArgumentParser()
    
    ### uni-var ###
    data_name = ['ACSF1',
                 'Adiac',
                 #'AllGestureWiimoteX',
                 #'AllGestureWiimoteY',
                 #'AllGestureWiimoteZ',
                 'ArrowHead',
                 'Beef',
                 'BeetleFly',
                 'BirdChicken',
                 'BME',
                 'Car',
                 'CBF',
                 'Chinatown',
                 'ChlorineConcentration',
                 'CinCECGTorso',
                 'Coffee',
                 'Computers',
                 'CricketX',
                 'CricketY',
                 'CricketZ',
                 'Crop',
                 'DiatomSizeReduction',
                 'DistalPhalanxOutlineAgeGroup',
                 'DistalPhalanxOutlineCorrect',
                 'DistalPhalanxTW',
                 'DodgerLoopDay',
                 'DodgerLoopGame',
                 'DodgerLoopWeekend',
                 'Earthquakes',
                 'ECG200',
                 'ECG5000',
                 'ECGFiveDays',
                 'ElectricDevices',
                 'EOGHorizontalSignal',
                 'EOGVerticalSignal',
                 'EthanolLevel',
                 'FaceAll',
                 'FaceFour',
                 'FacesUCR',
                 'FiftyWords',
                 'Fish',
                 'FordA',
                 'FordB',
                 'FreezerRegularTrain',
                 'FreezerSmallTrain',
                 'Fungi',
                 #'GestureMidAirD1',
                 #'GestureMidAirD2',
                 #'GestureMidAirD3',
                 #'GesturePebbleZ1',
                 #'GesturePebbleZ2',
                 'GunPoint',
                 'GunPointAgeSpan',
                 'GunPointMaleVersusFemale',
                 'GunPointOldVersusYoung',
                 'Ham',
                 'HandOutlines',
                 'Haptics',
                 'Herring',
                 'HouseTwenty',
                 'InlineSkate',
                 'InsectEPGRegularTrain',
                 'InsectEPGSmallTrain',
                 'InsectWingbeatSound',
                 'ItalyPowerDemand',
                 'LargeKitchenAppliances',
                 'Lightning2',
                 'Lightning7',
                 'Mallat',
                 'Meat',
                 'MedicalImages',
                 #'MelbournePedestrian',
                 'MiddlePhalanxOutlineAgeGroup',
                 'MiddlePhalanxOutlineCorrect',
                 'MiddlePhalanxTW',
                 'MixedShapesRegularTrain',
                 'MixedShapesSmallTrain',
                 'MoteStrain',
                 #'NonInvasiveFatalECGThorax1',
                 #'NonInvasiveFatalECGThorax2',
                 'OliveOil',
                 'OSULeaf',
                 'PhalangesOutlinesCorrect',
                 'Phoneme',
                 #'PickupGestureWiimoteZ',
                 'PigAirwayPressure',
                 'PigArtPressure',
                 'PigCVP',
                 #'PLAID',
                 'Plane',
                 'PowerCons',
                 'ProximalPhalanxOutlineAgeGroup',
                 'ProximalPhalanxOutlineCorrect',
                 'ProximalPhalanxTW',
                 'RefrigerationDevices',
                 'Rock',
                 'ScreenType',
                 'SemgHandGenderCh2',
                 'SemgHandMovementCh2',
                 'SemgHandSubjectCh2',
                 #'ShakeGestureWiimoteZ',
                 'ShapeletSim',
                 'ShapesAll',
                 'SmallKitchenAppliances',
                 #'SmoothSubspace',
                 'SonyAIBORobotSurface1',
                 'SonyAIBORobotSurface2',
                 'StarLightCurves',
                 'Strawberry',
                 'SwedishLeaf',
                 'Symbols',
                 'SyntheticControl',
                 'ToeSegmentation1',
                 'ToeSegmentation2',
                 'Trace',
                 'TwoLeadECG',
                 'TwoPatterns',
                 'UMD',
                 'UWaveGestureLibraryAll',
                 'UWaveGestureLibraryX',
                 'UWaveGestureLibraryY',
                 'UWaveGestureLibraryZ',
                 'Wafer',
                 'Wine',
                 'WordSynonyms',
                 'Worms',
                 'WormsTwoClass',
                 'Yoga'
                ]

    only = ['CricketZ', 'ElectricDevices', 'Fungi', 'WordSynonyms']
        
    for i, dataset in enumerate(data_name):
        if dataset in only:
            train_dataset = TimeSeriesWithLabels(dataset, 'univar', 'TRAIN', 'train') 
            valid_dataset = TimeSeriesWithLabels(dataset, 'univar', 'TRAIN', 'valid') 
            test_dataset = TimeSeriesWithLabels(dataset, 'univar', 'TEST', 'test') 

            parser = argparse.ArgumentParser()        
            parser.add_argument('--gpuidx', default=1, type=int, help='gpu index')
            parser.add_argument('--model', default='ConvPool', type=str, help='ConvPool | ConvSwitch')
            parser.add_argument('--pool', default='STP', type=str, help='GTP | STP | DTP')
            parser.add_argument('--pool_op', default='MAX', type=str, help='MAX | AVG')
            parser.add_argument('--switch_op', default='sample', type=str, help='sample | batch | ensem')
            parser.add_argument('--deep_extract', default='FCN', type=str, help='FCN')
            parser.add_argument('--proto_num', default=5, type=int, help = '3 | 4 | 5 | 6 | 7')
            parser.add_argument('--cost_type', default='euclidean', type=str, help='cosine | dotprod | euclidean')

            parser.add_argument('--batch_size', default=8, type=int, help='batch size')
            parser.add_argument('--num_epoch', default=500, type=int, help='# of training epochs')
            parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
            args = parser.parse_args()
            print(i, dataset)

            os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuidx)
            
            if args.model == 'ConvPool':
                model_folder = f'/data3/pooling/check/{args.model}_FCN_{args.pool}' 
                result_folder = f'/data3/pooling/check/{args.model}_FCN_{args.pool}_LRP'
                os.makedirs(result_folder, exist_ok=True)
            else:
                model_folder = f'/data3/pooling/check/{args.model}_FCN_{args.switch_op}' 
                result_folder = f'/data3/pooling/check/{args.model}_FCN_{args.switch_op}_LRP'
                os.makedirs(result_folder, exist_ok=True)

            main(args=args, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, num=i, data_type='uni', model_folder=model_folder, result_folder=result_folder, name = dataset)
    
    
    ### multi-var ###
    data_name = ['ArticularyWordRecognition',
                 'AtrialFibrillation',
                 'BasicMotions',
                 #'CharacterTrajectories',
                 'Cricket',
                 'DuckDuckGeese',
                 'EigenWorms',
                 'Epilepsy',
                 'EthanolConcentration',
                 'ERing',
                 'FaceDetection',
                 'FingerMovements',
                 'HandMovementDirection',
                 'Handwriting',
                 'Heartbeat',
                 #'InsectWingbeat',
                 #'JapaneseVowels',
                 'Libras',
                 'LSST',
                 'MotorImagery',
                 'NATOPS',
                 #'PenDigits',
                 'PEMS-SF',
                 #'Phoneme',
                 'RacketSports',
                 'SelfRegulationSCP1',
                 'SelfRegulationSCP2',
                 #'SpokenArabicDigits',
                 'StandWalkJump',
                 #'UWaveGestureLibrary'
                ]
    data_name=[]
    
    for i, dataset in enumerate(data_name):
        train_dataset = TimeSeriesWithLabels(dataset, 'multivar', 'TRAIN', 'train') 
        valid_dataset = TimeSeriesWithLabels(dataset, 'multivar', 'TRAIN', 'valid') 
        test_dataset = TimeSeriesWithLabels(dataset, 'multivar', 'TEST', 'test') 
        
        parser = argparse.ArgumentParser()        
        parser.add_argument('--gpuidx', default=1, type=int, help='gpu index')
        parser.add_argument('--model', default='ConvPool', type=str, help='ConvPool | ConvSwitch')
        parser.add_argument('--pool', default='STP', type=str, help='GTP | STP | DTP')
        parser.add_argument('--pool_op', default='MAX', type=str, help='MAX | AVG')
        parser.add_argument('--switch_op', default='sample', type=str, help='sample | batch | ensem')
        parser.add_argument('--deep_extract', default='FCN', type=str, help='FCN | ResNet')
        parser.add_argument('--proto_num', default=5, type=int, help = '3 | 4 | 5 | 6 | 7')
        parser.add_argument('--cost_type', default='euclidean', type=str, help='cosine | dotprod | euclidean')
        
        parser.add_argument('--batch_size', default=8, type=int, help='batch size')
        parser.add_argument('--num_epoch', default=500, type=int, help='# of training epochs')
        parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
        
        args = parser.parse_args()
        print(i, dataset)
        
        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuidx)
        
        model_folder = f'/data3/pooling/check/{args.model}_FCN_{args.pool}' 
        result_folder = f'/data3/pooling/check/{args.model}_FCN_{args.pool}_LRP'
        os.makedirs(result_folder, exist_ok=True)
        
        main(args=args, train_dataset=train_dataset, valid_dataset=valid_dataset, test_dataset=test_dataset, num=i, data_type='mul', model_folder=model_folder, result_folder=result_folder)
    
    
    
    
    
    