from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping_mekong_local_global, adjust_learning_rate, visual
from utils.metrics import metric
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
from copy import deepcopy
import json
from station_spec_search_hyperparam import StationSpecSearchHyperparam
from latex.pred_table import MakeTable
current_dir = os.path.dirname(os.path.abspath(__file__))
warnings.filterwarnings('ignore')


class Exp_MeKong(Exp_Basic):
    """
    Global:
    If target station is B, whose link stations are [A, C], build models:
     B: L_B -> H_B1
     A_2_B: L_A -> H_B2
     C_2_B: L_C -> H_B3
     pred_from_AC = avg(H_B2 + H_B3)
     prediction = alpha * H_B1 + (1 - alpha) * pred_from_AC
    """

    def __init__(self, args, station, setting):
        super(Exp_MeKong, self).__init__(args, station, setting)
        if self.args.other_station == 'child_parent':
            other_list = (self.constants.child_stations_dict[self.target_station] +
                          self.constants.parent_stations_dict[self.target_station])
            self.child_list = self.constants.child_stations_dict[self.target_station]
            self.parent_list = self.constants.parent_stations_dict[self.target_station]
        elif self.args.other_station == 'lag_correlation':
            other_list = self.constants.lag_correlation_dict[self.target_station]
            all_child_list = self.constants.child_stations_dict[self.target_station]
            all_parent_list = self.constants.parent_stations_dict[self.target_station]
            self.child_list = []
            self.parent_list = []
            for station in other_list:
                if station in all_child_list:
                    self.child_list.append(station)
                elif station in all_parent_list:
                    self.parent_list.append(station)
        elif self.args.other_station == 'flow':
            other_list = self.constants.get_other_list(self.target_station, 'flow')
            all_child_list = self.constants.child_stations_dict[self.target_station]
            all_parent_list = self.constants.parent_stations_dict[self.target_station]
            self.child_list = []
            self.parent_list = []
            for station in other_list:
                if station in all_child_list:
                    self.child_list.append(station)
                elif station in all_parent_list:
                    self.parent_list.append(station)
        else:
            raise NotImplementedError
        self.stations_list = other_list
        self.stations_list.append(self.target_station)
        self.models_dic = self._build_models_dict()
        self.alpha = self.args.alpha
        num_param = 0
        for model_name in self.models_dic.keys():
            num_param += self.numel(self.models_dic[model_name], True)
        if self.verbose:
            print('model number of parameters:', num_param)

    def numel(self, m: torch.nn.Module, only_trainable: bool = False):
        """
        Returns the total number of parameters used by `m` (only counting
        shared parameters once); if `only_trainable` is True, then only
        includes parameters with `requires_grad = True`
        """
        parameters = list(m.parameters())
        if only_trainable:
            parameters = [p for p in parameters if p.requires_grad]
        unique = {p.data_ptr(): p for p in parameters}.values()
        return sum(p.numel() for p in unique)

    def _build_models_dict(self):
        models_dic = {}
        models_dic[self.target_station] = self.model_dict[self.args.model].Model(self.args).float().to(self.device)
        for name in self.stations_list:
            if name == self.target_station:
                continue
            models_dic[name] = self.model_dict[self.args.model].Model(self.args).float().to(self.device)
        return models_dic

    def _get_data(self, station_list=None):
        if station_list is None:
            station_list = self.stations_list
        data_loader_name = f'{self.args.data}_mts'

        (train_dataset, train_loader,
         val_dataset, val_loader,
         test_dataset, test_loader) = data_provider(
            args=self.args,
            verbose=self.verbose,
            data_loader_name=data_loader_name,
            target_station=self.target_station,
            station_list=station_list,
            batch_flag='full_batch'
        )
        return train_dataset, train_loader, val_dataset, val_loader, test_dataset, test_loader

    def _select_optimizer(self, model, lr=None):
        if lr is None:
            lr = self.args.learning_rate
        if isinstance(model, list):
            model_optim = optim.Adam(model, lr=lr)
        else:
            model_optim = optim.Adam(model.parameters(), lr=lr)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, criterion, dataloader):
        total_loss = []
        for name in self.models_dic.keys():
            self.models_dic[name].eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, cycle_index, station_ids, time_idx) in enumerate(dataloader):
                # FlowNet use full batch, batch_x shape:[B, L, N, C], batch_y shape:[B, H, N, C]
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y[:, :, :, -1:].float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                target_x = deepcopy(batch_x[:, :, -1, :])
                target_y = deepcopy(batch_y[:, :, -1, :])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                target_out = self.models_dic[self.target_station](target_x, batch_x_mark, dec_inp, batch_y_mark)
                # other_out = self._get_other_out(input_list)
                # real_out is the output for inference: avg from target and others
                real_out = target_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                real_out = real_out.detach().cpu()
                target_y = target_y.detach().cpu()
                loss = criterion(real_out, target_y)
                total_loss.append(loss)
        total_loss = np.average(total_loss)
        # for name in self.models_dic.keys():
        #     self.models_dic[name].train()
        return total_loss

    def _train_global(self, loop=0):
        # other_station_list = self.other_station_list
        path = os.path.join(self.args.checkpoints, self.setting, f"wise_{self.args.run_wise}", 'global', str(self.target_station))
        if not os.path.exists(path):
            os.makedirs(path, exist_ok=True)
        for name in self.models_dic.keys():
            ckpt_name = f'checkpoint_{name}.pth'
            torch.save(self.models_dic[name].state_dict(), os.path.join(str(path), ckpt_name))
        # if not other_station_list:
        #     return None

        time_now = time.time()

        train_steps = len(self.train_loader)
        early_stopping = EarlyStopping_mekong_local_global(patience=self.args.patience, verbose=self.verbose,
                                                           delta=1e-6)

        model_optim_dic = {}
        if self.child_list:
            opt_params_child = [{'params': self.models_dic[m].parameters()} for m in self.child_list]
            model_optim_dic['child'] = self._select_optimizer(
                opt_params_child, lr=self.args.global_lr * (self.args.global_lr_factor ** loop)
            )
        if self.parent_list:
            opt_params_parent = [{'params': self.models_dic[m].parameters()} for m in self.parent_list]
            model_optim_dic['parent'] = self._select_optimizer(
                opt_params_parent, lr=self.args.global_lr * (self.args.global_lr_factor ** loop)
            )
        opt_params_target = [{'params': self.models_dic[self.target_station].parameters()}]
        model_optim_dic['target'] = self._select_optimizer(
            opt_params_target, lr=self.args.global_lr * (self.args.global_lr_factor ** loop)
        )
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            epoch_time = time.time()
            for name in self.models_dic.keys():
                self.models_dic[name].train()
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, cycle_index, station_ids, time_idx) in enumerate(
                    self.train_loader):
                # FlowNet use full batch, batch_x shape:[B, L, N, C], batch_y shape:[B, H, N, C]
                iter_count += 1
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y[:, :, :, -1:].float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                target_x = deepcopy(batch_x[:, :, -1, :])
                target_y = deepcopy(batch_y[:, :, -1, :])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                target_out = self.models_dic[self.target_station](target_x, batch_x_mark, dec_inp, batch_y_mark)
                target_out = target_out[:, -self.args.pred_len:, -1:]
                other_out_list = []
                # update child
                if self.child_list:
                    all_child_out = []
                    for x_i, station in enumerate(self.child_list):
                        station_i = self.stations_list.index(station)
                        input_x = deepcopy(batch_x[:, :, station_i, :])
                        child_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark,
                                                             target_x=target_x)
                        all_child_out.append(child_out[:, -self.args.pred_len:, -1:])
                    sum_child_out = torch.stack(all_child_out, dim=0).sum(dim=0)
                    model_optim_dic['child'].zero_grad()
                    loss = (self.alpha * criterion(sum_child_out[:, -self.args.pred_len:, -1:],
                                                   target_y[:, -self.args.pred_len:, -1:])
                            + (1 - self.alpha) * criterion(target_out[:, -self.args.pred_len:, -1:].detach(),
                                                           sum_child_out))
                    loss.backward(retain_graph=True)
                    model_optim_dic['child'].step()

                # update parent
                if self.parent_list:
                    all_parent_out = []
                    for x_i, station in enumerate(self.parent_list):
                        station_i = self.stations_list.index(station)
                        input_x = deepcopy(batch_x[:, :, station_i, :])
                        parent_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark,
                                                              target_x=target_x)
                        all_parent_out.append(parent_out[:, -self.args.pred_len:, -1:])
                    sum_parent_out = torch.stack(all_parent_out, dim=0).sum(dim=0)
                    model_optim_dic['parent'].zero_grad()
                    loss = (self.alpha * criterion(sum_parent_out[:, -self.args.pred_len:, -1:],
                                                   target_y[:, -self.args.pred_len:, -1:])
                            + (1 - self.alpha) * criterion(target_out[:, -self.args.pred_len:, -1:].detach(),
                                                           sum_parent_out))
                    loss.backward(retain_graph=True)
                    model_optim_dic['parent'].step()

                # update target
                if self.child_list and self.parent_list:
                    avg_other_out = torch.stack([sum_child_out, sum_parent_out], dim=0).mean(dim=0)
                elif self.child_list and self.parent_list == []:
                    avg_other_out = sum_child_out
                elif self.child_list == [] and self.parent_list:
                    avg_other_out = sum_parent_out
                else:
                    avg_other_out = target_out
                model_optim_dic['target'].zero_grad()
                loss = (self.alpha * criterion(target_out[:, -self.args.pred_len:, -1:],
                                               target_y[:, -self.args.pred_len:, -1:])
                        + (1 - self.alpha) * criterion(avg_other_out[:, -self.args.pred_len:, -1:].detach(),
                                                       target_out[:, -self.args.pred_len:, -1:]))
                loss.backward()
                model_optim_dic['target'].step()
                real_out = target_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                loss = criterion(real_out, target_y)
                train_loss.append(loss.item())

                if self.verbose:
                    if (i + 1) % 100 == 0:
                        print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                        speed = (time.time() - time_now) / iter_count
                        left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                        print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                        iter_count = 0
                        time_now = time.time()

            if self.verbose:
                print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(criterion, self.vali_loader)
            # test_loss = self.vali(criterion, self.test_loader)
            if self.verbose:
                print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}".format(
                    epoch + 1, train_steps, train_loss, vali_loss))
            early_stopping(vali_loss, self.models_dic, str(path))
            if early_stopping.early_stop:
                if self.verbose:
                    print("Early stopping")
                break

            for adj_lr_i, name in enumerate(model_optim_dic.keys()):
                if adj_lr_i == 0:
                    verbose = self.verbose
                else:
                    verbose = False
                adjust_learning_rate(model_optim_dic[name], epoch + 1, self.args, verbose)

        for name in self.models_dic.keys():
            ckpt_name = f'checkpoint_{name}.pth'
            best_model_path = os.path.abspath(os.path.join(str(path), ckpt_name))
            self.models_dic[name].load_state_dict(torch.load(best_model_path))

    def train(self):
        try:
            for name in self.models_dic.keys():
                ckpt_name = 'checkpoint.pth'
                self.models_dic[name].load_state_dict(torch.load(
                    os.path.join('./checkpoints/' + self.setting, "wise_all", 'all', ckpt_name)
                ))
        except Exception as e:
            raise ValueError('no checkpoint for global to load')

        station_list = self.child_list + self.parent_list + [self.target_station]
        (self.train_data, self.train_loader,
         self.vali_data, self.vali_loader,
         self.test_data, self.test_loader) = self._get_data(station_list)

        # Global
        for loop in range(self.args.global_loop):
            self._train_global(loop)
            mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv, \
                mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = self.test(train_phase='global', loop=loop, load_data=False)
            # all_duration_time = time.time() - local_start_time
            # print(f"Global loop {loop + 1}, {self.target_station}, time: {all_duration_time:.2f}s, "
            #       f"nse:{nse}, pbias:{pbias}, kge:{kge}, flv:{flv}, "
            #       f"fhv:{fhv}, mae: {mae}, mse: {mse}, rmse: {rmse}, mape: {mape}, mspe: {mspe}, r2: {r2}")
            StationSpecSearchHyperparam(self.args, self.target_station).save_station_results(rmse, mae, mape, nse, loop + 1)
            if self.args.ablation != 'None':
                csv_save_path = os.path.join(
                    './ablation/', self.args.data, self.args.target,
                    f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                    f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                    f"test{self.args.test_start_time}_{self.args.test_end_time}",
                    f'pl{self.args.pred_len}', self.args.run_type, f"wise_{self.args.run_wise}", self.args.model,
                    self.args.ablation, f'global_{loop}'
                )
            else:
                csv_save_path = os.path.join(
                    './baselines_results/', self.args.data, self.args.target,
                    f"train{self.args.train_start_time}_{self.args.train_end_time}_"
                    f"vali{self.args.vali_start_time}_{self.args.vali_end_time}_"
                    f"test{self.args.test_start_time}_{self.args.test_end_time}",
                    f'pl{self.args.pred_len}', self.args.run_type, f"wise_{self.args.run_wise}", self.args.model, f'global_{loop}'
                )
            if not os.path.exists(csv_save_path):
                os.makedirs(csv_save_path, exist_ok=True)
            make_table = MakeTable(self.args, self.args.model, self.args.run_type, self.args.results, self.setting,
                                   csv_save_path=csv_save_path, target_station=self.target_station, train_phase='global', loop=loop)
            make_table.save_to_csv()
            # make_table.collect_stations_rmse_results(phase=phase)
            if self.verbose:
                print("Make results table done")

    def test(self, test=0, train_phase='global', loop=None, load_data=True):
        if loop is None:
            loop = self.args.global_loop + self.args.start_loop
        if load_data:
            station_list = self.child_list + self.parent_list + [self.target_station]
            (self.train_data, self.train_loader,
             self.vali_data, self.vali_loader,
             self.test_data, self.test_loader) = self._get_data(station_list)

        if self.verbose:
            print('loading model')
        self.models_dic = self._build_models_dict()
        for name in self.models_dic.keys():
            ckpt_name = f'checkpoint_{name}.pth'
            self.models_dic[name].load_state_dict(torch.load(
                os.path.join('./checkpoints/' + self.setting, f"wise_{self.args.run_wise}", train_phase, self.target_station, ckpt_name)
            ))

        preds = []
        trues = []
        inputs = []
        stations_preds = {}
        for station_i, station in enumerate(self.stations_list):
            stations_preds[station] = []
        folder_path = './test_results/' + self.setting + f'/wise_{self.args.run_wise}/{train_phase}/{self.target_station}/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path, exist_ok=True)
        for station in self.stations_list:
            station_out_folder_path = './test_results/' + self.setting + f'/wise_{self.args.run_wise}/{train_phase}/{self.target_station}/' + f'{station}/'
            if not os.path.exists(station_out_folder_path):
                os.makedirs(station_out_folder_path, exist_ok=True)

        for name in self.models_dic.keys():
            self.models_dic[name].eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, cycle_index, station_ids, time_idx) in enumerate(self.test_loader):
                # FlowNet use full batch, batch_x shape:[B, L, N, C], batch_y shape:[B, H, N, C]
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y[:, :, :, -1:].float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                target_x = deepcopy(batch_x[:, :, -1, :])
                target_y = deepcopy(batch_y[:, :, -1, :])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                target_out = self.models_dic[self.target_station](target_x, batch_x_mark, dec_inp, batch_y_mark)
                target_out = target_out[:, -self.args.pred_len:, -1:]
                # other_out = self._get_other_out(input_list)
                # real_out is the output for inference: avg from target and others
                real_out = target_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                real_out = real_out.detach().cpu().numpy()
                target_y = target_y.detach().cpu().numpy()
                if self.args.inverse:
                    real_out = self.test_data.inverse_transform(real_out)
                    target_y = self.test_data.inverse_transform(target_y)
                real_out = real_out[:, :, -1:]
                target_y = target_y[:, :, -1:]

                pred = real_out
                true = target_y
                preds.append(pred)
                trues.append(true)
                # if i % 20 == 0:
                for station_i, station in enumerate(self.stations_list):
                    station_x = deepcopy(batch_x[:, :, station_i, :])
                    station_out = self.models_dic[station](station_x, batch_x_mark, dec_inp, batch_y_mark,
                                                           target_x=target_x)
                    station_out = station_out[:, -self.args.pred_len:, -1:]
                    station_out = station_out.detach().cpu().numpy()
                    if self.args.inverse:
                        station_out = self.test_data.inverse_transform(station_out)
                    station_out = station_out[:, :, -1:]
                    stations_preds[station].append(station_out)
                input = target_x[:, :, -1:].detach().cpu().numpy()
                if self.args.inverse:
                    input = self.test_data.inverse_transform(input)
                input = input[:, :, -1:]
                inputs.append(input)
                # for bs_i in range(input.shape[0]):
                #     if bs_i % 7 == 0:
                #         for station_i, station in enumerate(self.stations_list):
                #             gt = np.concatenate((input[bs_i, :, -1], true[bs_i, :, -1]), axis=0)
                #             pd = np.concatenate((input[bs_i, :, -1], station_out[bs_i, :, -1]), axis=0)
                #             station_out_folder_path = ('./test_results/' + setting +
                #                                        f'/{self.phase}/{train_phase}/' + f'{station}/')
                #             visual(gt, pd, os.path.join(station_out_folder_path, f'{i}_{bs_i}.pdf'))

        preds = np.concatenate(preds, axis=0)  # [B, H, N]
        trues = np.concatenate(trues, axis=0)  # [B, H, N]
        inputs = np.concatenate(inputs, axis=0)  # [B, L, N]
        if self.verbose:
            print('test shape:', preds.shape, trues.shape)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])  # [B, H, N]
        trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])  # [B, H, N]
        inputs = inputs.reshape(-1, inputs.shape[-2], inputs.shape[-1])  # [B, L, N]
        if self.verbose:
            print('test shape:', preds.shape, trues.shape)

        # for station_i, station in enumerate(self.stations_list):
        #     stations_pred = stations_preds[station]
        #     stations_pred = np.concatenate(stations_pred, axis=0)
        #     stations_pred = stations_pred.reshape(-1, stations_pred.shape[-2], stations_pred.shape[-1])
        #     stations_mae, stations_mse, stations_rmse, stations_mape, stations_mspe, stations_nse = metric(stations_pred, trues)
        #     station_out_folder_path = './test_results/' + setting + f'/{self.phase}/{train_phase}/' + f'{station}/'
        #     with open(os.path.join(station_out_folder_path, f"rmse.txt"), "w") as text_file:
        #         text_file.write(f"RMSE {stations_rmse}")
        #     if self.args.out_station_preds:
        #         folder_path = 'out_preds/' + setting + f'/{self.phase}/'
        #         if not os.path.exists(folder_path):
        #             os.makedirs(folder_path)
        #         np.save(folder_path + f'{station}_metrics.npy', np.array([stations_mae, stations_mse, stations_rmse, stations_mape, stations_mspe, stations_nse]))
        #         np.save(folder_path + f'{station}_pred.npy', stations_pred)

        # result save
        folder_path = self.args.results + self.setting + f'/wise_{self.args.run_wise}/{train_phase}_{loop}/{self.target_station}'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path, exist_ok=True)

        mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = metric(preds[:, :, -1], trues[:, :, -1])
        if self.verbose:
            print(f'{train_phase} {loop}, nse:{nse}, pbias:{pbias}, kge:{kge}, flv:{flv}, '
                  f'fhv:{fhv}, mae: {mae}, mse: {mse}, rmse: {rmse}, '
                  f'mape: {mape}, mspe: {mspe}, r2: {r2}')
        f = open("result_mekong_phase_a.txt", 'a')
        f.write(self.setting + "  \n")
        f.write('rmse:{}, mae:{}'.format(rmse, mae))
        f.write('\n')
        f.write('\n')
        f.close()

        np.save(os.path.join(str(folder_path), 'metrics.npy'),
                np.array([mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv,
                          fhv]))
        np.save(os.path.join(str(folder_path), 'pred.npy'), preds)
        np.save(os.path.join(str(folder_path), 'true.npy'), trues)
        np.save(os.path.join(str(folder_path), 'input.npy'), inputs)
        return mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv, \
            mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv
