# handle OOM
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:1024'

import ValueHome as ValueHome
import ValueHome_utils
from torch.utils.data import DataLoader
import torch
import datetime
import minitransformer as mtrans
import numpy as np
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
import evaluate
import time
import random
import pickle


def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)
    return


def Create_tester(model, eval_dataLoader, optimizer, savepth, savelabel="", batch_first=False, lr=5e-5, num_epochs=3, cur_epoch=0, loss=0, learn_decay=0):
    
    num_epochs = num_epochs
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if "diff" in data_type:
        human_flag = False
        if ("target" in encode_type) or ("act2trg" in encode_type): 
            target_template = pickle.load(open("gpt_target_template.pkl", "rb"))
    else:
        human_flag = True
        if ("target" in encode_type) or ("act2trg" in encode_type): 
            target_template = pickle.load(open("human_target_template.pkl", "rb"))

    model.to(device)
    optimizer_to(optimizer, device)


    model.eval()
    acc_dict = {
        1: 0,
        3: 0,
        5: 0
    }
    dis = 0

    for step, batch in enumerate(eval_dataLoader):
        # batch = {k: v.to(device) for k, v in batch.items()}
        src, trg, trg_y = batch# Permute from shape [batch size, seq len, num features] to [seq len, batch size, num features]
        if batch_first == False:
            src = src.permute(1, 0, 2)
            trg = trg.permute(1, 0, 2)
            trg_y = trg_y.permute(1, 0, 2)
        with torch.no_grad():
            outputs = model(
                src=src.to(device),
                tgt=trg.to(device),
                src_mask=src_mask.to(device),
                tgt_mask=tgt_mask.to(device)
            )
        
        if "duration" in encode_type:
            dur_dis = ValueHome_utils.accuracy_cal_duration(outputs, trg_y.to(device))
            dis += dur_dis
        elif "action" in encode_type:
            act_acc1, act_acc3, act_acc5 = ValueHome_utils.accuracy_cal_onehot(outputs, trg_y.to(device))
            acc_dict[1] += act_acc1
            acc_dict[3] += act_acc3
            acc_dict[5] += act_acc5
        elif ("target" in encode_type) or (encode_type == "act2trg"):
            tar_acc1, tar_acc3, tar_acc5 = ValueHome_utils.accuracy_cal_target(outputs, trg_y.to(device), 
                                                            userId=datatype.split("_")[0], 
                                                            human_flag=human_flag, 
                                                            templates=target_template)
            acc_dict[1] += tar_acc1
            acc_dict[3] += tar_acc3
            acc_dict[5] += tar_acc5
        elif "end2end" in encode_type:
            tar_acc1, tar_acc3, tar_acc5 = ValueHome_utils.accuracy_cal_target(outputs[:,:,:384], trg_y[:,:,:384].to(device),  
                                                            userId=-1, 
                                                            human_flag=False, 
                                                            templates=target_template)
            # tar_acc1, tar_acc3, tar_acc5 = ValueHome_utils.accuracy_cal_onehot(outputs[:,:,(384+14):(384+40)], trg_y[:,:,(384+14):(384+40)].to(device))
            # tar_acc1 = ValueHome_utils.accuracy_cal_duration(outputs[:,:,(384+6):(384+7)], trg_y[:,:,(384+6):(384+7)].to(device))
            # tar_acc3=0
            # tar_acc5=0
            acc_dict[1] += tar_acc1
            acc_dict[3] += tar_acc3
            acc_dict[5] += tar_acc5
    
    if "action" in encode_type:
        acc_dict[1] = acc_dict[1] / (step+1)
        acc_dict[3] = acc_dict[3] / (step+1)
        acc_dict[5] = acc_dict[5] / (step+1)
    elif "duration" in encode_type:
        dis = dis / (step+1)  # average distance
    elif ("target" in encode_type) or ("act2trg" in encode_type):
        acc_dict[1] = acc_dict[1] / (step+1)
        acc_dict[3] = acc_dict[3] / (step+1)
        acc_dict[5] = acc_dict[5] / (step+1)
    
    if "duration" in encode_type:
        print('epoch={0}, dis={1}. lr={2}'.format(epoch, dis, lr))
    else:
        print('epoch={0}, top1 acc={1}, top3 acc={2}, top5 acc={3}. lr={4}'.format(epoch, acc_dict[1], acc_dict[3],acc_dict[5],lr))
    
    return acc_dict, dis



def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True


