import torch
import torch.nn as nn
import torch.optim as optim
from data import next_batch, process_batch
from tqdm import tqdm
from utils import cal_acc, epoch_time
import time
import numpy as np
from load_data.datasets import Dataset, collate_fn, construct_mask
from model.loss_fn import cal_id_acc, check_rn_dis_loss
from model.model_utils import toseq
from multiprocessing import Pool
from transformers import BertModel, BertTokenizer

class Construct_traj_prompt():
    def __init__(self, time_prompt_tensor, travel_prompt_tensor, traj_prompt, device):
        # self.task_prompt_tensor = task_prompt_tensor
        self.time_prompt_tensor = time_prompt_tensor
        self.travel_prompt_tensor = travel_prompt_tensor
        self.traj_prompt = traj_prompt
        self.device = device

        MODEL_PATH = './PLM/BERT'  # 装着上面3个文件的文件夹位置
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_PATH)

        model = BertModel.from_pretrained(MODEL_PATH)  # 读取预训练模型

        self.bert_token = model.state_dict()['embeddings.word_embeddings.weight']
        self.weekday = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
    
    def encode_prompt(self, text):
        indexed_tokens = self.tokenizer.encode(text)
        tokens_tensor = torch.tensor([indexed_tokens])  # 文本编码转tensor
        return self.bert_token[tokens_tensor]

    
    def process_one_traj_prompt(self, length):
        
        _start_time = str(self.start_time[length][0]) + ":" + str(self.start_time[length][1])
        _start_date = self.weekday[self.start_time[length][2]]
        _end_time = str(self.end_time[length][0]) + ":" + str(self.end_time[length][1])
        _end_date = self.weekday[self.end_time[length][2]]
        
        _time = _start_time + " on " + _start_date + " and ended at " + _end_time + " on " + _end_date +". "
        
        _time_prompt = torch.cat([self.time_prompt_tensor, self.encode_prompt(_time)], dim=1)


        travel_minutes, travel_seconds = self.total_time[length][0], self.total_time[length][1]
        travel_dis = self.total_dis[length]

        travel_time_prompt = self.encode_prompt("{} minutes {} seconds.".format(travel_minutes, travel_seconds))
        travel_dis_prompt = self.encode_prompt("{:.2f} kilometers.".format(travel_dis))
        
        travel_prompt = torch.cat([self.travel_prompt_tensor[0], travel_time_prompt, self.travel_prompt_tensor[1], travel_dis_prompt], dim=1)
        
        return torch.cat([self.task_prompt_tensor, _time_prompt, travel_prompt, self.traj_prompt], dim=1)
        
    def process_batch_data(self, task_prompt_tensor, start_time, end_time, total_time, total_dis):

        self.task_prompt_tensor = task_prompt_tensor
        self.start_time = start_time
        self.end_time = end_time
        self.total_time = total_time
        self.total_dis = total_dis
        res = []
        for i in range(len(self.start_time)):
            prompt_token = self.process_one_traj_prompt(i)
            res.append(prompt_token)
        res = torch.stack(res, 0).squeeze(1).to(self.device)
        return res


