# handle OOM
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:1024'

import ValueHome as ValueHome
import ValueHome_utils
from torch.utils.data import DataLoader
import torch
import datetime
import minitransformer as mtrans
import numpy as np
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
import evaluate
import time
from transformers import TrainingArguments, Trainer
from torch.utils.tensorboard import SummaryWriter
import random
import pickle


def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)
    return


def Create_trainer(model, train_dataLoader, eval_dataLoader, optimizer, savepth, savelabel="", batch_first=False, lr=5e-5, num_epochs=3, cur_epoch=0, loss=0, learn_decay=0):
    loss_att = -1
    time_weight = 1
    if "duration" in encode_type:
        min_dis = 100
    else:
        max_acc = -1
    
    num_epochs = num_epochs
    num_training_steps = num_epochs * len(train_dataLoader)
    lr_scheduler = get_scheduler(
        name="linear", 
        optimizer=optimizer, 
        num_warmup_steps=0, 
        num_training_steps=num_training_steps
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if "onlyaction" in encode_type:
        loss_cal = torch.nn.CrossEntropyLoss()
    elif "onlyduration" in encode_type:
        loss_cal = torch.nn.MSELoss()
    elif ("onlytarget" in encode_type) or ("act2trg" in encode_type):
        loss_cal = torch.nn.CosineEmbeddingLoss()
        target_template = pickle.load(open("gpt_target_template.pkl", "rb"))
    else:
        loss_cal = torch.nn.MSELoss()
    model.to(device)
    optimizer_to(optimizer, device)


    # train

    for epoch in range(cur_epoch, num_epochs):
        model.train()
        torch.autograd.set_detect_anomaly(True)
        for step, batch in enumerate(tqdm(train_dataLoader)):
            src, trg, trg_y = batch
            src = src.to(device)
            trg = trg.to(device)
            trg_y = trg_y.to(device)

            # Permute from shape [batch size, seq len, num features] to [seq len, batch size, num features]
            if batch_first == False:
                src = src.permute(1, 0, 2)
                trg = trg.permute(1, 0, 2)
                trg_y = trg_y.permute(1, 0, 2)
            outputs = model(
                src=src.to(device),
                tgt=trg.to(device),
                src_mask=src_mask.to(device),
                tgt_mask=tgt_mask.to(device)
            )

            # calculate loss
            if "onlyaction" in encode_type:
                loss = loss_cal(outputs.to(device), trg_y.to(device))  # should be input[0,:,14:33]
            elif "onlyduration" in encode_type:
                loss = loss_cal(outputs.to(device), trg_y.to(device))  # sould be input[0,:,6:7]
            elif ("onlytarget" in encode_type) or ("act2trg" in encode_type):
                target = torch.ones(outputs.shape[1]).to(device)
                loss = loss_cal(outputs[0,:,:].to(device), trg_y[0,:,:].to(device), target=target)
            else:
                loss = loss_cal(outputs[0,:,:loss_att].to(device), trg_y[0,:,:loss_att].to(device))
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            if step % 100 == 0:
                global_step = epoch*len(train_dataLoader) + step
                writer.add_scalar("loss", loss.item(), global_step)
        # eval
        model.eval()
        Acc = 0
        for step, batch in enumerate(eval_dataLoader):
            # batch = {k: v.to(device) for k, v in batch.items()}
            src, trg, trg_y = batch# Permute from shape [batch size, seq len, num features] to [seq len, batch size, num features]
            if batch_first == False:
                src = src.permute(1, 0, 2)
                trg = trg.permute(1, 0, 2)
                trg_y = trg_y.permute(1, 0, 2)
            with torch.no_grad():
                outputs = model(
                    src=src.to(device),
                    tgt=trg.to(device),
                    src_mask=src_mask.to(device),
                    tgt_mask=tgt_mask.to(device)
                )
            if "duration" in encode_type:
                dur_dis = ValueHome_utils.accuracy_cal_duration(outputs, trg_y.to(device))
                Acc += dur_dis
            elif "action" in encode_type:
                act_acc, _, _ = ValueHome_utils.accuracy_cal_onehot(outputs, trg_y.to(device))
                Acc += act_acc
            elif ("target" in encode_type) or (encode_type == "act2trg"):
                tar_acc, _, _ = ValueHome_utils.accuracy_cal_target(outputs, trg_y.to(device), 
                                                              userId=-1, 
                                                              human_flag=False, 
                                                              templates=target_template)
                Acc += tar_acc
            elif "end2end" in encode_type:
                tar_acc, _, _ = ValueHome_utils.accuracy_cal_target(outputs[:,:,:384], trg_y[:,:,:384].to(device), 
                                                              userId=-1, 
                                                              human_flag=False, 
                                                              templates=target_template)
                Acc += tar_acc
        
        if "action" in encode_type:
            Acc = Acc / (step+1)
            print('epoch={0}, acc={1}. lr={2}'.format(epoch, Acc, lr))
            writer.add_scalar("acc_val", Acc, epoch)
        elif "duration" in encode_type:
            Acc = Acc / (step+1)  # average distance
            print('epoch={0}, distance={1}. lr={2}'.format(epoch, Acc, lr))
            writer.add_scalar("distance_val", Acc, epoch)
        elif ("target" in encode_type) or ("act2trg" in encode_type):
            Acc = Acc / (step+1)
            print('epoch={0}, acc={1}. lr={2}'.format(epoch, Acc, lr))
            writer.add_scalar("acc_val", Acc, epoch)
        
        # save model
        save_flag = False
        if "duration" in encode_type:
            if Acc < min_dis:
                save_flag = True
        else:
            if Acc > max_acc:
                save_flag = True

        if save_flag:
            cur_time, _ = ValueHome_utils.vis_time(time.time())
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
            }, savepth + '/{0}_epoch{1}_{2}.pth'.format(cur_time, epoch, savelabel))
            if "duration" in encode_type:
                min_dis = Acc
            else:
                max_acc = Acc
    
    # save model
    cur_time, _ = ValueHome_utils.vis_time(time.time())
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
    }, savepth + '/{0}_{1}.pth'.format(cur_time, savelabel))


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True