# def Create_MiniTransformer():
if __name__ == '__main__':
    dim_list = [4]
    layer_list = [8]
    encode_layer_list = [1]
    head_list = [4]
    lr_list = [0.000001]
    window_list = [14]
    data_list = [
        # "02_diff2full",
        # "03_diff2full",
        # "04_diff2full",
        # "05_diff2full",
        # "05_full",
        # "16_full",
        # "17_full",
        "18_full",
        # "19_full"
        ]
    
    setup_seed(7)

    for dim in dim_list:
        for layer in layer_list:
            for head in head_list:
                for encode_layer in encode_layer_list:
                    for lr in lr_list:
                        for window in window_list:
                            for datatype in data_list:

                                # Hyperparams
                                PRETRAIN_FLAG = True  #True
                                SHUFFLE_FALG = True
                                # encode_type = "action_onlyaction"
                                # encode_type = "duration_onlyduration"
                                # encode_type = "act2trg"
                                # encode_type = "target_onlytarget"
                                # encode_type = "0916target_onlytarget"
                                encode_type = "end2end_act2trg"
                                top_num = 3
                                root = "."
                                data_type = datatype
                                if 'diff' in data_type:
                                    arranged_dataset_root = "/home/zhe/Documents/Code/IntentionProject/ValueHomeDataset/GPT/gpt_plan/gpt_plan/GPT_arranged"
                                    data_dir = arranged_dataset_root + f"/gpt4_{encode_type.split('_')[-1]}"
                                else:
                                    arranged_dataset_root = "/home/zhe/Documents/Code/IntentionProject/ValueHomeDataset/HumanData/human_arranged"
                                    data_dir = arranged_dataset_root + f"/human_{encode_type.split('_')[-1]}"

                                if "onlyaction" in encode_type:
                                    # pretrain_pth = root + "/models/action_onlyaction_02_diff2full/2023-09-26_18-23-19_epoch104_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/action_onlyaction_03_diff2full/2023-09-26_18-26-19_epoch170_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/action_onlyaction_04_diff2full/2023-09-26_18-28-28_epoch170_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/action_onlyaction_05_diff2full/2023-09-26_18-29-06_epoch101_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/action_onlyaction_05_full/2023-09-26_18-33-32_epoch170_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/action_onlyaction_16_full/2023-09-26_18-35-38_epoch170_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/action_onlyaction_17_full/2023-09-26_18-36-19_epoch170_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/action_onlyaction_18_full/2023-09-26_18-37-43_epoch107_lr1e-06.pth"
                                    pretrain_pth = root + "/models/action_onlyaction_19_full/2023-09-26_18-46-23_epoch97_lr1e-06.pth"
                                elif "onlyduration" in encode_type:
                                    # pretrain_pth = root + "/models/duration_onlyduration_02_diff2full/2023-09-26_18-36-58_epoch56_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/duration_onlyduration_03_diff2full/2023-09-26_18-49-50_epoch98_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/duration_onlyduration_04_diff2full/2023-09-26_19-03-47_epoch151_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/duration_onlyduration_05_diff2full/2023-09-26_19-07-17_epoch91_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/duration_onlyduration_05_full/2023-09-26_19-24-07_epoch144_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/duration_onlyduration_16_full/2023-09-26_19-30-04_epoch151_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/duration_onlyduration_17_full/2023-09-26_19-31-55_epoch151_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/duration_onlyduration_18_full/2023-09-26_19-36-49_epoch151_lr1e-06.pth"
                                    pretrain_pth = root + "/models/duration_onlyduration_19_full/2023-09-26_19-37-23_epoch71_lr1e-06.pth"
                                elif "onlytarget" in encode_type:
                                    # pretrain_pth = root + "/models/target_onlytarget_02_diff2full/2023-09-26_18-40-23_epoch401_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/target_onlytarget_03_diff2full/2023-09-26_18-58-09_epoch499_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/target_onlytarget_04_diff2full/2023-09-26_18-59-40_epoch416_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/target_onlytarget_05_diff2full/2023-09-26_19-15-10_epoch499_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/target_onlytarget_05_full/2023-09-26_19-18-33_epoch495_lr1e-06.pth"
                                    # # pretrain_pth = root + "/models/target_onlytarget_16_full/2023-09-26_19-18-44_epoch401_lr1e-06.pth"
                                    # # pretrain_pth = root + "/models/target_onlytarget_17_full/2023-09-26_19-18-33_epoch495_lr1e-06.pth"
                                    # # pretrain_pth = root + "/models/target_onlytarget_18_full/2023-09-26_19-18-33_epoch495_lr1e-06.pth"
                                    # # pretrain_pth = root + "/models/target_onlytarget_19_full/2023-09-26_19-18-33_epoch495_lr1e-06.pth"
                                    pretrain_pth = root + "/models/target_finetune_before0916/target_onlytarget_05_full/2023-09-16_11-04-03_epoch192_lr1e-05.pth"
                                    # pretrain_pth = root + "/models/target_finetune_before0916/target_onlytarget_16_full/2023-09-16_11-05-17_lr1e-05.pth"
                                    # pretrain_pth = root + "/models/target_finetune_before0916/target_onlytarget_17_full/2023-09-16_11-05-41_lr1e-05.pth"
                                    # pretrain_pth = root + "/models/target_finetune_before0916/target_onlytarget_18_full/2023-09-16_11-06-16_epoch171_lr1e-05.pth"
                                    
                                elif encode_type == "act2trg":
                                    # pretrain_pth = root + "/models/act2trg_02_diff2full/2023-09-26_18-44-49_epoch561_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/act2trg_03_diff2full/2023-09-26_18-59-30_epoch599_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/act2trg_04_diff2full/2023-09-26_19-10-23_epoch595_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/act2trg_05_diff2full/2023-09-26_19-11-03_epoch496_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/act2trg_05_full/2023-09-26_19-32-23_epoch599_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/act2trg_16_full/2023-09-26_19-38-39_epoch582_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/act2trg_17_full/2023-09-26_19-40-37_epoch599_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/act2trg_18_full/2023-09-26_19-41-11_epoch517_lr1e-06.pth"
                                    pretrain_pth = root + "/models/act2trg_19_full/2023-09-26_19-44-39_epoch599_lr1e-06.pth"
                                
                                elif "end2end_act2trg" in encode_type:
                                    # pretrain_pth = root + "/models/end2end_act2trg_02_diff2full/2023-09-28_20-49-56_epoch399_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/end2end_act2trg_03_diff2full/2023-09-28_20-58-40_epoch399_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/end2end_act2trg_04_diff2full/2023-09-28_21-08-24_epoch399_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/end2end_act2trg_05_diff2full/2023-09-28_21-18-13_epoch399_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/end2end_act2trg_05_full/2023-09-28_21-28-09_epoch399_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/end2end_act2trg_16_full/2023-09-28_21-34-57_epoch399_lr1e-06.pth"
                                    # pretrain_pth = root + "/models/end2end_act2trg_17_full/2023-09-28_21-36-44_epoch399_lr1e-06.pth"
                                    pretrain_pth = root + "/models/end2end_act2trg_18_full/2023-09-28_21-37-45_epoch229_lr1e-06.pth"

                                if encode_type == "act2trg":
                                    window = 1

                                savelabel = f"lr{str(lr)}" #_dim{str(dim)}ly{str(layer)}ecd{str(encode_layer)}hd{str(head)}win{str(window)}"
                                savepath = f"models/{encode_type}_{data_type}/{savelabel}"
                                if not os.path.exists(os.path.join(root, f"models/{encode_type}_{data_type}")):
                                    os.mkdir(os.path.join(root, f"models/{encode_type}_{data_type}"))
                                if not os.path.exists(os.path.join(root, savepath)):
                                    os.mkdir(os.path.join(root, savepath))
                                
                                # print training info
                                print("start training with params:")
                                print("learing rate: {0}".format(lr))
                                print("dim: {0}".format(dim))
                                print("layer: {0}".format(layer))
                                print("encode_layer: {0}".format(encode_layer))
                                print("head: {0}".format(head))
                                print("window: {0}".format(window))
                                with open(os.path.join(root, savepath, "train_info.txt"), "w") as f:
                                    f.write("start training with params:\n")
                                    f.write("learing rate: {0}\n".format(lr))
                                    f.write("dim: {0}\n".format(dim))
                                    f.write("layer: {0}\n".format(layer))
                                    f.write("encode_layer: {0}\n".format(encode_layer))
                                    f.write("head: {0}\n".format(head))
                                    f.write("window: {0}\n".format(window))
                                
                                if "diff" in data_type:
                                    test_size = 0.01
                                    eval_size = 0.19 #0.07
                                    train_start = 0 #0.53
                                    batch_size = 8
                                else:
                                    test_size = 0.01
                                    eval_size = 0.19
                                    train_start = 0
                                    batch_size = 8
                                learning_rate = lr #0.0001
                                if "onlytarget" in encode_type:
                                    max_epoch = 2000
                                elif "act2trg" in encode_type:
                                    max_epoch = 1700
                                elif "onlyaction" in encode_type:
                                    max_epoch = 2200
                                else:
                                    max_epoch = 1100
                                if "action" in encode_type:
                                    # creating target col name
                                    label_num = 0  # 19(actions)
                                    feature_num = 1218  # 1211(features)
                                elif "duration" in encode_type:
                                    label_num = 0  # 1(duration)
                                    feature_num = 1218
                                elif "target" in encode_type:
                                    label_num = 0
                                    feature_num = 1370 # 384 + 986
                                elif "act2trg" in encode_type:
                                    label_num = 0
                                    feature_num = 1602  # 384 + 1211, but only use 1211
                                #- start creating column names
                                exogenous_vars = []  # labels. should contain strings. Each string must correspond to a column name
                                for i in range(label_num):
                                    exogenous_vars.append("label_%s" % str(i).zfill(3))
                                indexes = np.arange(feature_num).astype(str)
                                target_col_name = indexes.tolist()
                                #- end creating
                                cur_time, _ = ValueHome_utils.vis_time(time.time())

                                ## Params
                                dim_val = dim #32
                                n_heads = head #2
                                n_decoder_layers = encode_layer #1
                                n_encoder_layers = encode_layer #1
                                dec_seq_len = 1 #360 # length of input given to decoder
                                enc_seq_len = window #20 # length of input given to encoder
                                output_sequence_length = 1 #360 # target sequence length. If hourly data and length = 48, you predict 2 days ahead
                                window_size = enc_seq_len + output_sequence_length # used to slice data into sub-sequences
                                step_size = 1 # Step size, i.e. how many time steps does the moving window move at each step
                                in_features_encoder_linear_layer = layer #64 # 2048
                                in_features_decoder_linear_layer = layer  #64 # 2048
                                max_seq_len = enc_seq_len
                                batch_first = False
                                learn_decay = 0.1
                                

                                # Define input variables
                                input_variables = exogenous_vars + target_col_name
                                target_idx = 0 # index position of target in batched trg_y

                                if encode_type == "act2trg":
                                    input_size = 1218
                                else:
                                    input_size = len(input_variables)

                                # Define number of predicted features    
                                if "onlyaction" in encode_type:
                                    num_predicted_features = 26
                                elif "onlyduration" in encode_type:
                                    num_predicted_features = 1
                                elif "onlytarget" in encode_type:
                                    num_predicted_features = 384
                                elif encode_type == "act2trg":
                                    num_predicted_features = 384
                                elif "end2end" in encode_type:
                                    num_predicted_features = 1602
                                else:
                                    num_predicted_features = len(input_variables)

                                # Read data
                                _, raw_eval_data, _ = ValueHome_utils.ValueHomeLoader(data_dir=data_dir, 
                                                                                        test_size=test_size, 
                                                                                        eval_size=eval_size, 
                                                                                        train_start=train_start,
                                                                                        label_num=label_num, 
                                                                                        feature_num=feature_num,
                                                                                        data_type=data_type)
                                # Make list of (start_idx, end_idx) pairs that are used to slice the time series sequence into chunkc. 
                                eval_indices = ValueHome_utils.get_indices_entire_sequence(
                                    data=raw_eval_data, 
                                    window_size=window_size, 
                                    step_size=step_size)

                                if len(eval_indices) == 0:
                                    print(f"No eval or test data. Skipping {data_type}...")
                                    continue

                                # Making instance of custom dataset class
                                eval_data = ValueHome.ValueHome(
                                    data=torch.from_numpy(raw_eval_data[input_variables].values).float(),
                                    indices=eval_indices,
                                    enc_seq_len=enc_seq_len,
                                    dec_seq_len=dec_seq_len,
                                    target_seq_len=output_sequence_length
                                    )

                                # Making dataloader
                                eval_data = DataLoader(eval_data, batch_size=1, shuffle=SHUFFLE_FALG)

                                # Make src mask for decoder with size:
                                # [batch_size*n_heads, output_sequence_length, enc_seq_len]
                                src_mask = ValueHome_utils.generate_square_subsequent_mask(
                                    dim1=output_sequence_length,
                                    dim2=enc_seq_len
                                    )

                                # Make tgt mask for decoder with size:
                                # [batch_size*n_heads, output_sequence_length, output_sequence_length]
                                tgt_mask = ValueHome_utils.generate_square_subsequent_mask( 
                                    dim1=output_sequence_length,
                                    dim2=output_sequence_length
                                    )

                                model = mtrans.MiniTransformer(
                                        input_size=input_size,
                                        dec_seq_len=enc_seq_len,
                                        batch_first=batch_first,
                                        num_predicted_features=num_predicted_features
                                        )
                                
                                if PRETRAIN_FLAG:
                                    print('Attempting to load checkpoint file:', pretrain_pth)
                                    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
                                    checkpoint = torch.load(pretrain_pth)
                                    model.load_state_dict(checkpoint["model_state_dict"])
                                    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                                    epoch = checkpoint['epoch'] + 1
                                    loss = checkpoint['loss']
                                else:
                                    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
                                    epoch = 0
                                    loss = 0
                                

                                Create_tester(
                                    model=model,
                                    eval_dataLoader=eval_data,
                                    optimizer=optimizer,
                                    savepth=os.path.join(root, savepath),
                                    savelabel=savelabel,
                                    lr=learning_rate,
                                    num_epochs=max_epoch,
                                    cur_epoch=epoch,
                                    loss=loss,
                                    learn_decay=learn_decay,
                                    )
