# handle OOM
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:1024'

import ValueHome as ValueHome
import ValueHome_utils
from torch.utils.data import DataLoader
import torch
import datetime
import minitransformer as mtrans
import numpy as np
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
import evaluate
import time
from transformers import TrainingArguments, Trainer
from torch.utils.tensorboard import SummaryWriter
import random
import pickle


def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)
    return


def Create_trainer(model, train_dataLoader, eval_dataLoader, optimizer, savepth, savelabel="", batch_first=False, lr=5e-5, num_epochs=3, cur_epoch=0, loss=0, learn_decay=0):
    loss_att = -1
    if "duration" in encode_type:
        min_dis = 100
    else:
        max_acc = -1
    
    num_epochs = num_epochs
    num_training_steps = num_epochs * len(train_dataLoader)
    lr_scheduler = get_scheduler(
        name="linear", 
        optimizer=optimizer, 
        num_warmup_steps=0, 
        num_training_steps=num_training_steps
    )
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if "onlyaction" in encode_type:
        loss_cal = torch.nn.CrossEntropyLoss()
    elif "onlyduration" in encode_type:
        loss_cal = torch.nn.MSELoss()
    elif ("onlytarget" in encode_type) or ("act2trg" in encode_type):
        loss_cal = torch.nn.CosineEmbeddingLoss()
        if "diff" in data_type:
            human_flag = False
            target_template = pickle.load(open("gpt_target_template.pkl", "rb"))
        else:
            human_flag = True
            target_template = pickle.load(open("human_target_template.pkl", "rb"))
    else:
        loss_cal = torch.nn.MSELoss()
    model.to(device)
    optimizer_to(optimizer, device)


    # train

    for epoch in range(cur_epoch, num_epochs):
        model.train()
        torch.autograd.set_detect_anomaly(True)
        for step, batch in enumerate(tqdm(train_dataLoader)):
            src, trg, trg_y = batch
            src = src.to(device)
            trg = trg.to(device)
            trg_y = trg_y.to(device)

            # Permute from shape [batch size, seq len, num features] to [seq len, batch size, num features]
            if batch_first == False:
                src = src.permute(1, 0, 2)
                trg = trg.permute(1, 0, 2)
                trg_y = trg_y.permute(1, 0, 2)

            outputs = model(
                src=src.to(device),
                tgt=trg.to(device),
                src_mask=src_mask.to(device),
                tgt_mask=tgt_mask.to(device)
            )

            # calculate loss
            if "onlyaction" in encode_type:
                loss = loss_cal(outputs.to(device), trg_y.to(device))  # should be input[0,:,14:33]
            elif "onlyduration" in encode_type:
                loss = loss_cal(outputs.to(device), trg_y.to(device))  # sould be input[0,:,6:7]
            elif ("onlytarget" in encode_type) or ("act2trg" in encode_type):
                target = torch.ones(outputs.shape[1]).to(device)
                loss = loss_cal(outputs[0,:,:].to(device), trg_y[0,:,:].to(device), target=target)
            else:
                loss = loss_cal(outputs[0,:,:loss_att].to(device), trg_y[0,:,:loss_att].to(device))
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            if step % 100 == 0:
                global_step = epoch*len(train_dataLoader) + step
                writer.add_scalar("loss", loss.item(), global_step)
        # eval
        model.eval()
        Acc = 0
        for step, batch in enumerate(eval_dataLoader):
            src, trg, trg_y = batch
            if batch_first == False:
                src = src.permute(1, 0, 2)
                trg = trg.permute(1, 0, 2)
                trg_y = trg_y.permute(1, 0, 2)
            with torch.no_grad():
                outputs = model(
                    src=src.to(device),
                    tgt=trg.to(device),
                    src_mask=src_mask.to(device),
                    tgt_mask=tgt_mask.to(device)
                )
            if "duration" in encode_type:
                dur_dis = ValueHome_utils.accuracy_cal_duration(outputs, trg_y.to(device))
                Acc += dur_dis
            elif "action" in encode_type:
                act_acc, _, _ = ValueHome_utils.accuracy_cal_onehot(outputs, trg_y.to(device))
                Acc += act_acc
            elif ("target" in encode_type) or (encode_type == "act2trg"):
                tar_acc, _, _ = ValueHome_utils.accuracy_cal_target(outputs, trg_y.to(device), 
                                                              userId=datatype.split("_")[0], 
                                                              human_flag=human_flag, 
                                                              templates=target_template)
                Acc += tar_acc
            elif "end2end" in encode_type:
                tar_acc, _, _ = ValueHome_utils.accuracy_cal_target(outputs[:,:,:384], trg_y[:,:,:384].to(device), 
                                                              userId=-1, 
                                                              human_flag=False, 
                                                              templates=target_template)
                Acc += tar_acc
        
        if "action" in encode_type:
            Acc = Acc / (step+1)
            print('epoch={0}, acc={1}. lr={2}'.format(epoch, Acc, lr))
            writer.add_scalar("acc_val", Acc, epoch)
        elif "duration" in encode_type:
            Acc = Acc / (step+1)  # average distance
            print('epoch={0}, distance={1}. lr={2}'.format(epoch, Acc, lr))
            writer.add_scalar("distance_val", Acc, epoch)
        elif ("target" in encode_type) or ("act2trg" in encode_type):
            Acc = Acc / (step+1)
            print('epoch={0}, acc={1}. lr={2}'.format(epoch, Acc, lr))
            writer.add_scalar("acc_val", Acc, epoch)
        
        
        # save model
        save_flag = False
        if "duration" in encode_type:
            if Acc < min_dis:
                save_flag = True
        else:
            if Acc > max_acc:
                save_flag = True

        if save_flag:
            cur_time, _ = ValueHome_utils.vis_time(time.time())
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": loss,
            }, savepth + '/{0}_epoch{1}_{2}.pth'.format(cur_time, epoch, savelabel))
            if "duration" in encode_type:
                min_dis = Acc
            else:
                max_acc = Acc
    
    # save model
    cur_time, _ = ValueHome_utils.vis_time(time.time())
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
    }, savepth + '/{0}_epoch{1}_{2}.pth'.format(cur_time, epoch, savelabel))