if __name__ == '__main__':
    dim_list = [4]#,8,16,32]#,64,128]
    layer_list = [8]#,16,32,64,128]#,256,512,1024,2048]
    encode_layer_list = [1]#,2,3]#,4]
    head_list = [4]#1,2,3,4]#,5,6,7,8]
    lr_list = [0.00001]#[0.0001]#,0.00001] #,0.000001]
    window_list = [14]  #[5,8,10,12,14,16,18,20,25,28,30]
    
    setup_seed(7)

    for dim in dim_list:
        for layer in layer_list:
            for head in head_list:
                for encode_layer in encode_layer_list:
                    for lr in lr_list:
                        for window in window_list:

                            # Hyperparams
                            PRETRAIN_FLAG = False  #True
                            SHUFFLE_FALG = True
                            # encode_type = "action_onlyaction"
                            # encode_type = "duration_onlyduration"
                            # encode_type = "target_onlytarget"
                            # encode_type = "act2trg"
                            encode_type = "end2end_act2trg"
                            root = "/home/zhe/Documents/Code/IntentionProject/ValueSoftAlign-Net"
                            data_type = "_diff2full"
                            arranged_dataset_root = "/home/zhe/Documents/Code/IntentionProject/ValueHomeDataset/GPT/gpt_plan/gpt_plan/GPT_arranged"
                            data_dir = arranged_dataset_root + f"/{encode_type}"
                            pretrain_pth = root + "/models/2023-09-26_10-25-26_Act2trg_lr0.001_dim4ly8ecd1hd4win1.pth"

                            if encode_type == "act2trg":
                                window = 1

                            savelabel = f"End2end_lr{str(lr)}_dim{str(dim)}ly{str(layer)}ecd{str(encode_layer)}hd{str(head)}win{str(window)}"
                            savepath = f"models/{savelabel}"
                            if not os.path.exists(os.path.join(root, savepath)):
                                os.mkdir(os.path.join(root, savepath))
                            
                            # print training info
                            print("start training with params:")
                            print("learing rate: {0}".format(lr))
                            print("dim: {0}".format(dim))
                            print("layer: {0}".format(layer))
                            print("encode_layer: {0}".format(encode_layer))
                            print("head: {0}".format(head))
                            print("window: {0}".format(window))
                            with open(os.path.join(root, savepath, "train_info.txt"), "w") as f:
                                f.write("start training with params:\n")
                                f.write("learing rate: {0}\n".format(lr))
                                f.write("dim: {0}\n".format(dim))
                                f.write("layer: {0}\n".format(layer))
                                f.write("encode_layer: {0}\n".format(encode_layer))
                                f.write("head: {0}\n".format(head))
                                f.write("window: {0}\n".format(window))
                            
                            test_size = 0.2
                            eval_size = 0.1
                            train_start = 0
                            batch_size = 32
                            learning_rate = lr
                            max_epoch = 200
                            if "action" in encode_type:
                                # creating target col name
                                label_num = 0  # 26(actions)
                                feature_num = 1218  # 1211(features)
                            elif "duration" in encode_type:
                                label_num = 0  # 1(duration)
                                feature_num = 1218
                            elif "target" in encode_type:
                                label_num = 0
                                feature_num = 1370 # 384 + 986
                            elif "act2trg" in encode_type:
                                label_num = 0
                                feature_num = 1602  # 384 + 1218, but only use 1218
                            #- start creating column names
                            exogenous_vars = []  # labels. should contain strings. Each string must correspond to a column name
                            for i in range(label_num):
                                exogenous_vars.append("label_%s" % str(i).zfill(3))
                            indexes = np.arange(feature_num).astype(str)
                            target_col_name = indexes.tolist()
                            #- end creating
                            cur_time, _ = ValueHome_utils.vis_time(time.time())
                            if PRETRAIN_FLAG:
                                writer_pth = input('please complete the save path of train log: %s/train_log/' % root)
                                writer = SummaryWriter(root + '/train_log/' + writer_pth)
                            else:
                                writer = SummaryWriter(root + '/train_log/{0}_{1}'.format(cur_time, savelabel))

                            ## Params
                            dim_val = dim #32
                            n_heads = head #2
                            n_decoder_layers = encode_layer #1
                            n_encoder_layers = encode_layer #1
                            dec_seq_len = 1 #360 # length of input given to decoder
                            enc_seq_len = window #20 # length of input given to encoder
                            output_sequence_length = 1 #360 # target sequence length. If hourly data and length = 48, you predict 2 days ahead
                            window_size = enc_seq_len + output_sequence_length # used to slice data into sub-sequences
                            step_size = 1 # Step size, i.e. how many time steps does the moving window move at each step
                            in_features_encoder_linear_layer = layer #64 # 2048
                            in_features_decoder_linear_layer = layer  #64 # 2048
                            max_seq_len = enc_seq_len
                            batch_first = False
                            learn_decay = 0.1
                            

                            # Define input variables
                            input_variables = exogenous_vars + target_col_name
                            target_idx = 0 # index position of target in batched trg_y

                            if "act2trg" in encode_type:
                                if "end2end" in encode_type:
                                    input_size = 1602
                                else:
                                    input_size = 1218
                            else:
                                input_size = len(input_variables)

                            # Define number of predicted features    
                            if "onlyaction" in encode_type:
                                num_predicted_features = 26
                            elif "onlyduration" in encode_type:
                                num_predicted_features = 1
                            elif "onlytarget" in encode_type:
                                num_predicted_features = 384
                            elif "act2trg" in encode_type:
                                if "end2end" in encode_type:
                                    num_predicted_features = 1602
                                else:
                                    num_predicted_features = 384
                            else:
                                num_predicted_features = len(input_variables)

                            # Read data
                            raw_training_data, raw_eval_data, raw_test_data = ValueHome_utils.ValueHomeLoader(data_dir=data_dir, 
                                                                                                                test_size=test_size, 
                                                                                                                eval_size=eval_size,
                                                                                                                train_start=train_start, 
                                                                                                                label_num=label_num, 
                                                                                                                feature_num=feature_num,
                                                                                                                data_type=data_type)
                            # Make list of (start_idx, end_idx) pairs that are used to slice the time series sequence into chunkc. 
                            # Should be training data indices only
                            training_indices = ValueHome_utils.get_indices_entire_sequence(
                                data=raw_training_data, 
                                window_size=window_size, 
                                step_size=step_size)
                            eval_indices = ValueHome_utils.get_indices_entire_sequence(
                                data=raw_eval_data, 
                                window_size=window_size, 
                                step_size=step_size)
                            test_indices = ValueHome_utils.get_indices_entire_sequence(
                                data=raw_test_data, 
                                window_size=window_size, 
                                step_size=step_size)

                            # Making instance of custom dataset class
                            # print(raw_training_data[input_variables])
                            training_data = ValueHome.ValueHome(
                                data=torch.from_numpy(raw_training_data[input_variables].values).float(), # len = 19(actions) + 1211(encoded_ftr)
                                indices=training_indices,
                                enc_seq_len=enc_seq_len,
                                dec_seq_len=dec_seq_len,
                                target_seq_len=output_sequence_length
                                )
                            eval_data = ValueHome.ValueHome(
                                data=torch.from_numpy(raw_eval_data[input_variables].values).float(),
                                indices=eval_indices,
                                enc_seq_len=enc_seq_len,
                                dec_seq_len=dec_seq_len,
                                target_seq_len=output_sequence_length
                                )
                            test_data = ValueHome.ValueHome(
                                data=torch.from_numpy(raw_test_data[input_variables].values).float(),
                                indices=test_indices,
                                enc_seq_len=enc_seq_len,
                                dec_seq_len=dec_seq_len,
                                target_seq_len=output_sequence_length
                                )

                            # Making dataloader
                            training_data = DataLoader(training_data, batch_size=batch_size, shuffle=SHUFFLE_FALG)
                            eval_data = DataLoader(eval_data, batch_size=1, shuffle=SHUFFLE_FALG)
                            test_data = DataLoader(test_data, batch_size=1)

                            # Make src mask for decoder with size:
                            # [batch_size*n_heads, output_sequence_length, enc_seq_len]
                            src_mask = ValueHome_utils.generate_square_subsequent_mask(
                                dim1=output_sequence_length,
                                dim2=enc_seq_len
                                )

                            # Make tgt mask for decoder with size:
                            # [batch_size*n_heads, output_sequence_length, output_sequence_length]
                            tgt_mask = ValueHome_utils.generate_square_subsequent_mask( 
                                dim1=output_sequence_length,
                                dim2=output_sequence_length
                                )

                            model = mtrans.MiniTransformer(
                                    input_size=input_size,
                                    dec_seq_len=enc_seq_len,
                                    batch_first=batch_first,
                                    num_predicted_features=num_predicted_features
                                    )
                            
                            if PRETRAIN_FLAG:
                                print('Attempting to load checkpoint file:', pretrain_pth)
                                optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
                                checkpoint = torch.load(pretrain_pth)
                                model.load_state_dict(checkpoint["model_state_dict"])
                                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                                epoch = checkpoint['epoch'] + 1
                                loss = checkpoint['loss']
                            else:
                                optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
                                epoch = 0
                                loss = 0
                            

                            Create_trainer(
                                model=model,
                                train_dataLoader=training_data,
                                eval_dataLoader=eval_data,
                                optimizer=optimizer,
                                savepth=os.path.join(root, savepath),
                                savelabel=savelabel,
                                lr=learning_rate,
                                num_epochs=max_epoch,
                                cur_epoch=epoch,
                                loss=loss,
                                learn_decay=learn_decay,
                                )

                            writer.close()