

import torch
import numpy as np
import random
import warnings
import argparse
import time
warnings.filterwarnings("ignore")

from data import dataloader
from memory import Buffer
from trainer import Trainer
import utils
from copy import deepcopy

import os


from model.skl_cl import SKI_CL



def main(args):



    raw_data_set = {}    
    train_set = {}   
    val_set = {}      
    test_set = {}    
    scaler_set = {}  
    adj_set = {}      
    results = {}
    NAME_LIST = ['train','train_normalized','val','val_normalized','test','test_normalized']  


    if args.data_name == 'traffic':

        stages = [1,2,3,4,5,6,7]
        stage_num = len(stages)
        results['rmse'] = np.zeros((stage_num,stage_num))
        results['mse'] = np.zeros((stage_num,stage_num))
        results['mae'] = np.zeros((stage_num,stage_num))
        results['mape'] = np.zeros((stage_num,stage_num))
        results['precision'] = np.zeros((stage_num,stage_num))
        results['recall'] = np.zeros((stage_num,stage_num))
        data_pre = "PEMS3-CL_"
        args.data_pre = data_pre
        adj_dir  = 'data/PEMS3-cl/'
        node_number = 22
        args.node_number = node_number

   
    if args.model_name == 'ski-cl':
        model = SKI_CL(num_nodes = node_number,
                         input_dim = 1,
                         rnn_units = 64,
                         output_dim = 1,
                         lag = args.lag,
                         horizon = args.horizon,
                         num_layers = 1 ,
                         cheb_k =2)
    
    
    memory_buffer = Buffer(args)

    for i in range(len(stages)):

        stage = stages[i]
        data_name = data_pre+str(stage)
        args.data_name_stage = data_name
        print("Now: " + str(data_name))

        if args.prior_form == 'binary':
            adj = np.load(adj_dir+str(stage)+'_adj.npy')
            adj_set[stage] = adj

        data,scaler = dataloader.get_normalized_data(data_name)
        raw_data_set[stage] = data

        window_set = {}
        stage_ind = {}

        
        for name in NAME_LIST:
          features,target = dataloader.add_window(data[name],name=data_name,lag = args.lag,horizon = args.horizon)
          window_set[name+'_X'] = features
          window_set[name+'_Y'] = target
        
        current_stage_training_num = len(window_set['train_normalized'+'_X'])
        current_stage_val_num = len(window_set['val_normalized'+'_X'])
        current_stage_test_num = len(window_set['test_normalized'+'_X'])

        stage_ind['train']= torch.tensor([stage]).repeat(current_stage_training_num)
        stage_ind['val'] = torch.tensor([stage]).repeat(current_stage_val_num)
        stage_ind['test'] = torch.tensor([stage]).repeat(current_stage_test_num)


        if i>0 and args.selection_method[:5] == 'joint':
            
            stage_ind,window_set = memory_buffer.combine(stage_ind,window_set)  


        loaders = dataloader.get_forcasting_dataloader(stage_ind,window_set,batch_size = args.batch_size)

        train_set[stage] = loaders["train"]
        val_set[stage] = loaders["val"]
        test_set[stage] = loaders["test"]
        scaler_set[stage] = scaler 

        if stage > 1:
            name = data_pre+str(stage-1)
            path = "model_check_point/{}/{}/{}/best_model.pth".format(args.model_name,args.selection_method,name)
            check_point = torch.load(path)
            state_dict = check_point['state_dict']
            model.load_state_dict(state_dict)

        trainer = Trainer(args,model,stage,adj_set,optimizer_lr = 0.0001)

        trainer.train(loaders["train"],loaders["val"],memory_buffer,data)

        test_stage=list(range(1,stage+1))
        for j in range(len(test_stage)):
            st = test_stage[j]
            print('testing:',str(st))
            trainer.test(results,test_set[st],scaler_set[st],data_name,train_stage_index = i ,test_stage_index = j,raw_data = raw_data_set[st] )

        result_save_dirs = "result/{}/{}/{}".format(args.data_name,args.model_name,args.selection_method)
                
        if not os.path.exists(result_save_dirs):
            os.makedirs(result_save_dirs)
        
        np.savetxt(result_save_dirs+'/'+'rmse.txt', results['rmse'], delimiter =',')
        np.savetxt(result_save_dirs+'/'+'mse.txt', results['mse'], delimiter =',')
        np.savetxt(result_save_dirs+'/'+'mae.txt', results['mae'], delimiter =',')
        np.savetxt(result_save_dirs+'/'+'mape.txt', results['mape'], delimiter =',')
        if args.prior_form == 'binary':
            np.savetxt(result_save_dirs+'/'+'precision.txt', results['precision'], delimiter =',')
            np.savetxt(result_save_dirs+'/'+'recall.txt', results['recall'], delimiter =',')
        else:
            np.savetxt(result_save_dirs+'/'+'graph_mae.txt', results['graph_mae'], delimiter =',')
            np.savetxt(result_save_dirs+'/'+'graph_rmse.txt', results['graph_rmse'], delimiter =',')

        
        memory_buffer.add_reservoir(args,
                                    window_set,
                                    stage_ind,
                                    current_stage_training_num,
                                    current_stage_val_num,
                                    model,
                                    stage,
                                    raw_data_set[stage],
                                    adj_set[stage])
        
      

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()

    # data
    parser.add_argument('--data_name', type=str, default='solar') #[traffic,solar]
    parser.add_argument('--data_pre', type=str, default='none')
    parser.add_argument('--data_name_stage', type=str, default='')
    parser.add_argument('--prior_form', type=str, default='binary')
    parser.add_argument('--node_number', type=int, default=0)
    parser.add_argument('--lag', type=int, default=24)
    parser.add_argument('--horizon', type=int, default=12) 
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--training_epoch', type=int, default=200)
    parser.add_argument('--alpha_ts_stage', type=float, default=1.0)
    parser.add_argument('--alpha_ts_mem', type=float, default=1.0)
    parser.add_argument('--graph_coef', type=float, default=0.0)
    parser.add_argument('--selection_method', type=str, default='random')
    parser.add_argument('--ratio', type=float, default=0.1)
    parser.add_argument('--seg', type=int, default=7)
    now = int(round(time.time()*1000))
    begin_time = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(now/1000))


    args = parser.parse_args()
    main(args)

    