import torch
import torch.nn as nn
import torch.optim as optim
from data import next_batch, process_batch
from tqdm import tqdm
from utils import cal_acc, epoch_time
import time
import numpy as np
from load_data.datasets import Dataset, collate_fn, construct_mask
from model.loss_fn import cal_id_acc, check_rn_dis_loss
from model.model_utils import toseq
from transformers import BertModel, BertTokenizer

class Construct_traj_prompt():
    def __init__(self, time_prompt_tensor, travel_prompt_tensor, traj_prompt, device):
        # self.task_prompt_tensor = task_prompt_tensor
        self.time_prompt_tensor = time_prompt_tensor
        self.travel_prompt_tensor = travel_prompt_tensor
        self.traj_prompt = traj_prompt
        self.device = device

        MODEL_PATH = './PLM/BERT' # 装着上面3个文件的文件夹位置
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path=MODEL_PATH)

        model = BertModel.from_pretrained(MODEL_PATH)  # 读取预训练模型

        self.bert_token = model.state_dict()['embeddings.word_embeddings.weight']
        self.weekday = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
    
    def encode_prompt(self, text):
        indexed_tokens = self.tokenizer.encode(text)
        tokens_tensor = torch.tensor([indexed_tokens])  # 文本编码转tensor
        return self.bert_token[tokens_tensor]

    
    def process_one_traj_prompt(self, length):
        
        _start_time = str(self.start_time[length][0]) + ":" + str(self.start_time[length][1])
        _start_date = self.weekday[self.start_time[length][2]]
        _end_time = str(self.end_time[length][0]) + ":" + str(self.end_time[length][1])
        _end_date = self.weekday[self.end_time[length][2]]
        
        _time = _start_time + " on " + _start_date + " and ended at " + _end_time + " on " + _end_date +". "
        
        _time_prompt = torch.cat([self.time_prompt_tensor, self.encode_prompt(_time)], dim=1)


        travel_minutes, travel_seconds = self.total_time[length][0], self.total_time[length][1]
        travel_dis = self.total_dis[length]

        travel_time_prompt = self.encode_prompt("{} minutes {} seconds.".format(travel_minutes, travel_seconds))
        travel_dis_prompt = self.encode_prompt("{:.2f} kilometers.".format(travel_dis))
        
        travel_prompt = torch.cat([self.travel_prompt_tensor[0], travel_time_prompt, self.travel_prompt_tensor[1], travel_dis_prompt], dim=1)
        
        return torch.cat([self.task_prompt_tensor, _time_prompt, travel_prompt, self.traj_prompt], dim=1)
        
    def process_batch_data(self, task_prompt_tensor, start_time, end_time, total_time, total_dis):

        self.task_prompt_tensor = task_prompt_tensor
        self.start_time = start_time
        self.end_time = end_time
        self.total_time = total_time
        self.total_dis = total_dis
        res = []
        for i in range(len(self.start_time)):
            prompt_token = self.process_one_traj_prompt(i)
            res.append(prompt_token)
        res = torch.stack(res, 0).squeeze(1).to(self.device)
        return res