class Trainer():
    
    def __init__(self, model, batch_size, device, lr, lambda1, mbr, road_condition, id_size, task_prompt_tensor,  time_prompt_tensor, travel_prompt_tensor, traj_prompt, clip=1):
        self.model = model
        self.batch_size = batch_size
        self.device = device
        self.lr = lr
        self.lambda1 = lambda1
        self.clip = clip
        self.mbr = mbr
        self.id_size = id_size
        self.task_prompt_tensor = task_prompt_tensor
        self.time_prompt_tensor = time_prompt_tensor
        self.travel_prompt_tensor = travel_prompt_tensor
        self.traj_prompt = traj_prompt

        self.best_model = model
        self.road_condition = road_condition
        self.best_acc = -1
        self.Construct_traj_prompt = Construct_traj_prompt(time_prompt_tensor, travel_prompt_tensor, traj_prompt, device)
    def train(self, epochs, train_iterator, val_iterator, save_txt, save_model_path, model_name):
        optimizer = optim.AdamW(self.model.parameters(), lr=self.lr)

        criterion_reg = nn.MSELoss()
        criterion_ce = nn.NLLLoss()
        self.model.train()
        flag = False
        train_loss, train_acc = 0, 0
        best_epoch = 0
        start_time = time.time()
        for epoch in range(epochs):
            for i, batch in enumerate(train_iterator):
                if i % 3 == 0: keep_ratio = 0.25
                if i % 3 == 1: keep_ratio = 0.125
                if i % 3 == 2: keep_ratio = 0.0625
                # print(1)
                road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, src_time, road_condition_x, road_condition_y, road_condition_t, traj_length,  start_times, end_times = batch
                road_id, road_rate = np.array(road_id), np.array(road_rate)

                src_lat, src_lng = np.array(src_lat), np.array(src_lng)
                road_condition_x, road_condition_y, road_condition_t = np.array(road_condition_x), np.array(road_condition_y), np.array(road_condition_t)
                

                mask_index, padd_index, forward_delta_t, backward_delta_t, forward_index, backward_index, traj_total_time, traj_total_dis = construct_mask(src_lat, src_lng, self.mbr, traj_length, "train", road_id.shape[0], keep_ratio=keep_ratio)
                
                trg_id = torch.from_numpy(road_id).permute(1, 0).long().to(self.device)  # T, B
                trg_rate = torch.tensor(road_rate, dtype=torch.float).permute(1, 0).to(self.device).unsqueeze(-1)
                src_candi_id = src_candi_id.float().to(self.device)
                mask_index, padd_index = torch.from_numpy(mask_index).float().to(self.device), torch.from_numpy(padd_index).to(self.device)
                src_lat, src_lng = torch.tensor(src_lat, dtype=torch.float).to(self.device), torch.tensor(src_lng, dtype=torch.float).to(self.device)
                road_condition_x, road_condition_y, road_condition_t = torch.tensor(road_condition_x, dtype=torch.long).to(self.device), torch.tensor(road_condition_y, dtype=torch.long).to(self.device), torch.tensor(road_condition_t, dtype=torch.long).to(self.device)
                
                prompt_token = self.Construct_traj_prompt.process_batch_data(self.task_prompt_tensor[keep_ratio], start_times, end_times, traj_total_time, traj_total_dis)

                forward_delta_t, backward_delta_t = torch.from_numpy(forward_delta_t).float().to(self.device), torch.from_numpy(backward_delta_t).float().to(self.device)
                forward_index, backward_index = torch.from_numpy(forward_index).long().to(self.device), torch.from_numpy(backward_index).long().to(self.device)

                src_lat[mask_index==1] = 0
                src_lng[mask_index==1] = 0
                road_condition_x[mask_index==1] = 0
                road_condition_y[mask_index==1] = 0

                src_candi_id[mask_index==1] = 0


                road_condition_xyt_index = torch.stack([road_condition_x, road_condition_y, road_condition_t], dim=-1)
                
                # print(padd_index.shape)
                # print(mask_index.shape)

                # print(src_candi_id[0][0])
                # print(src_candi_id[0][0].sum())
                # exit()
                # if epoch == 7: flag = True
                # print(2222)
                optimizer.zero_grad()
                predict_ID, predict_rate = self.model(src_lat, src_lng, src_time, mask_index, src_candi_id, traj_length, padd_index, keep_ratio, prompt_token, self.road_condition, road_condition_xyt_index, forward_delta_t, backward_delta_t, forward_index, backward_index)# predict_ID: T, B. F

                # print(3333)
                output_ids_dim = predict_ID.shape[-1]
                output_ids = predict_ID.reshape(-1, output_ids_dim)  # [(trg len - 1)* batch size, output id one hot dim]
                
                tmp_trg_id = trg_id.reshape(-1)
                loss_train_ids = criterion_ce(output_ids, tmp_trg_id)
                loss_rates = criterion_reg(predict_rate, trg_rate) * self.lambda1
                
                ttl_loss = loss_train_ids + loss_rates
                
                ttl_loss.backward()


                predict_road = predict_ID.argmax(-1)

                acc = cal_acc(predict_road, trg_id, traj_length)
                # predict_road = predict_ID.argmax(-1).reshape(-1)
                # acc = predict_road[predict_road == trg_id].sum() / predict_road.shape[0]
                # print(predict_road[:, 0])

                train_loss += ttl_loss.item()
                train_acc += acc
                optimizer.step()
                
            end_time = time.time()
            train_min, train_sec = epoch_time(start_time, end_time)

            train_acc, train_loss = train_acc / len(train_iterator), train_loss / len(train_iterator)

            start_time = time.time()
            val_acc_25 = self.val(self.model, val_iterator, 0.25)
            val_acc_125 = self.val(self.model, val_iterator, 0.125)
            val_acc_0625 = self.val(self.model, val_iterator, 0.0625)

            val_acc = val_acc_25 + val_acc_125 + val_acc_0625

            end_time = time.time()
            val_min, val_sec = epoch_time(start_time, end_time)

            if val_acc > self.best_acc:
                self.best_model = self.model
                self.best_acc = val_acc
                torch.save(self.model.state_dict(), save_model_path + model_name)

                best_epoch = epoch


                with open(save_txt, "a+") as f:
                    f.write("epoch: {}, train time: {}m {}s, val time: {}m {}s\n".format(epoch+1, train_min, train_sec, val_min, val_sec))
                    f.write("epoch: {}, train acc: {}, train loss: {}, val acc 0.25: {}, Val acc 0.125: {}, Val acc 0.0625: {}   |  Best Model!!!\n".format(epoch+1, train_acc, train_loss, val_acc_25, val_acc_125, val_acc_0625))
                print("epoch: {}, train time: {}m {}s, val time: {}m {}s".format(epoch+1, train_min, train_sec, val_min, val_sec))
                print("epoch: {}, train acc: {}, train loss: {}, val acc 0.25: {}, Val acc 0.125: {}, Val acc 0.0625: {}   |  Best Model!!!\n".format(epoch+1, train_acc, train_loss, val_acc_25, val_acc_125, val_acc_0625))
            else:
                with open(save_txt, "a+") as f:
                    f.write("epoch: {}, train time: {}m {}s, val time: {}m {}s\n".format(epoch+1, train_min, train_sec, val_min, val_sec))
                    f.write("epoch: {}, train acc: {}, train loss: {}, val acc 0.25: {}, Val acc 0.125: {}, Val acc 0.0625: {}\n".format(epoch+1, train_acc, train_loss, val_acc_25, val_acc_125, val_acc_0625))
                print("epoch: {}, train time: {}m {}s, val time: {}m {}s".format(epoch+1, train_min, train_sec, val_min, val_sec))
                print("epoch: {}, train acc: {}, train loss: {}, val acc 0.25: {}, Val acc 0.125: {}, Val acc 0.0625: {}\n".format(epoch+1, train_acc, train_loss, val_acc_25, val_acc_125, val_acc_0625))
            
                if epoch - best_epoch > 10:
                    with open(save_txt, "a+") as f:
                        f.write("Early step!\n")
                    print("Early step!\n")
                    break
            
            # exit()
        # exit()
    def val(self, model, iterator, keep_ratio, types = "val"):
        all_acc = 0
        all_num = 0
        self.model.eval()
        for i, batch in enumerate(iterator):
            road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, src_time, road_condition_x, road_condition_y, road_condition_t, traj_length,  start_times, end_times = batch
            road_id, road_rate = np.array(road_id), np.array(road_rate)
            src_lat, src_lng = np.array(src_lat), np.array(src_lng)
            road_condition_x, road_condition_y, road_condition_t = np.array(road_condition_x), np.array(road_condition_y), np.array(road_condition_t)


            mask_index, padd_index, forward_delta_t, backward_delta_t, forward_index, backward_index, traj_total_time, traj_total_dis = construct_mask(src_lat, src_lng, self.mbr, traj_length, "val", road_id.shape[0], keep_ratio=keep_ratio)
                
            trg_id = torch.tensor(road_id).permute(1, 0).long().to(self.device)  # T, B
            trg_rate = torch.tensor(road_rate, dtype=torch.float).permute(1, 0).to(self.device).unsqueeze(-1)
            src_candi_id = src_candi_id.float().to(self.device)
            mask_index, padd_index = torch.tensor(mask_index).float().to(self.device), torch.tensor(padd_index).to(self.device)
            src_lat, src_lng = torch.tensor(src_lat, dtype=torch.float).to(self.device), torch.tensor(src_lng, dtype=torch.float).to(self.device)

            prompt_token = self.Construct_traj_prompt.process_batch_data(self.task_prompt_tensor[keep_ratio], start_times, end_times, traj_total_time, traj_total_dis)


            src_lat[mask_index==1] = 0
            src_lng[mask_index==1] = 0
            src_candi_id[mask_index==1] = 0

            

            forward_delta_t, backward_delta_t = torch.from_numpy(forward_delta_t).float().to(self.device), torch.from_numpy(backward_delta_t).float().to(self.device)
            forward_index, backward_index = torch.from_numpy(forward_index).long().to(self.device), torch.from_numpy(backward_index).long().to(self.device)


            road_condition_x, road_condition_y, road_condition_t = torch.tensor(road_condition_x, dtype=torch.long).to(self.device), torch.tensor(road_condition_y, dtype=torch.long).to(self.device), torch.tensor(road_condition_t, dtype=torch.long).to(self.device)
            road_condition_x[mask_index==1] = 0
            road_condition_y[mask_index==1] = 0


            road_condition_xyt_index = torch.stack([road_condition_x, road_condition_y, road_condition_t], dim=-1)

            predict_ID, predict_rate = model(src_lat, src_lng, src_time, mask_index, src_candi_id, traj_length, padd_index, keep_ratio, prompt_token, self.road_condition, road_condition_xyt_index, forward_delta_t, backward_delta_t, forward_index, backward_index)# predict_ID: T, B. F
            
            predict_road = predict_ID.argmax(-1)
            # print(predict_road[:, 0])
            acc = cal_acc(predict_road, trg_id, traj_length)
            
            all_acc += acc
        return all_acc / len(iterator)
    
    def test(self, model, iterator, keep_ratio, rn_dict, raw_rn_dict, new2raw_rid_dict, types = "test", dataset = "Chengdu"):
        all_acc = 0
        all_recall = 0
        all_prec = 0
        all_num = 0
        self.model.eval()

        save_txt_path = "./save_result/{}/fine_tune_final/{}/".format(dataset, keep_ratio)  #一个模型保存到一个文件夹
        import os
        if not os.path.exists(save_txt_path): os.makedirs(save_txt_path)

        save_traj_start_id = 0
        for i, batch in enumerate(iterator):
            road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, src_time, road_condition_x, road_condition_y, road_condition_t, traj_length = batch
            road_id, road_rate = np.array(road_id), np.array(road_rate)
            src_lat, src_lng = np.array(src_lat), np.array(src_lng)
            road_condition_x, road_condition_y, road_condition_t = np.array(road_condition_x), np.array(road_condition_y), np.array(road_condition_t)


            mask_index, padd_index, forward_delta_t, backward_delta_t, forward_index, backward_index = construct_mask(traj_length, types, road_id.shape[0], keep_ratio=keep_ratio)
            
            trg_id = torch.tensor(road_id).permute(1, 0).long().to(self.device)  # T, B
            trg_rate = torch.tensor(road_rate, dtype=torch.float).permute(1, 0).to(self.device).unsqueeze(-1)
            src_candi_id = src_candi_id.float().to(self.device)
            mask_index, padd_index = torch.tensor(mask_index).float().to(self.device), torch.tensor(padd_index).to(self.device)
            src_lat, src_lng = torch.tensor(src_lat, dtype=torch.float).to(self.device), torch.tensor(src_lng, dtype=torch.float).to(self.device)

            prompt_token = self.prompt_token[keep_ratio].to(self.device)
            prompt_token = prompt_token.repeat(src_lat.shape[0], 1, 1)

            src_lat[mask_index==1] = 0
            src_lng[mask_index==1] = 0
            src_candi_id[mask_index==1] = 0

            forward_delta_t, backward_delta_t = torch.from_numpy(forward_delta_t).float().to(self.device), torch.from_numpy(backward_delta_t).float().to(self.device)
            forward_index, backward_index = torch.from_numpy(forward_index).long().to(self.device), torch.from_numpy(backward_index).long().to(self.device)

            road_condition_x, road_condition_y, road_condition_t = torch.tensor(road_condition_x, dtype=torch.long).to(self.device), torch.tensor(road_condition_y, dtype=torch.long).to(self.device), torch.tensor(road_condition_t, dtype=torch.long).to(self.device)
            road_condition_x[mask_index==1] = 0
            road_condition_y[mask_index==1] = 0


            road_condition_xyt_index = torch.stack([road_condition_x, road_condition_y, road_condition_t], dim=-1)


            predict_ID, predict_rate = model(src_lat, src_lng, src_time, mask_index, src_candi_id, traj_length, padd_index, keep_ratio, prompt_token, self.road_condition, road_condition_xyt_index, forward_delta_t, backward_delta_t, forward_index, backward_index)# predict_ID: T, B. F
            
            # 接下来需要将目标的lat，lng，ID，rate，以及预测的ID，rate保存下来


            # print(trg_id.shape, trg_rate.shape)#, predict_ID.shape, predict_rate.shape)
            # print(mm_lat.shape, mm_lng.shape)
            # exit()
            save_traj_start_id = save_traj_start_id + trg_id.shape[1]
            save_txt(save_traj_start_id, save_txt_path, trg_id, trg_rate, predict_ID, predict_rate, mm_lat, mm_lng, traj_length)
            

