import os


# os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import numpy as np
import pickle
from argparse import ArgumentParser
from data import split_train_val_test, next_batch, process_batch, load_data
from finetune_trainer import Trainer
from tqdm import tqdm
from model.model import PTR
from model.layer import init_weights

import torch
import torch.nn as nn
import torch.optim as optim
from load_data.datasets import Dataset, collate_fn, construct_mask, cal_prompt_token


if __name__ == '__main__':

    parser = ArgumentParser()
    parser.add_argument('--dataset', help='name of the dataset to use', type=str, default="Chengdu")
    parser.add_argument('--gpu_id', help='index of the cuda device to use', type=int, default=0)
    parser.add_argument('--batch_size', help='batch size', type=int, default=64)
    parser.add_argument('--hid_dim', help='hid size', type=int, default=512)
    parser.add_argument('--lr', help='learning rate', type=float, default=1e-4)
    parser.add_argument('--lambda1', type=int, default=10, help='weight for multi task rate')
    parser.add_argument('--kernel_size', type=int, default=9, help='kernel size in conv layer')
    parser.add_argument('--epoch', type=int, default=1, help='epoch')
    parser.add_argument('--repeat_time', type=int, default=1, help='expand dataset')
    parser.add_argument('--grid_size', type=int, default=50, help='grid size')
    parser.add_argument('--debug', type=bool, default=False, help='whether debug')
    parser.add_argument('--road_candi', type=bool, default=True, help='whether add road network')
    parser.add_argument('--soft_traj_num', type=int, default=128, help='num of  referce tokens')
    parser.add_argument('--keep_ratio', type=float, default=0.25, help='num of traj text')
    
    args = parser.parse_args()

    

    if args.dataset == "Porto":
        id_size = 2224 + 1
        mbr = {
            'min_lat':41.142,
            'min_lng':-8.652,
            'max_lat':41.174,
            'max_lng':-8.578
        }
        road_condition = np.load("./load_data/data/{}/flow.npy".format(args.dataset))
    elif args.dataset == "Chengdu":
        id_size = 2504 + 1
        mbr = {
            'min_lat':30.655,
            'min_lng':104.043,
            'max_lat':30.727,
            'max_lng':104.129
        }
        road_condition = np.load("./load_data/data/{}/flow.npy".format(args.dataset))

    # exit()
    device = "cuda:" + str(args.gpu_id)
    print("load data start...")
    # train_set, val_set, test_set = split_train_val_test(args.dataset, 0.8, 0.1)
    # train_set = load_data(args.dataset, "train", id_size=id_size, repeat_times=args.repeat_time, debug=args.debug)
    # val_set = load_data(args.dataset, "valid", id_size=id_size, repeat_times=1, debug=args.debug)
    # test_set = load_data(args.dataset, "test", id_size=id_size, repeat_times=1, debug=args.debug)
    
    road_condition = torch.tensor(road_condition, dtype=torch.float).to(device)
    print(road_condition.shape)
    # exit()

    train_set = Dataset(args.dataset, id_size, "train", mbr, repeat_times=1, debug=args.debug)
    val_set = Dataset(args.dataset, id_size, "valid", mbr, repeat_times=1, debug=args.debug)
    test_set = Dataset(args.dataset, id_size, "test", mbr, repeat_times=1, debug=args.debug)

    


    train_iterator = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size,
                                                 shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=False)
    
    val_iterator = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size,
                                                 shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=False)


    test_iterator = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size,
                                                 shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=False)
    task_prompt_tensor, time_prompt_tensor, travel_prompt_tensor, traj_prompt = cal_prompt_token()
    
    

    print(len(train_set), len(val_set), len(test_set))
    # exit()

    save_txt_path =  "./fine_tuned/result/{}_kernel_{}/".format(args.dataset, args.kernel_size)
    if not os.path.exists(save_txt_path): os.makedirs(save_txt_path)
    model_save_path = "./checkpoints/{}/".format(args.dataset)

    model_name = "debug_False_RN_{}_repeat_3_kernel_{}_softnum_{}_val-best-model.pt".format(args.road_candi, args.kernel_size, args.soft_traj_num)

    save_txt = save_txt_path + "logging_keep_ratio_{}".format(args.keep_ratio)
    if args.road_candi:
        save_txt += "_RN_{}_softnum_{}".format(args.road_candi, args.soft_traj_num)
    save_txt += ".txt"
    if args.debug:
        save_txt = save_txt_path + "debug_logging.txt"
    with open(save_txt, "w") as f:
        f.write("Fine-tune Dataset: {}\n".format(args.dataset))
        f.write("conv kernel: {}, softnum: {}\n".format(args.kernel_size, args.soft_traj_num))

    with open(save_txt, "a+") as f:
        f.write("Training set length: {}, Validation set length: {}, Test set length: {}\n".format(len(train_set), len(val_set), len(test_set)))
    print("load data ok...")

    model = PTR(args.hid_dim, id_size, device, args.road_candi, 9, args.soft_traj_num).to(device)
    model.apply(init_weights)
    model.load_state_dict(torch.load(model_save_path + model_name))
    
    model_save_path = "./fine_tuned/checkpoints/{}/".format(args.dataset)
    model_name = "debug_{}_RN_{}_ratio_{}_kernel_{}_softnum_{}_val-best-model.pt".format(args.debug, args.road_candi, args.keep_ratio, args.kernel_size, args.soft_traj_num)

    if not os.path.exists(model_save_path): os.makedirs(model_save_path)

    print("load model ok...")

    trainer = Trainer(model, args.batch_size, device, args.lr, args.lambda1, mbr, road_condition, id_size, task_prompt_tensor,  time_prompt_tensor, travel_prompt_tensor, traj_prompt)
    
    trainer.train(args.epoch, train_iterator, val_iterator, save_txt, model_save_path, model_name)

    model.load_state_dict(torch.load(model_save_path + model_name))



    test_acc = trainer.val(model, test_iterator, args.keep_ratio, types="test")
    with open(save_txt, "a+") as f:
        f.write("Test Acc {}: {} \n".format(args.keep_ratio, test_acc)) 
    print("Test Acc {}: {}  \n".format(args.keep_ratio, test_acc)) 