class Trainer():
    
    def __init__(self, model, batch_size, device, lr, lambda1, mbr, road_condition, id_size, task_prompt_tensor,  time_prompt_tensor, travel_prompt_tensor, traj_prompt, clip=1):
        self.model = model
        self.batch_size = batch_size
        self.device = device
        self.lr = lr
        self.lambda1 = lambda1
        self.clip = clip
        self.mbr = mbr
        self.id_size = id_size
        self.task_prompt_tensor = task_prompt_tensor
        self.time_prompt_tensor = time_prompt_tensor
        self.travel_prompt_tensor = travel_prompt_tensor
        self.traj_prompt = traj_prompt

        self.best_model = model
        self.road_condition = road_condition
        self.best_acc = -1
        self.Construct_traj_prompt = Construct_traj_prompt(time_prompt_tensor, travel_prompt_tensor, traj_prompt, device)
    def train(self, epochs, train_iterator, val_iterator, save_txt, save_model_path, model_name):
        optimizer = optim.AdamW(self.model.parameters(), lr=self.lr)

        criterion_reg = nn.MSELoss()
        criterion_ce = nn.NLLLoss()
        self.model.train()
        flag = False
        train_loss, train_acc = 0, 0
        best_epoch = 0
        start_time = time.time()
        for epoch in range(epochs):
            for i, batch in enumerate(train_iterator):
                if i % 3 == 0: keep_ratio = 0.25
                if i % 3 == 1: keep_ratio = 0.125
                if i % 3 == 2: keep_ratio = 0.0625
                # print(1)
                road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, src_time, road_condition_x, road_condition_y, road_condition_t, traj_length,  start_times, end_times = batch
                road_id, road_rate = np.array(road_id), np.array(road_rate)

                src_lat, src_lng = np.array(src_lat), np.array(src_lng)
                road_condition_x, road_condition_y, road_condition_t = np.array(road_condition_x), np.array(road_condition_y), np.array(road_condition_t)
                

                mask_index, padd_index, forward_delta_t, backward_delta_t, forward_index, backward_index, traj_total_time, traj_total_dis = construct_mask(src_lat, src_lng, self.mbr, traj_length, "train", road_id.shape[0], keep_ratio=keep_ratio)
                
                trg_id = torch.from_numpy(road_id).permute(1, 0).long().to(self.device)  # T, B
                trg_rate = torch.tensor(road_rate, dtype=torch.float).permute(1, 0).to(self.device).unsqueeze(-1)
                src_candi_id = src_candi_id.float().to(self.device)
                mask_index, padd_index = torch.from_numpy(mask_index).float().to(self.device), torch.from_numpy(padd_index).to(self.device)
                src_lat, src_lng = torch.tensor(src_lat, dtype=torch.float).to(self.device), torch.tensor(src_lng, dtype=torch.float).to(self.device)
                road_condition_x, road_condition_y, road_condition_t = torch.tensor(road_condition_x, dtype=torch.long).to(self.device), torch.tensor(road_condition_y, dtype=torch.long).to(self.device), torch.tensor(road_condition_t, dtype=torch.long).to(self.device)
                
                prompt_token = self.Construct_traj_prompt.process_batch_data(self.task_prompt_tensor[keep_ratio], start_times, end_times, traj_total_time, traj_total_dis)

                forward_delta_t, backward_delta_t = torch.from_numpy(forward_delta_t).float().to(self.device), torch.from_numpy(backward_delta_t).float().to(self.device)
                forward_index, backward_index = torch.from_numpy(forward_index).long().to(self.device), torch.from_numpy(backward_index).long().to(self.device)

                src_lat[mask_index==1] = 0
                src_lng[mask_index==1] = 0
                road_condition_x[mask_index==1] = 0
                road_condition_y[mask_index==1] = 0

                src_candi_id[mask_index==1] = 0


                road_condition_xyt_index = torch.stack([road_condition_x, road_condition_y, road_condition_t], dim=-1)
                
                # print(padd_index.shape)
                # print(mask_index.shape)

                # print(src_candi_id[0][0])
                # print(src_candi_id[0][0].sum())
                # exit()
                # if epoch == 7: flag = True
                # print(2222)
                optimizer.zero_grad()
                predict_ID, predict_rate = self.model(src_lat, src_lng, src_time, mask_index, src_candi_id, traj_length, padd_index, keep_ratio, prompt_token, self.road_condition, road_condition_xyt_index, forward_delta_t, backward_delta_t, forward_index, backward_index)# predict_ID: T, B. F

                # print(3333)
                output_ids_dim = predict_ID.shape[-1]
                output_ids = predict_ID.reshape(-1, output_ids_dim)  # [(trg len - 1)* batch size, output id one hot dim]
                
                tmp_trg_id = trg_id.reshape(-1)
                loss_train_ids = criterion_ce(output_ids, tmp_trg_id)
                loss_rates = criterion_reg(predict_rate, trg_rate) * self.lambda1
                
                ttl_loss = loss_train_ids + loss_rates
                
                ttl_loss.backward()


                predict_road = predict_ID.argmax(-1)

                acc = cal_acc(predict_road, trg_id, traj_length)
                # predict_road = predict_ID.argmax(-1).reshape(-1)
                # acc = predict_road[predict_road == trg_id].sum() / predict_road.shape[0]
                # print(predict_road[:, 0])

                train_loss += ttl_loss.item()
                train_acc += acc
                optimizer.step()
                
            end_time = time.time()
            train_min, train_sec = epoch_time(start_time, end_time)

            train_acc, train_loss = train_acc / len(train_iterator), train_loss / len(train_iterator)

            start_time = time.time()
            val_acc = self.val(self.model, val_iterator, keep_ratio)

            end_time = time.time()
            val_min, val_sec = epoch_time(start_time, end_time)

            if val_acc > self.best_acc:
                self.best_model = self.model
                self.best_acc = val_acc
                torch.save(self.model.state_dict(), save_model_path + model_name)

                best_epoch = epoch


                with open(save_txt, "a+") as f:
                    f.write("epoch: {}, train time: {}m {}s, val time: {}m {}s\n".format(epoch+1, train_min, train_sec, val_min, val_sec))
                    f.write("epoch: {}, train acc: {}, train loss: {}, val acc {}: {}  |  Best Model!!!\n".format(epoch+1, train_acc, train_loss, keep_ratio, val_acc))
                print("epoch: {}, train time: {}m {}s, val time: {}m {}s".format(epoch+1, train_min, train_sec, val_min, val_sec))
                print("epoch: {}, train acc: {}, train loss: {}, val acc {}: {}  |  Best Model!!!\n".format(epoch+1, train_acc, train_loss, keep_ratio, val_acc))
            else:
                with open(save_txt, "a+") as f:
                    f.write("epoch: {}, train time: {}m {}s, val time: {}m {}s\n".format(epoch+1, train_min, train_sec, val_min, val_sec))
                    f.write("epoch: {}, train acc: {}, train loss: {}, val acc {}: {}\n".format(epoch+1, train_acc, train_loss, keep_ratio, val_acc))
                print("epoch: {}, train time: {}m {}s, val time: {}m {}s".format(epoch+1, train_min, train_sec, val_min, val_sec))
                print("epoch: {}, train acc: {}, train loss: {}, val acc {}: {}\n".format(epoch+1, train_acc, train_loss, keep_ratio, val_acc))
                    
                if epoch - best_epoch > 10:
                    with open(save_txt, "a+") as f:
                        f.write("Early step!\n")
                    print("Early step!\n")
                    break
            
            # exit()
        # exit()
    
    def val(self, model, iterator, keep_ratio, types = "val"):
        all_acc = 0
        all_num = 0
        self.model.eval()
        for i, batch in enumerate(iterator):
            road_id, road_rate, mm_lat, mm_lng, src_lat, src_lng, src_candi_id, src_time, road_condition_x, road_condition_y, road_condition_t, traj_length,  start_times, end_times = batch
            road_id, road_rate = np.array(road_id), np.array(road_rate)
            src_lat, src_lng = np.array(src_lat), np.array(src_lng)
            road_condition_x, road_condition_y, road_condition_t = np.array(road_condition_x), np.array(road_condition_y), np.array(road_condition_t)


            mask_index, padd_index, forward_delta_t, backward_delta_t, forward_index, backward_index, traj_total_time, traj_total_dis = construct_mask(src_lat, src_lng, self.mbr, traj_length, "val", road_id.shape[0], keep_ratio=keep_ratio)
                
            trg_id = torch.tensor(road_id).permute(1, 0).long().to(self.device)  # T, B
            trg_rate = torch.tensor(road_rate, dtype=torch.float).permute(1, 0).to(self.device).unsqueeze(-1)
            src_candi_id = src_candi_id.float().to(self.device)
            mask_index, padd_index = torch.tensor(mask_index).float().to(self.device), torch.tensor(padd_index).to(self.device)
            src_lat, src_lng = torch.tensor(src_lat, dtype=torch.float).to(self.device), torch.tensor(src_lng, dtype=torch.float).to(self.device)

            prompt_token = self.Construct_traj_prompt.process_batch_data(self.task_prompt_tensor[keep_ratio], start_times, end_times, traj_total_time, traj_total_dis)


            src_lat[mask_index==1] = 0
            src_lng[mask_index==1] = 0
            src_candi_id[mask_index==1] = 0

            

            forward_delta_t, backward_delta_t = torch.from_numpy(forward_delta_t).float().to(self.device), torch.from_numpy(backward_delta_t).float().to(self.device)
            forward_index, backward_index = torch.from_numpy(forward_index).long().to(self.device), torch.from_numpy(backward_index).long().to(self.device)


            road_condition_x, road_condition_y, road_condition_t = torch.tensor(road_condition_x, dtype=torch.long).to(self.device), torch.tensor(road_condition_y, dtype=torch.long).to(self.device), torch.tensor(road_condition_t, dtype=torch.long).to(self.device)
            road_condition_x[mask_index==1] = 0
            road_condition_y[mask_index==1] = 0


            road_condition_xyt_index = torch.stack([road_condition_x, road_condition_y, road_condition_t], dim=-1)

            predict_ID, predict_rate = model(src_lat, src_lng, src_time, mask_index, src_candi_id, traj_length, padd_index, keep_ratio, prompt_token, self.road_condition, road_condition_xyt_index, forward_delta_t, backward_delta_t, forward_index, backward_index)# predict_ID: T, B. F
            
            predict_road = predict_ID.argmax(-1)
            # print(predict_road[:, 0])
            acc = cal_acc(predict_road, trg_id, traj_length)
            
            all_acc += acc
        return all_acc / len(iterator)