def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True


# def Create_MiniTransformer():
if __name__ == '__main__':
    dim = 4
    layer = 8
    encode_layer = 1
    head = 4
    lr = 0.000001
    window = 14
    data_list = [
        "02_diff2full",
        "03_diff2full",
        "04_diff2full",
        "05_diff2full",
        "05_full",
        "16_full",
        "17_full",
        "18_full",
        "19_full"
        ]
    
    setup_seed(7)

    for datatype in data_list:

        # Hyperparams
        PRETRAIN_FLAG = True  #True
        SHUFFLE_FALG = True
        # encode_type = "action_onlyaction"
        # encode_type = "duration_onlyduration"
        # encode_type = "act2trg"
        # encode_type = "target_onlytarget"
        encode_type = "end2end_act2trg"
        root = "."
        data_type = datatype
        # 暂时仅支持单文件读取，待扩充
        if 'diff' in data_type:
            arranged_dataset_root = "/home/zhe/Documents/Code/IntentionProject/ValueHomeDataset/GPT/gpt_plan/gpt_plan/GPT_arranged"
            # data_dir = arranged_dataset_root + f"/{encode_type}"
            data_dir = arranged_dataset_root + f"/gpt4_{encode_type.split('_')[-1]}"
        else:
            arranged_dataset_root = "/home/zhe/Documents/Code/IntentionProject/ValueHomeDataset/HumanData/human_arranged"
            data_dir = arranged_dataset_root + f"/human_{encode_type.split('_')[-1]}"

        if "onlyaction" in encode_type:
            pretrain_pth = root + "/models/Action/2023-09-25_21-15-30_epoch71_ACT_lr1e-05_dim4ly8ecd1hd4win14.pth"
            model_id = "1"
        elif "onlyduration" in encode_type:
            pretrain_pth = root + "/models/Duration/2023-09-25_21-42-00_epoch52_Duration_lr0.0001_dim4ly8ecd1hd4win14.pth"
            model_id = "1"
        elif "onlytarget" in encode_type:
            pretrain_pth = root + "/models/Target/2023-09-26_13-15-57_epoch400_Target_lr1e-05_dim4ly8ecd1hd4win14.pth"
            model_id = "1"
        elif encode_type == "act2trg":
            pretrain_pth = root + "/models/Act2trg/2023-09-26_15-08-42_epoch494_Act2trg_lr1e-05_dim4ly8ecd1hd4win1.pth"
            model_id = "1"
        elif encode_type == "end2end_act2trg":
            pretrain_pth = root + "/models/End2end_lr1e-05_dim4ly8ecd1hd4win14/2023-09-28_15-41-51_End2end_lr1e-05_dim4ly8ecd1hd4win14.pth"
            model_id = "1"

        if encode_type == "act2trg":
            window = 1

        savelabel = f"lr{str(lr)}" #_dim{str(dim)}ly{str(layer)}ecd{str(encode_layer)}hd{str(head)}win{str(window)}"
        savepath = f"models/{encode_type}_{data_type}/{savelabel}"
        if not os.path.exists(os.path.join(root, f"models/{encode_type}_{data_type}")):
            os.mkdir(os.path.join(root, f"models/{encode_type}_{data_type}"))
        if not os.path.exists(os.path.join(root, savepath)):
            os.mkdir(os.path.join(root, savepath))
        
        # print training info
        print("start training with params:")
        print("learing rate: {0}".format(lr))
        print("dim: {0}".format(dim))
        print("layer: {0}".format(layer))
        print("encode_layer: {0}".format(encode_layer))
        print("head: {0}".format(head))
        print("window: {0}".format(window))
        with open(os.path.join(root, savepath, "train_info.txt"), "w") as f:
            f.write("start training with params:\n")
            f.write("learing rate: {0}\n".format(lr))
            f.write("dim: {0}\n".format(dim))
            f.write("layer: {0}\n".format(layer))
            f.write("encode_layer: {0}\n".format(encode_layer))
            f.write("head: {0}\n".format(head))
            f.write("window: {0}\n".format(window))
          
        test_size = 0.2
        eval_size = 0.1
        train_start = 0
        batch_size = 8
        learning_rate = lr

        if "onlytarget" in encode_type:
            max_epoch = 500
        elif encode_type == "act2trg":
            max_epoch = 600
        elif "onlyaction" in encode_type:
            max_epoch = 171
        elif "duration" in encode_type:
            max_epoch = 152
        elif "end2end" in encode_type:
            max_epoch = 400

        if "action" in encode_type:
            # creating target col name
            label_num = 0  # 26(actions)
            feature_num = 1218  # 1218(features)
        elif "duration" in encode_type:
            label_num = 0  # 1(duration)
            feature_num = 1218
        elif "target" in encode_type:
            label_num = 0
            feature_num = 1370 # 384 + 986
        elif "act2trg" in encode_type:
            label_num = 0
            feature_num = 1602
        #- start creating column names
        exogenous_vars = []  # labels. should contain strings. Each string must correspond to a column name
        for i in range(label_num):
            exogenous_vars.append("label_%s" % str(i).zfill(3))
        indexes = np.arange(feature_num).astype(str)
        target_col_name = indexes.tolist()
        #- end creating
        cur_time, _ = ValueHome_utils.vis_time(time.time())
        if PRETRAIN_FLAG:
            writer_pth = root + f'/finetune_log/{encode_type}_{data_type}_model{model_id}'
            if not os.path.exists(writer_pth):
                os.mkdir(writer_pth)
            writer = SummaryWriter(writer_pth + '/{0}_{1}'.format(cur_time, savelabel))
        else:
            writer = SummaryWriter(root + '/train_log/{0}_{1}'.format(cur_time, savelabel))

        ## Params
        dim_val = dim #32
        n_heads = head #2
        n_decoder_layers = encode_layer #1
        n_encoder_layers = encode_layer #1
        dec_seq_len = 1 #360 # length of input given to decoder
        enc_seq_len = window #20 # length of input given to encoder
        output_sequence_length = 1 #360 # target sequence length. If hourly data and length = 48, you predict 2 days ahead
        window_size = enc_seq_len + output_sequence_length # used to slice data into sub-sequences
        step_size = 1 # Step size, i.e. how many time steps does the moving window move at each step
        in_features_encoder_linear_layer = layer #64 # 2048
        in_features_decoder_linear_layer = layer  #64 # 2048
        max_seq_len = enc_seq_len
        batch_first = False
        learn_decay = 0.1
        

        # Define input variables
        input_variables = exogenous_vars + target_col_name
        target_idx = 0 # index position of target in batched trg_y

        if encode_type == "act2trg":
            input_size = 1218
        else:
            input_size = len(input_variables)

        # Define number of predicted features    
        if "onlyaction" in encode_type:
            num_predicted_features = 26
        elif "onlyduration" in encode_type:
            num_predicted_features = 1
        elif "onlytarget" in encode_type:
            num_predicted_features = 384
        elif encode_type == "act2trg":
            num_predicted_features = 384
        elif "end2end" in encode_type:
            num_predicted_features = 1602
        else:
            num_predicted_features = len(input_variables)

        # Read data
        raw_training_data, raw_eval_data, raw_test_data = ValueHome_utils.ValueHomeLoader(data_dir=data_dir, 
                                                                                            test_size=test_size, 
                                                                                            eval_size=eval_size, 
                                                                                            train_start=train_start,
                                                                                            label_num=label_num, 
                                                                                            feature_num=feature_num,
                                                                                            data_type=data_type)
        # Make list of (start_idx, end_idx) pairs that are used to slice the time series sequence into chunkc. 
        # Should be training data indices only
        training_indices = ValueHome_utils.get_indices_entire_sequence(
            data=raw_training_data, 
            window_size=window_size, 
            step_size=step_size)
        eval_indices = ValueHome_utils.get_indices_entire_sequence(
            data=raw_eval_data, 
            window_size=window_size, 
            step_size=step_size)
        test_indices = ValueHome_utils.get_indices_entire_sequence(
            data=raw_test_data, 
            window_size=window_size, 
            step_size=step_size)

        if len(eval_indices) == 0 or len(test_indices) == 0:
            print(f"No eval or test data. Skipping {data_type}...")
            continue

        # Making instance of custom dataset class
        training_data = ValueHome.ValueHome(
            data=torch.from_numpy(raw_training_data[input_variables].values).float(), # len = 26(actions) + 1218(encoded_ftr)
            indices=training_indices,
            enc_seq_len=enc_seq_len,
            dec_seq_len=dec_seq_len,
            target_seq_len=output_sequence_length
            )
        eval_data = ValueHome.ValueHome(
            data=torch.from_numpy(raw_eval_data[input_variables].values).float(),
            indices=eval_indices,
            enc_seq_len=enc_seq_len,
            dec_seq_len=dec_seq_len,
            target_seq_len=output_sequence_length
            )
        test_data = ValueHome.ValueHome(
            data=torch.from_numpy(raw_test_data[input_variables].values).float(),
            indices=test_indices,
            enc_seq_len=enc_seq_len,
            dec_seq_len=dec_seq_len,
            target_seq_len=output_sequence_length
            )

        # Making dataloader
        training_data = DataLoader(training_data, batch_size=batch_size, shuffle=SHUFFLE_FALG, drop_last=True)
        eval_data = DataLoader(eval_data, batch_size=1, shuffle=SHUFFLE_FALG)
        test_data = DataLoader(test_data, batch_size=1)

        # Make src mask for decoder with size:
        # [batch_size*n_heads, output_sequence_length, enc_seq_len]
        src_mask = ValueHome_utils.generate_square_subsequent_mask(
            dim1=output_sequence_length,
            dim2=enc_seq_len
            )

        # Make tgt mask for decoder with size:
        # [batch_size*n_heads, output_sequence_length, output_sequence_length]
        tgt_mask = ValueHome_utils.generate_square_subsequent_mask( 
            dim1=output_sequence_length,
            dim2=output_sequence_length
            )

        model = mtrans.MiniTransformer(
                input_size=input_size,
                dec_seq_len=enc_seq_len,
                batch_first=batch_first,
                num_predicted_features=num_predicted_features
                )
        
        if PRETRAIN_FLAG:
            print('Attempting to load checkpoint file:', pretrain_pth)
            optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
            checkpoint = torch.load(pretrain_pth)
            model.load_state_dict(checkpoint["model_state_dict"])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epoch = checkpoint['epoch'] + 1
            loss = checkpoint['loss']
        else:
            optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
            epoch = 0
            loss = 0

        Create_trainer(
            model=model,
            train_dataLoader=training_data,
            eval_dataLoader=eval_data,
            optimizer=optimizer,
            savepth=os.path.join(root, savepath),
            savelabel=savelabel,
            lr=learning_rate,
            num_epochs=max_epoch,
            cur_epoch=epoch,
            loss=loss,
            learn_decay=learn_decay,
            )
        
        writer.close()