from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from models import ESSformer
from .tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop, metric

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler 

import seaborn as sns
import os, sys
import time

import warnings
import matplotlib.pyplot as plt
import numpy as np
from torch.nn.parallel import DistributedDataParallel as DDP


warnings.filterwarnings('ignore')

class Exp_Main:
    def __init__(self, args):
        self.args = args
        self.device = args.device
        self.model = self._build_model().to(self.device)
        self.model = DDP(module=self.model, device_ids=[args.local_gpu_id], find_unused_parameters=True) # 

    def _build_model(self):
        model_dict = {
            'ESSformer' : ESSformer
        }
        model = model_dict[self.args.model].Model(self.args).float()
        
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate, weight_decay = 1e-6)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, vali_data, vali_loader):
        total_mse = []
        total_mae = []
        self.model.eval()
        mse = nn.MSELoss()
        mae = nn.L1Loss()
        start_time = time.time()
        with torch.no_grad():
            for i, (_, batch_x, batch_y, _, _) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()

                outputs, _ = self.model(batch_x)

                outputs = outputs[:, -self.args.pred_len:]
                batch_y = batch_y[:, -self.args.pred_len:].to(self.device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                mse_loss = mse(pred, true)
                mae_loss = mae(pred, true)

                total_mse.append(mse_loss.item())
                total_mae.append(mae_loss.item())
        inf_time = (time.time() - start_time)

        total_mse = np.average(total_mse)
        total_mae = np.average(total_mae)

        self.model.train()
        return torch.tensor([total_mse, total_mae, inf_time]).to(self.device)
    
    def alive_value(self, batch, c_in):
        # split_num = #np.random.randint(self.args.split_num_s, self.args.split_num_e + 1)
        # drop_num = c_in % self.args.split_num
        
        alive_num = self.args.split_num * self.args.split_mult#c_in - drop_num 
        # print(alive_num)
        if (c_in == self.args.split_num) or (c_in ==self.args.split_mult):
            live_index = torch.stack([torch.arange(c_in) for i in range(batch)], dim = 0)
        else:
            live_index = torch.stack([torch.randperm(c_in)[:alive_num] for i in range(batch)], dim = 0)
        # live_index = torch.stack([torch.arange(c_in)[:alive_num] for i in range(batch)], dim = 0)
        # import pdb ; pdb.set_trace()
        return torch.arange(batch).unsqueeze(-1).repeat(1,live_index.shape[-1]), live_index

    def train(self, setting):
        folder_path = f'./{self.args.global_path}/results/' + setting + '/'
        # import pdb ; pdb.set_trace()
        # if os.path.exists(folder_path):
        #     sys.exit(0)
        
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')
        
        path = os.path.join(self.args.global_path, self.args.checkpoints, setting)
        if not os.path.exists(path) and self.args.rank == 0:
            os.makedirs(path)
            

        time_now = time.time()

        train_steps = len(train_loader)
        # early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

            
        # scheduler = lr_scheduler.OneCycleLR(optimizer = model_optim,
        #                                     steps_per_epoch = train_steps,
        #                                     pct_start = 0.2,
        #                                     epochs = self.args.train_epochs,
        #                                     max_lr = self.args.learning_rate)
        
        all_training_loss = []
        all_val_loss = []
        all_test_loss = []

        saved_val_loss = float('inf')
        folder_path = f'./{self.args.global_path}/results/' + setting + '/'
        if not os.path.exists(folder_path) and self.args.rank == 0:
            os.makedirs(folder_path)
        if self.args.rank == 0:
            with open(f"{folder_path}/record.txt","w") as f:
                f.write("\n".join([f"{k}:{v}"for k,v in vars(self.args).items()]))
                f.write("\n")
                f.write("\n")
        start_time = time.time()

        
        torch.distributed.barrier()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            train_data.set_epoch(epoch)
            vali_data.set_epoch(epoch)
            for i,(_, batch_x, batch_y, _, _) in enumerate(train_loader):
                iter_count += 1
                
                model_optim.zero_grad()
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                
                # print(self.args.rank, idxes)
                if "ESSformer" in self.args.model:
                    idx1, idx2 = self.alive_value(len(batch_x), self.args.enc_in)
                    outputs, atten_time = self.model(batch_x.transpose(1,2)[idx1, idx2].transpose(1,2), idx = idx2)
                    outputs = outputs[:, -self.args.pred_len:]
                    batch_y = batch_y[:, -self.args.pred_len:].to(self.device)
                    batch_y = batch_y.transpose(1,2)[idx1, idx2].transpose(1,2)
                    
                else:
                    outputs, atten_time = self.model(batch_x)
                    outputs = outputs[:, -self.args.pred_len:]
                    batch_y = batch_y[:, -self.args.pred_len:].to(self.device)
                loss = criterion(outputs, batch_y)
                train_loss.append(loss.item())
                # print(loss)
                # break
                if (i + 1) % 100 == 0 and self.args.rank == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()
                loss.backward()
                model_optim.step()
                # print(loss)
            # print(iter_count)

                # if self.args.data == 'traffic':
                #     # print("여기2")
                #     adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False)
                #     scheduler.step()
            # if (epoch + 1) % 10 != 0:
            #     continue
            # print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            speed = format((time.time() - time_now) / (len(train_loss) * (epoch + 1)), ".4f")
            train_loss = np.average(train_loss)
            vali_info = self.vali(vali_data, vali_loader) # vali_mse, vali_mae, vali_inftime
            test_info = self.vali(test_data, test_loader)
            

            # vali_mse = torch.distributed.all_gather()
            if  self.args.rank != 0:
                torch.distributed.gather(vali_info)
                torch.distributed.gather(test_info)
            else:
                gather_lst_vali = [torch.zeros_like(vali_info) for i in range(self.args.world_size)]
                gather_lst_test = [torch.zeros_like(test_info) for i in range(self.args.world_size)]
                torch.distributed.gather(vali_info, gather_lst_vali)
                torch.distributed.gather(test_info, gather_lst_test)
                # print(gather_lst_test)
                gahter_vali = sum(gather_lst_vali) / self.args.world_size
                gahter_test = sum(gather_lst_test) / self.args.world_size
                vali_mse, vali_mae, vali_inftime = gahter_vali.cpu().numpy()
                test_mse, test_mae, test_inftime = gahter_test.cpu().numpy()

                time_diff = int(time.time() - start_time)
                time_diff_h = f"{time_diff//3600}h {(time_diff%3600) // 60 }m {(time_diff%3600) % 60}s"
                p_text = "Epoch: {0}, Steps: {1} | Train Loss: {2:.4f}, Vali Loss: ({3:.4f},{4:.4f}) Test Loss: ({5:.4f},{6:.4f}) | Time: {7}({8}, {9}, {10:.4f}, {11:.4f})".format(
                    epoch + 1, train_steps, train_loss, vali_mse, vali_mae, test_mse, test_mae, time_diff, time_diff_h, speed, vali_inftime, test_inftime)
                with open(f"{folder_path}/record.txt","a") as f:
                    f.write(p_text)
                    f.write(f"\n")
                print(p_text)
                # torch.save(self.model.state_dict(), path + '/' + f'checkpoint_{epoch+1}.pth')
                if vali_mse < saved_val_loss:
                    saved_val_loss = vali_mse
                    with open(f"{folder_path}/record.txt","a") as f:
                        f.write('save')
                        f.write(f"\n")
            
                    print(f'Validation loss decreased ({saved_val_loss:.6f} --> {vali_mse:.6f}).  Saving model ...')
                    torch.save(self.model.state_dict(), path + '/' + 'checkpoint.pth')
                
                all_test_loss.append(test_mse)
                all_val_loss.append(vali_mse)
                all_training_loss.append(train_loss)
            torch.distributed.barrier()


        if  self.args.rank == 0:
            sns.lineplot(x = range(len(all_training_loss)), y = all_training_loss, label = "train") 
            sns.lineplot(x = range(len(all_val_loss)), y = all_val_loss, label = "val") 
            sns.lineplot(x = range(len(all_test_loss)), y = all_test_loss, label = "test") 
            plt.legend()
            plt.savefig(path + '/' + 'plot')
            plt.clf()
        torch.distributed.barrier()
        best_model_path = path + '/' + 'checkpoint.pth'

        self.model.load_state_dict(torch.load(best_model_path,map_location=self.device))

        return self.model


    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')
        
        if test:
            print('loading model')
            tmp_path = os.path.join(f'./{self.args.global_path}/checkpoints/' + setting, 'checkpoint.pth')
            if os.path.isfile(tmp_path):
                self.model.load_state_dict(torch.load(tmp_path,map_location=self.device))
            else:
                tmp_lst = setting.split("_")
                tmp_lst[-5] = "3"
                if "traffic" in setting:
                    tmp_lst[-7], tmp_lst[-6] = "20", "6"
                else:
                    tmp_lst[-7], tmp_lst[-6] = "30", "3"
                tmp_path = os.path.join(f'./{self.args.global_path}/checkpoints/' + "_".join(tmp_lst), 'checkpoint.pth')
                self.model.load_state_dict(torch.load(tmp_path,map_location=self.device))
            
        preds = []
        trues = []
        # inputx = []
        folder_path = f'./{self.args.global_path}/test_results/' + setting + '/'
        # './/' + setting + '/'
        if not os.path.exists(folder_path) and self.args.rank == 0:
            os.makedirs(folder_path)
        torch.distributed.barrier()

        total_mse = []
        total_mae = []
        self.model.eval()
        mse = nn.MSELoss()
        mae = nn.L1Loss()
        start_time = time.time()
        with torch.no_grad():
            for i, (_, batch_x, batch_y, _, _) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()

                outputs, _ = self.model(batch_x)

                outputs = outputs[:, -self.args.pred_len:]
                batch_y = batch_y[:, -self.args.pred_len:].to(self.device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                mse_loss = mse(pred, true)
                mae_loss = mae(pred, true)

                total_mse.append(mse_loss.item())
                total_mae.append(mae_loss.item())
        inf_time = (time.time() - start_time)

        total_mse = np.average(total_mse)
        total_mae = np.average(total_mae)

        # self.model.train()
        
        test_info =  torch.tensor([total_mse, total_mae, inf_time]).to(self.device)
        
        if  self.args.rank != 0:
            torch.distributed.gather(test_info)
        else:
            gather_lst_test = [torch.zeros_like(test_info) for i in range(self.args.world_size)]
            torch.distributed.gather(test_info, gather_lst_test)
            gahter_test = sum(gather_lst_test) / self.args.world_size
            test_mse, test_mae, test_inftime = gahter_test.cpu().numpy()

            folder_path = f'./{self.args.global_path}/results/' + setting + '/'
            if not os.path.exists(folder_path) and self.args.rank == 0:
                os.makedirs(folder_path)

            f = open(folder_path + "result.txt","w")

            f.write('mae:{}, mse:{}, rmse:{}, mape:{}, mspe:{}'.format(test_mae, test_mse, 0,0,0))
            f.write('\n')
            f.close()

            print('mae:{}, mse:{}, rmse:{}, mape:{}, mspe:{}'.format(test_mae, test_mse, 0,0,0))


        return


    def test_(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')
        
        if test:
            print('loading model')
            tmp_path = os.path.join(f'./{self.args.global_path}/checkpoints/' + setting, 'checkpoint.pth')
            if os.path.isfile(tmp_path):
                self.model.load_state_dict(torch.load(tmp_path,map_location=self.device))
            else:
                tmp_lst = setting.split("_")
                tmp_lst[-5] = "3"
                tmp_path = os.path.join(f'./{self.args.global_path}/checkpoints/' + "_".join(tmp_lst), 'checkpoint.pth')
                self.model.load_state_dict(torch.load(tmp_path,map_location=self.device))
            
        preds = []
        trues = []
        # inputx = []
        folder_path = f'./{self.args.global_path}/test_results/' + setting + '/'
        # './/' + setting + '/'
        if not os.path.exists(folder_path) and self.args.rank == 0:
            os.makedirs(folder_path)
        torch.distributed.barrier()

        self.model.eval()
        with torch.no_grad():
            for i, (idx, batch_x, batch_y, _, _) in enumerate(test_loader):
                print(len(batch_x))
                if i == 0:
                    each_idx = idx[[0]].to(self.device)
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)

                outputs, _ = self.model(batch_x)


                # print(outputs.shape,batch_y.shape)
                outputs = outputs[:, -self.args.pred_len:]
                batch_y = batch_y[:, -self.args.pred_len:].to(self.device)
                outputs = outputs.detach()
                batch_y = batch_y.detach()
                

                pred = outputs  
                true = batch_y  

                preds.append(pred)
                trues.append(true)
                # tmp_lst.append( ((pred - true) ** 2).mean().item() )
                # inputx.append(batch_x.detach())
                if i % 20 == 0 and self.args.rank == 0:
                    input = batch_x.detach().cpu().numpy()
                    gt = np.concatenate((input[0, :, -1], true[0, :, -1].cpu().numpy()), axis=0)
                    pd = np.concatenate((input[0, :, -1], pred[0, :, -1].cpu().numpy()), axis=0)
                    visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))
        
        each_pred = torch.cat(preds, dim=0)
        each_true = torch.cat(trues, dim=0)
        # each_inputx = torch.cat(inputx, dim=0)
        if  self.args.rank != 0:
            torch.distributed.gather(each_pred)
            torch.distributed.gather(each_true)
            # torch.distributed.gather(each_inputx)
            torch.distributed.gather(each_idx)
        else:
            pred_lst = [torch.zeros_like(each_pred) for i in range(self.args.world_size)]
            true_lst = [torch.zeros_like(each_true) for i in range(self.args.world_size)]
            # inputx_lst = [torch.zeros_like(each_inputx) for i in range(self.args.world_size)]
            idx_lst = [torch.zeros_like(each_idx) for i in range(self.args.world_size)]
            torch.distributed.gather(each_pred,pred_lst)
            torch.distributed.gather(each_true,true_lst)
            # torch.distributed.gather(each_inputx,inputx_lst)
            torch.distributed.gather(each_idx,idx_lst)
            # print(tmp_lst_lst, sum(tmp_lst_lst) / 2)
            # import pdb ; pdb.set_trace()
            
            idx = torch.cat(idx_lst, dim = 0).argsort()
            preds = torch.stack(pred_lst, dim = 1)[:,idx].flatten(0,1).cpu().numpy()
            trues = torch.stack(true_lst, dim = 1)[:,idx].flatten(0,1).cpu().numpy()
            # inputx = torch.stack(inputx_lst, dim = 1)[:,idx].flatten(0,1).cpu().numpy()
            # print(inputx.shape)
            preds = np.array(preds)
            trues = np.array(trues)
            # inputx = np.array(inputx)

            preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
            trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
            # inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])

            # result save
            
            folder_path = f'./{self.args.global_path}/results/' + setting + '/'
            if not os.path.exists(folder_path) and self.args.rank == 0:
                os.makedirs(folder_path)

            
            # np.save(folder_path + 'pred.npy', preds)
            # np.save(folder_path + 'true.npy', trues)
            # np.save(folder_path + 'x.npy', inputx)

            
            f = open(folder_path + "result.txt","w")

            mae, mse, rmse, mape,  mspe = metric(preds, trues, None)
            f.write('mae:{}, mse:{}, rmse:{}, mape:{}, mspe:{}'.format(mae, mse, rmse, mape, mspe))
            f.write('\n')
            f.close()

            print('mae:{}, mse:{}, rmse:{}, mape:{}, mspe:{}'.format(mae, mse, rmse, mape, mspe))
        torch.distributed.barrier()

        return
