from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate, visual
from utils.losses import RmseLoss
from utils.metrics import metric
from utils.constants_LamaH import Constants
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
import pandas as pd

current_dir = os.path.dirname(os.path.abspath(__file__))
warnings.filterwarnings('ignore')


class Exp_MeKong(Exp_Basic):
    def __init__(self, args, station, setting):
        super(Exp_MeKong, self).__init__(args, station, setting)
        self.constants = Constants(args)
        self.model = self._build_model()
        self.ckpt_path = os.path.join(self.args.checkpoints, self.setting, f"wise_{self.args.run_wise}",
                                      self.target_station)
        if not os.path.exists(self.ckpt_path):
            os.makedirs(self.ckpt_path)
        self.test_results_path = os.path.join('test_results', setting, f"wise_{self.args.run_wise}",
                                              self.target_station)
        if not os.path.exists(self.test_results_path):
            os.makedirs(self.test_results_path)
        self.results_path = os.path.join(self.args.results, setting, f"wise_{self.args.run_wise}", self.target_station)
        if not os.path.exists(self.results_path):
            os.makedirs(self.results_path)

        num_param = self.numel(self.model, True)

        print('model number of parameters:', num_param)

    def numel(self, m: torch.nn.Module, only_trainable: bool = False):
        """
        Returns the total number of parameters used by `m` (only counting
        shared parameters once); if `only_trainable` is True, then only
        includes parameters with `requires_grad = True`
        """
        parameters = list(m.parameters())
        if only_trainable:
            parameters = [p for p in parameters if p.requires_grad]
        unique = {p.data_ptr(): p for p in parameters}.values()
        return sum(p.numel() for p in unique)

    def _build_model(self):
        model = self.model_dict[self.args.model].Model(self.args).float().to(self.device)
        return model

    def _get_data(self):
        is_GNN = False
        if self.args.run_wise == 'all' and self.args.model not in self.GNN_list:
            batch_flag = 'mini_batch'
            station_list = None
        else:
            batch_flag = 'full_batch'
            if self.args.model in self.GNN_list:
                station_list = None
                is_GNN = True
            else:
                station_list = [self.target_station]

        (train_dataset, train_loader,
         val_dataset, val_loader,
         test_dataset, test_loader) = data_provider(
            self.args,
            self.verbose,
            f'{self.args.data}_mts',
            self.target_station,
            station_list=station_list,
            batch_flag=batch_flag,
            is_GNN=is_GNN
        )
        return train_dataset, train_loader, val_dataset, val_loader, test_dataset, test_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        if self.args.loss == 'MSE':
            criterion = nn.MSELoss()
        elif self.args.loss == 'RMSE':
            criterion = RmseLoss()
        else:
            raise NotImplementedError("loss function is not implemented, only accept in ['MSE', 'RMSE']")
        return criterion

    def vali(self, criterion, data_loader):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, cycle_index, station_ids, time_idx) in enumerate(data_loader):
                if self.args.model in self.GNN_list:
                    batch_x = batch_x[:, :, :, -1]
                    batch_y = batch_y[:, :, :, -1]
                else:
                    if self.args.run_wise != 'all':
                        batch_x = batch_x[:, :, -1, :]
                        batch_y = batch_y[:, :, -1, :]
                    batch_y = batch_y[:, :, -1:]

                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y.float().to(self.device, non_blocking=True)
                batch_x_mark = batch_x_mark.float().to(self.device, non_blocking=True)
                batch_y_mark = batch_y_mark.float().to(self.device, non_blocking=True)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :], device=self.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1)
                # encoder - decoder
                if self.args.model in self.cycle_model_list:
                    cycle_index = cycle_index.to(self.device)
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, cycle_index=cycle_index)
                else:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                if self.args.model in self.GNN_list:
                    outputs = outputs[:, -self.args.pred_len:, :]
                    batch_y = batch_y[:, -self.args.pred_len:, :]
                else:
                    outputs = outputs[:, -self.args.pred_len:, -1:]
                    batch_y = batch_y[:, -self.args.pred_len:, -1:]
                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()
                loss = criterion(pred, true)
                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def train(self):
        (self.train_data, self.train_loader,
         self.vali_data, self.vali_loader,
         self.test_data, self.test_loader) = self._get_data()

        time_now = time.time()

        train_steps = len(self.train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=self.verbose, delta=1e-6)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, cycle_index, station_ids, time_idx) in enumerate(
                    self.train_loader):
                iter_count += 1
                model_optim.zero_grad()
                if self.args.model in self.GNN_list:
                    batch_x = batch_x[:, :, :, -1]
                    batch_y = batch_y[:, :, :, -1]
                else:
                    if self.args.run_wise != 'all':
                        batch_x = batch_x[:, :, -1, :]
                        batch_y = batch_y[:, :, -1, :]
                    batch_y = batch_y[:, :, -1:]

                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y.float().to(self.device, non_blocking=True)
                batch_x_mark = batch_x_mark.float().to(self.device, non_blocking=True)
                batch_y_mark = batch_y_mark.float().to(self.device, non_blocking=True)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :], device=self.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1)

                # encoder - decoder
                if self.args.model in self.cycle_model_list:
                    cycle_index = cycle_index.to(self.device)
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, cycle_index=cycle_index)
                else:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                if self.args.model in self.GNN_list:
                    outputs = outputs[:, -self.args.pred_len:, :]
                    batch_y = batch_y[:, -self.args.pred_len:, :]
                else:
                    outputs = outputs[:, -self.args.pred_len:, -1:]
                    batch_y = batch_y[:, -self.args.pred_len:, -1:]
                loss = criterion(outputs, batch_y)
                train_loss.append(loss.item())

                print_iter = int(len(self.train_data) / (self.args.batch_size * 5))
                print_iter_flag = (i + 1) % print_iter == 0 if print_iter != 0 else False
                if print_iter_flag:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                loss.backward()
                model_optim.step()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(criterion, self.vali_loader)
            test_loss = self.vali(criterion, self.test_loader)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, self.ckpt_path)
            if early_stopping.early_stop:
                print("Early stopping")
                break

            adjust_learning_rate(model_optim, epoch + 1, self.args, self.verbose)

        best_model_path = os.path.join(str(self.ckpt_path), 'checkpoint.pth')
        self.model.load_state_dict(torch.load(str(best_model_path)))

    def test(self, test=0):
        if test == 1:
            (self.train_data, self.train_loader,
             self.vali_data, self.vali_loader,
             self.test_data, self.test_loader) = self._get_data()

        print('loading model')
        self.model = self._build_model().to(self.device)
        self.model.load_state_dict(torch.load(os.path.join(str(self.ckpt_path), 'checkpoint.pth')))

        num_times = self.test_loader.dataset.data_x.shape[0] - self.args.seq_len - self.args.pred_len + 1
        preds = np.zeros((num_times, self.args.pred_len, self.args.num_stations))
        trues = np.zeros((num_times, self.args.pred_len, self.args.num_stations))

        self.model.eval()
        with (torch.no_grad()):
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, cycle_index, station_ids, time_idx) in enumerate(self.test_loader):
                if self.args.model in self.GNN_list:
                    batch_x = batch_x[:, :, :, -1]
                    batch_y = batch_y[:, :, :, -1]
                else:
                    if self.args.run_wise != 'all':
                        batch_x = batch_x[:, :, -1, :]
                        batch_y = batch_y[:, :, -1, :]
                    batch_y = batch_y[:, :, -1:]

                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y.float().to(self.device, non_blocking=True)
                batch_x_mark = batch_x_mark.float().to(self.device, non_blocking=True)
                batch_y_mark = batch_y_mark.float().to(self.device, non_blocking=True)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :], device=self.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1)
                # encoder - decoder
                if self.args.model in self.cycle_model_list:
                    cycle_index = cycle_index.to(self.device)
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, cycle_index=cycle_index)
                else:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                if self.args.model in self.GNN_list:
                    # [Batch size, time steps, num stations]
                    outputs = outputs[:, -self.args.pred_len:, :]
                    batch_y = batch_y[:, -self.args.pred_len:, :]
                else:
                    # [Batch size, time steps, water flow]
                    outputs = outputs[:, -self.args.pred_len:, -1:]
                    batch_y = batch_y[:, -self.args.pred_len:, -1:]
                outputs = outputs.detach().cpu().numpy()  # [B, H, C]
                batch_y = batch_y.detach().cpu().numpy()

                if self.args.inverse:
                    if self.args.model in self.GNN_list:
                        outputs = self.test_data.inverse_transform(outputs, input_station_ids=station_ids, is_GNN=True)
                        batch_y = self.test_data.inverse_transform(batch_y, input_station_ids=station_ids, is_GNN=True)
                    else:
                        outputs = self.test_data.inverse_transform(outputs, input_station_ids=station_ids)
                        batch_y = self.test_data.inverse_transform(batch_y, input_station_ids=station_ids)

                for idx in range(len(station_ids)):
                    n = station_ids[idx].item()
                    t = time_idx[idx].item()
                    if self.args.model in self.GNN_list:
                        preds[t, :, :] = outputs[idx, :, :]
                        trues[t, :, :] = batch_y[idx, :, :]
                    else:
                        preds[t, :, n] = outputs[idx, :, -1]
                        trues[t, :, n] = batch_y[idx, :, -1]
                # if i % 20 == 0:
                # input = batch_x[:, :, -1:].detach().cpu().numpy()
                # if self.args.inverse:
                #     input = self.test_data.inverse_transform(input, input_station_ids=station_ids)
                # gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
                # pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
                # visual(gt, pd, os.path.join(str(self.test_results_path), str(i) + '.pdf'))

        print('test shape:', preds.shape, trues.shape)

        all_mae = []
        all_mse = []
        all_rmse = []
        all_mape = []
        all_mspe = []
        all_r2 = []
        all_nse = []
        all_pbias = []
        all_kge = []
        all_flv = []
        all_fhv = []
        for node_i in range(preds.shape[-1]):
            mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = metric(preds[:, :, node_i], trues[:, :, node_i])
            all_mae.append(mae)
            all_mse.append(mse)
            all_rmse.append(rmse)
            all_mape.append(mape)
            all_mspe.append(mspe)
            all_r2.append(r2)
            all_nse.append(nse)
            all_pbias.append(pbias)
            all_kge.append(kge)
            all_flv.append(flv)
            all_fhv.append(fhv)

        avg_mae = np.mean(all_mae)
        avg_mse = np.mean(all_mse)
        avg_rmse = np.mean(all_rmse)
        avg_mape = np.mean(all_mape)
        avg_mspe = np.mean(all_mspe)
        avg_r2 = np.mean(all_r2)
        avg_nse = np.mean(all_nse)
        avg_pbias = np.mean(all_pbias)
        avg_kge = np.mean(all_kge)
        avg_flv = np.mean(all_flv)
        avg_fhv = np.mean(all_fhv)
        median_mae = np.median(all_mae)
        median_mse = np.median(all_mse)
        median_rmse = np.median(all_rmse)
        median_mape = np.median(all_mape)
        median_mspe = np.median(all_mspe)
        median_r2 = np.median(all_r2)
        median_nse = np.median(all_nse)
        median_pbias = np.median(all_pbias)
        median_kge = np.median(all_kge)
        median_flv = np.median(all_flv)
        median_fhv = np.median(all_fhv)

        print(f'avg_nse:{avg_nse}, avg_pbias:{avg_pbias}, avg_kge:{avg_kge}, avg_flv:{avg_flv}, '
              f'avg_fhv:{avg_fhv}, avg_mae: {avg_mae}, avg_mse: {avg_mse}, avg_rmse: {avg_rmse}, '
              f'avg_mape: {avg_mape}, avg_mspe: {avg_mspe}, avg_r2: {avg_r2}')
        print(f'median_nse: {median_nse}, median_pbias: {median_pbias}, median_kge: {median_kge}, '
              f'median_flv: {median_flv}, median_fhv: {median_fhv}, median_mae: {median_mae}, '
              f'median_mse: {median_mse}, median_rmse: {median_rmse}, median_mape: {median_mape}, '
              f'median_mspe: {median_mspe}, median_r2: {median_r2}')
        print(f"{avg_nse},{median_nse},{avg_rmse},{avg_mae}")
        f = open("result_initial.txt", 'a')
        f.write(self.setting + "  \n")
        f.write(f'avg_nse:{avg_nse}, avg_pbias:{avg_pbias}, avg_kge:{avg_kge}, avg_flv:{avg_flv}, '
                f'avg_fhv:{avg_fhv}, avg_mae: {avg_mae}, avg_mse: {avg_mse}, avg_rmse: {avg_rmse}, '
                f'avg_mape: {avg_mape}, avg_mspe: {avg_mspe}, avg_r2: {avg_r2}')
        f.write('\n')
        f.write(f'median_nse: {median_nse}, median_pbias: {median_pbias}, median_kge: {median_kge}, '
                f'median_flv: {median_flv}, median_fhv: {median_fhv}, median_mae: {median_mae}, '
                f'median_mse: {median_mse}, median_rmse: {median_rmse}, median_mape: {median_mape}, '
                f'median_mspe: {median_mspe}, median_r2: {median_r2}')
        f.write(f"{avg_nse},{median_nse},{avg_rmse},{avg_mae}")
        f.write('\n')
        f.write('\n')
        f.close()

        np.save(os.path.join(str(self.results_path), 'metrics.npy'),
                np.array([avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv,
                          avg_fhv]))
        np.save(os.path.join(str(self.results_path), 'pred.npy'), preds)
        np.save(os.path.join(str(self.results_path), 'true.npy'), trues)
        return avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv, \
            median_mae, median_mse, median_rmse, median_mape, median_mspe, median_r2, median_nse, median_pbias, median_kge, median_flv, median_fhv
