from data_provider.data_factory import data_provider
from models import PMformer
from .tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop, metric

import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.optim import lr_scheduler 

import seaborn as sns
import os, sys
import time

import warnings
import matplotlib.pyplot as plt
import numpy as np

warnings.filterwarnings('ignore')

class Exp_Main:
    def __init__(self, args):
        self.args = args
        self.device = "cpu" if int(self.args.gpu) < 0 else f"cuda:{self.args.gpu}"
        self.model = self._build_model().to(self.device)

    def _build_model(self):
        model_dict = {
            'PMformer' : PMformer
        }
        model = model_dict[self.args.model].Model(self.args).float()
        
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate, weight_decay = 1e-6)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, vali_data, vali_loader):
        total_mse = []
        total_mae = []
        self.model.eval()
        mse = nn.MSELoss()
        mae = nn.L1Loss()
        start_time = time.time()
        with torch.no_grad():
            for i, (_, batch_x, batch_y, _, _) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float()

                outputs, _ = self.model(batch_x)

                outputs = outputs[:, -self.args.pred_len:]
                batch_y = batch_y[:, -self.args.pred_len:].to(self.device)

                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()

                mse_loss = mse(pred, true)
                mae_loss = mae(pred, true)

                total_mse.append(mse_loss)
                total_mae.append(mae_loss)
        inf_time = (time.time() - start_time)

        total_mse = np.average(total_mse)
        total_mae = np.average(total_mae)

        self.model.train()
        return total_mse, total_mae, inf_time
    
    def alive_value(self, batch, c_in):
        # split_num = np.random.randint(self.args.split_num_s, self.args.split_num_e + 1)
        alive_num = self.args.split_num * self.args.split_mult
        # print(alive_num)
        if (c_in == self.args.split_num) or (c_in ==self.args.split_mult):
            live_index = torch.stack([torch.arange(c_in) for i in range(batch)], dim = 0)
        else:
            live_index = torch.stack([torch.randperm(c_in)[:alive_num] for i in range(batch)], dim = 0)
        # import pdb ; pdb.set_trace()
        return torch.arange(batch).unsqueeze(-1).repeat(1,live_index.shape[-1]), live_index

    def train(self, setting):
        folder_path = f'./{self.args.global_path}/results/' + setting + '/'
        # import pdb ; pdb.set_trace()
        # if os.path.exists(folder_path):
        #     sys.exit(0)

        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        test_data, test_loader = self._get_data(flag='test')
        
        path = os.path.join(self.args.global_path, self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)
        
        for_check_path = os.path.join(self.args.global_path, "results", setting, "result.txt")
        if os.path.exists(for_check_path):
            import pdb ; pdb.set_trace()
            sys.exit(0)
        
            

        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(self.args,patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()
        
        if not ('Crossformer' in self.args.model or 'PatchTST' in self.args.model):
            
            scheduler = lr_scheduler.OneCycleLR(optimizer = model_optim,
                                                steps_per_epoch = train_steps,
                                                pct_start = 0.2,
                                                epochs = self.args.train_epochs,
                                                max_lr = self.args.learning_rate)
        else:
            print('여기')
        all_training_loss = []
        all_val_loss = []
        all_test_loss = []

        saved_val_loss = float('inf')
        folder_path = f'./{self.args.global_path}/results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        with open(f"{folder_path}/record.txt","w") as f:
            f.write("\n".join([f"{k}:{v}"for k,v in vars(self.args).items()]))
            f.write("\n")
            f.write("\n")
        start_time = time.time()
        


        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()

            for i,(_, batch_x, batch_y, _, _) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)

                if "PMformer" in self.args.model:
                    idx1, idx2 = self.alive_value(len(batch_x), self.args.enc_in)
                    outputs, atten_time = self.model(batch_x.transpose(1,2)[idx1, idx2].transpose(1,2), idx = idx2)
                    outputs = outputs[:, -self.args.pred_len:]
                    batch_y = batch_y[:, -self.args.pred_len:].to(self.device)
                    batch_y = batch_y.transpose(1,2)[idx1, idx2].transpose(1,2)
                    
                else:
                    outputs, atten_time = self.model(batch_x)
                    outputs = outputs[:, -self.args.pred_len:]
                    batch_y = batch_y[:, -self.args.pred_len:].to(self.device)
                    # import pdb  ; pdb.set_trace()
                loss = criterion(outputs, batch_y)
                train_loss.append(loss.item())
                # print(loss)
                # break
            
                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                loss.backward()
                model_optim.step()
                # print(loss)
                # if self.args.data == 'traffic':
                #     # print("여기2")
                #     adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False)
                #     scheduler.step()
                    

            # print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            speed = format((time.time() - time_now) / (len(train_loss) * (epoch + 1)), ".4f")
            train_loss = np.average(train_loss)
            vali_mse, vali_mae, vali_inftime = self.vali(vali_data, vali_loader)
            test_mse, test_mae, test_inftime = self.vali(test_data, test_loader)
            

            time_diff = int(time.time() - start_time)
            time_diff_h = f"{time_diff//3600}h {(time_diff%3600) // 60 }m {(time_diff%3600) % 60}s"
            p_text = "Epoch: {0}, Steps: {1} | Train Loss: {2:.4f}, Vali Loss: ({3:.4f},{4:.4f}) Test Loss: ({5:.4f},{6:.4f}) | Time: {7}({8}, {9}, {10:.4f}, {11:.4f})".format(
                epoch + 1, train_steps, train_loss, vali_mse, vali_mae, test_mse, test_mae, time_diff, time_diff_h, speed, vali_inftime, test_inftime)
            with open(f"{folder_path}/record.txt","a") as f:
                f.write(p_text)
                f.write(f"\n")
            if self.args.data in ['electricity', 'traffic']:
                torch.save(self.model.state_dict(), path + '/' + f'checkpoint_{epoch+1}.pth')

            if vali_mse < saved_val_loss:
                saved_val_loss = vali_mse
                # with open(f"{folder_path}/record.txt","w") as f:
                #     text = "Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                # epoch + 1, train_steps, train_loss, vali_loss, test_loss)
                #     f.write(text)
                #     f.write(f"\n")
            # else:

            # print('Updating learning rate to {}'.format(scheduler.get_last_lr()[0]))
            all_training_loss.append(train_loss)
            all_test_loss.append(test_mse)
            all_val_loss.append(vali_mse)

            # adjust_learning_rate(model_optim, scheduler, epoch + 1, self.args, printout=False)
            # scheduler.step()

            print(p_text)
            early_stopping(vali_mse, self.model, path)
            if early_stopping.early_stop:
                break

        best_model_path = path + '/' + 'checkpoint.pth'
        sns.lineplot(x = range(len(all_training_loss)), y = all_training_loss, label = "train") 
        sns.lineplot(x = range(len(all_val_loss)), y = all_val_loss, label = "val") 
        sns.lineplot(x = range(len(all_test_loss)), y = all_test_loss, label = "test") 
        plt.legend()
        plt.savefig(path + '/' + 'plot')
        plt.clf()

        best_model_path = path + '/' + 'checkpoint.pth'
        if self.args.save_checkpoints:
            self.model.load_state_dict(torch.load(best_model_path,map_location=self.device))

        return self.model

    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='test')
        
        if test:
            print('loading model')
            self.model.load_state_dict(torch.load(os.path.join(f'./{self.args.global_path}/checkpoints/' + setting, 'checkpoint.pth'),map_location=self.device))
            
        preds = []
        trues = []
        inputx = []
        folder_path = f'./{self.args.global_path}/test_results/' + setting + '/'
        # './/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        self.model.eval()
        with torch.no_grad():
            for i, (_, batch_x, batch_y, _, _) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)

                outputs, _ = self.model(batch_x)


                # print(outputs.shape,batch_y.shape)
                outputs = outputs[:, -self.args.pred_len:]
                batch_y = batch_y[:, -self.args.pred_len:].to(self.device)
                outputs = outputs.detach().cpu().numpy()
                batch_y = batch_y.detach().cpu().numpy()

                pred = outputs  
                true = batch_y  

                preds.append(pred)
                trues.append(true)
                inputx.append(batch_x.detach().cpu().numpy())
                if i % 20 == 0:
                    input = batch_x.detach().cpu().numpy()
                    gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
                    pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
                    visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))

        preds = np.array(preds)
        trues = np.array(trues)
        inputx = np.array(inputx)

        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
        inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])

        # result save
        
        folder_path = f'./{self.args.global_path}/results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        
        np.save(folder_path + 'pred.npy', preds)
        np.save(folder_path + 'true.npy', trues)
        np.save(folder_path + 'x.npy', inputx)

        
        f = open(folder_path + "result.txt","w")

        mae, mse, rmse, mape,  mspe = metric(preds, trues, None)
        f.write('mae:{}, mse:{}, rmse:{}, mape:{}, mspe:{}'.format(mae, mse, rmse, mape, mspe))
        f.write('\n')
        f.close()

        print('mae:{}, mse:{}, rmse:{}, mape:{}, mspe:{}'.format(mae, mse, rmse, mape, mspe))
        

        return
