import torch
import os
import time
from torch import optim
from torch.optim import lr_scheduler 
import numpy as np
import matplotlib.pyplot as plt

from utils.tools import EarlyStopping, adjust_learning_rate, test_params_flop
from utils.loss import type2loss, get_loss

from models.GRformer import GRformer
from models.Informer import Informer
from models.MTGNN import gtnet
from models.Linear import DLinear, NLinear, GLinear

models_dict={
    "Informer": Informer,
    "MTGNN": gtnet,
    "DLinear": DLinear,
    "NLinear": NLinear,
    "GLinear": GLinear,
    "GRformer": GRformer
}
optimizer_catagory = {
    'adam': optim.Adam
}

class Trainer():
    def __init__(self, args, setting, task_path, corr, high_correlated_count) -> None:
        self.args = args
        self.setting = setting
        self.task_path = task_path
        self.device = torch.device(args.device)
        if args.use_gcn:
            model = models_dict[self.args.model](self.args, corr=corr, high_correlated_count=high_correlated_count).float().to(self.args.device)
        else:
            model = models_dict[self.args.model](self.args).float().to(self.args.device)
        if args.use_multi_gpu:
            self.model = torch.nn.DataParallel(model, device_ids=self.args.device_ids)
        else:
            self.model = model

    def train(self, data):
        # 加载模型

        print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(self.setting))
        # 配置模型存储目录
        path = os.path.join(self.args.checkpoints, self.task_path, self.setting)
        if not os.path.exists(path):
            os.makedirs(path)
        folder_path = './results/{}/{}/'.format(self.task_path, self.setting)
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        print(f'the model store path will be: {path}')
        print(f'the result store path will be: {folder_path}')
        
        # 加载数据
        train_loader = data['train_loader']
        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
        # 优化器选择
        optimizer = optimizer_catagory[self.args.opt](self.model.parameters(), lr=self.args.learning_rate)
        # 损失函数
        if self.args.loss=='mse':
            loss_func = torch.nn.MSELoss()
        elif self.args.loss=='mae':
            loss_func = torch.nn.L1Loss()
        elif self.args.loss=='huber':
            loss_func = torch.nn.HuberLoss(reduction='mean', delta=0.8)
        # 自动混合精度，让模型训练效果更好的，用法是和下面的autocast联动
        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()
        # 学习率调节，按pct_start比例先增大后减小
        scheduler = lr_scheduler.OneCycleLR(optimizer = optimizer,
                                            steps_per_epoch = train_steps,
                                            pct_start = self.args.pct_start,
                                            epochs = self.args.train_epochs,
                                            max_lr = self.args.learning_rate)
        # plt-show loss
        train_loss_list = []
        vali_loss_list = []
        epoch_list = []
        for epoch in range(self.args.train_epochs):
            epoch_list.append(epoch+1)
            epoch_time = time.time()
            iter_count = 0
            train_loss = []
            
            self.model.train()
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
            # for i, (batch_x, batch_y) in enumerate(train_loader):

                # 记录日志时间
                if i==0:
                    current_time = time.time()
                iter_count += 1
                optimizer.zero_grad()                                                                       # 避免梯度累积
                batch_x = batch_x.float().to(self.args.device)                                              # 原始数据
                batch_y = batch_y.float().to(self.args.device)
                batch_x_mark = batch_x_mark.float()                                                         # 数据的时间嵌入标识
                batch_y_mark = batch_y_mark.float()

                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'TST' in self.args.model:
                            if self.args.use_gcn:
                                outputs, A0, A = self.model(batch_x)
                            else:
                                outputs = self.model(batch_x)
                        else:
                            if self.args.model == "MTGNN":
                                outputs, A0, A = self.model(batch_x)
                            else:
                                # decoder input
                                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                                # 生成了和预测长度相同的全0向量，这是掩码的意思
                                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                                if self.args.output_attention:
                                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                                else:
                                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                        f_dim = -1 if self.args.features == 'MS' else 0
                        outputs = outputs[:, -self.args.pred_len:, f_dim:]
                        batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.args.device)
                        if self.args.loss == 'huber':
                            loss = loss_func(outputs, batch_y, 1)
                        else:
                            loss = loss_func(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model or 'RP' in self.args.model or 'PG' in self.args.model or 'GR' in self.args.model:
                        if self.args.use_gcn:
                            outputs, A0, A = self.model(batch_x)
                        else:
                            outputs = self.model(batch_x)
                    else:
                        if self.args.model == "MTGNN":
                            outputs, A0, A = self.model(batch_x)
                        else:
                            # decoder input
                            dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                            # 生成了和预测长度相同的全0向量，这是掩码的意思
                            dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float()
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                    f_dim = -1 if self.args.features == 'MS' else 0
                    outputs = outputs[:, -self.args.pred_len:, f_dim:]
                    batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.args.device)
                    loss = loss_func(outputs, batch_y)
                    train_loss.append(loss.item())
                if (i + 1) % 200 == 0:
                    speed = (time.time() - current_time) / iter_count
                    print("\tepoch_{} iter_{} | loss: {} | speed: {:.2f} (s/iter)".format(epoch + 1, i + 1, loss.item(), speed))
                    current_time = time.time()
                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    optimizer.step()
                if self.args.lradj == 'TST':
                    adjust_learning_rate(optimizer, scheduler, epoch + 1, self.args, printout=False)
                    scheduler.step()
                del batch_x
                del batch_y
            print("epoch {} finished. cost time {:.2f}s".format(epoch + 1, time.time() - epoch_time))
            # with open('./storage/gradients.txt', 'w') as f:
            #     for name, param in self.model.named_parameters():
            #         if param.grad is not None:
            #             gradient = param.grad.detach().cpu().numpy()
            #             f.write(f'{name}:\n{gradient}\n\n')
            #     f.close()
            # 画出每轮结束时原始A的热力图
            # if self.args.use_gcn and 'Linear' not in self.args.model and (epoch+1)%10==0:
            #     index='0'
            #     for a in [A0, A]:
            #         a = a[:self.args.enc_in]
            #         a = a.detach().cpu().numpy()
            #         plt.figure(figsize=(7.2, 7.2))
            #         plt.imshow(a, cmap='coolwarm', interpolation='nearest')
            #         plt.colorbar()
            #         # 设置坐标刻度
            #         xticks = np.arange(0, a.shape[1], 100)
            #         yticks = np.arange(0, a.shape[0], 100)
            #         plt.xticks(xticks)
            #         plt.yticks(yticks)
            #         plt.title(f'Original Embedding A{index} Heatmap')
            #         plt.savefig(folder_path+'A{}_epoch{}.png'.format(index, epoch))
            #         plt.clf()
            #         index=''
            # 验证
            train_loss = np.average(train_loss)
            train_loss_list.append(train_loss)
            vali_loss = self.validate(data['val_loader'], loss_func)
            vali_loss_list.append(vali_loss)
            # test_loss = self.validate(data['test_loader'], loss_func)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.4f} Vali Loss: {3:.4f}".format(
                epoch + 1, train_steps, train_loss, vali_loss))
            early_stopping(vali_loss, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            if self.args.lradj != 'TST':
                adjust_learning_rate(optimizer, scheduler, epoch + 1, self.args)
            else:
                print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0]))
        fig, ax = plt.subplots(figsize=(12.8, 7.2))
        ax.plot(epoch_list, train_loss_list, 'r', marker='x', label='train_loss')
        ax.plot(epoch_list, vali_loss_list, 'b', marker='*', label='vali_loss')
        ax.legend()
        ax.set_title('loss func value')
        ax.set_xlabel('epoch')
        ax.set_ylabel('loss')
        plt.savefig(folder_path+'loss.png', dpi=100)
        plt.clf()

    def validate(self, vali_loader, loss_func):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
            # for i, (batch_x, batch_y) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.args.device)
                batch_y = batch_y.float().to(self.args.device)

                batch_x_mark = batch_x_mark.float()
                batch_y_mark = batch_y_mark.float()

                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs, _, _ = self.model(batch_x.unsqueeze(-1).permute(0, 3, 2, 1))
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model or 'RP' in self.args.model or 'PG' in self.args.model or 'GR' in self.args.model:
                        if self.args.use_gcn:
                            outputs, _, _ = self.model(batch_x)
                        else:
                            outputs = self.model(batch_x)
                    else:
                        if self.args.model == "MTGNN":
                            outputs, _, _ = self.model(batch_x)
                        else:
                            dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                            dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float()
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                f_dim = -1 if self.args.features == 'MS' else 0
                outputs = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.args.device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                loss = loss_func(pred, true)

                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def test(self, test_loader):
        print('>>>>>>>start testing : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(self.setting))
        self.model.load_state_dict(torch.load(os.path.join('./checkpoints/{}/{}'.format(self.task_path, self.setting), 'checkpoint.pth')))

        preds = []
        trues = []

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
            # for i, (batch_x, batch_y) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.args.device)
                batch_y = batch_y.float().to(self.args.device)

                batch_x_mark = batch_x_mark.float()
                batch_y_mark = batch_y_mark.float()

                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs, _, _ = self.model(batch_x)
                        # else:
                        #     if self.args.output_attention:
                        #         outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        #     else:
                        #         outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model or 'RP' in self.args.model or 'PG' in self.args.model or 'GR' in self.args.model:
                        if self.args.use_gcn:
                            outputs, _, _ = self.model(batch_x)
                        else:
                            outputs = self.model(batch_x)
                    else:
                        if self.args.model == "MTGNN":
                            outputs, _, _ = self.model(batch_x)
                        else:
                            dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
                            dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float()
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

                f_dim = -1 if self.args.features == 'MS' else 0
                outputs = outputs[:, -self.args.pred_len:, f_dim:]
                batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.args.device)

                preds.append(outputs)
                trues.append(batch_y)


        if self.args.test_flop:
            test_params_flop((batch_x.shape[1],batch_x.shape[2]))
            exit()

        preds = torch.cat(preds, dim=0)
        trues = torch.cat(trues, dim=0)

        # result save
        folder_path = './results/{}/{}/'.format(self.task_path, self.setting)
        
        mse, rmse, mae, mape = get_loss(preds, trues)
        print('mse:{}, mae:{}, rmse:{}, mape:{}'.format(mse, mae, rmse, mape))
        f = open(folder_path+"result.txt", 'a')
        f.write(self.setting + "  \n")
        f.write('mse:{}, mae:{}, rse:{}, mape:{} ==== lr={}'.format(mse, mae, rmse, mape, self.args.learning_rate))
        f.write('\n')
        f.write('\n')
        f.close()

        return mse

    def predict(self, pred_loader):
        print('>>>>>>>start predicting : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(self.setting))
        # self.model.load_state_dict(torch.load(os.path.join('./checkpoints/{}/{}'.format(self.task_path, self.setting), 'checkpoint.pth'), map_location="cuda:0"))

        preds = []
        trues = []

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                # dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float().to(batch_y.device)
                # dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if self.args.use_amp:
                    with torch.cuda.amp.autocast():
                        if 'Linear' in self.args.model or 'TST' in self.args.model:
                            outputs = self.model(batch_x)
                        else:
                            if self.args.output_attention:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                            else:
                                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                else:
                    if 'Linear' in self.args.model or 'TST' in self.args.model or 'RP' in self.args.model or 'PG' in self.args.model or 'GR' in self.args.model:
                        if self.args.use_gcn:
                            outputs, _, _ = self.model(batch_x)
                        else:
                            outputs = self.model(batch_x)
                    else:
                        if self.args.output_attention:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                        else:
                            outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                pred = outputs.detach().cpu().numpy()  # .squeeze()
                preds.append(pred)

        preds = np.array(preds)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])

        # result save
        folder_path = './results/{}/{}/'.format(self.task_path, self.setting)
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        np.save(folder_path + f'{self.args.model}_real_prediction.npy', preds)

        return