def save_txt(save_traj_start_id, save_path, trg_id, trg_rate, predict_ID, predict_rate, mm_lat, mm_lng, traj_length):
    trg_id = trg_id.permute(1, 0)
    trg_rate = trg_rate[:,:,0].permute(1, 0)
    # print(mm_lat)
    # exit()
    predict_ID = predict_ID.permute(1, 0, 2).argmax(-1)
    predict_rate = predict_rate[:,:,0].permute(1, 0)

    txt_file = save_traj_start_id // 1000  #一个文件保存1000条左右轨迹，便于查看
    save_name = str(txt_file) + ".txt"
    with open(save_path + save_name, "a+") as f:

        for batchs in range(trg_id.shape[0]):
            trajectory_id = save_traj_start_id + batchs  #g轨迹的编号
            f.write("#,trajectory_id: {}, length: {}\n".format(trajectory_id, traj_length[batchs]))  #每一条轨迹的开头，记录轨迹的编号，以及轨迹的长度
            
            #接下来的每一行记录一条轨迹的信息
            for traj_point in range(traj_length[batchs]):
                target_id = trg_id[batchs][traj_point].item()
                
                target_rate = trg_rate[batchs][traj_point].item()
                pred_id = predict_ID[batchs][traj_point].item()
                pred_rate = predict_rate[batchs][traj_point].item()

                # 对于经纬度，没有转成tensor，所以不需要加item()
                target_lat = mm_lat[batchs][traj_point]
                target_lng = mm_lng[batchs][traj_point]

                f.write("{},{},{},{},{},{}\n".format(target_lat, target_lng, target_id, target_rate, pred_id, pred_rate))
    # exit()