from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping_mekong_local_global, adjust_learning_rate, visual
from utils.metrics import metric
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
from copy import deepcopy
import json

current_dir = os.path.dirname(os.path.abspath(__file__))

warnings.filterwarnings('ignore')


class Exp_MeKong(Exp_Basic):
    """
    Global:
    If target station is B, whose link stations are [A, C], build models:
     B: L_B -> H_B1
     A_2_B: L_A -> H_B2
     C_2_B: L_C -> H_B3
     pred_from_AC = avg(H_B2 + H_B3)
     prediction = alpha * H_B1 + (1 - alpha) * pred_from_AC
    """

    def __init__(self, args, test_station):
        super(Exp_MeKong, self).__init__(args, test_station)
        self.test_station = test_station
        self.stations_list = [test_station]
        self.stations_list.append(self.target_station)
        self.models_dic = self._build_models_dict()
        self.alpha = self.args.alpha

    def _build_models_dict(self):
        models_dic = {}
        if self.args.features == 'S':
            num_channels = 1
        else:
            num_channels = len(self.station_channels_dict[self.target_station])
        self.args.enc_in = num_channels
        self.args.dec_in = num_channels
        self.args.c_out = num_channels
        models_dic[self.target_station] = self.model_dict[self.args.model].Model(self.args).float().to(self.device)

        if self.args.features == 'S':
            num_channels = 1
        else:
            num_channels = len(self.station_channels_dict[self.test_station])
        self.args.enc_in = num_channels
        self.args.dec_in = num_channels
        self.args.c_out = num_channels
        models_dic[self.test_station] = self.model_dict[self.args.model].Model(self.args).float().to(self.device)
        return models_dic

    def _get_data(self, flag, stations_list=None):
        if stations_list is None:
            stations_list = self.stations_list
        data_loader_name = f'{self.args.data}_stations_list'
        data_set, data_loader = data_provider(self.args, flag, self.verbose,
                                              data_loader_name, stations_list=stations_list)
        return data_set, data_loader

    def _select_optimizer(self, model, lr=None):
        if lr is None:
            lr = self.args.learning_rate
        if isinstance(model, list):
            model_optim = optim.Adam(model, lr=lr)
        else:
            model_optim = optim.Adam(model.parameters(), lr=lr)
        return model_optim

    def _select_criterion(self):
        criterion = nn.MSELoss()
        return criterion

    def vali(self, criterion, dataloader, train_phase='local', station=None):
        total_loss = []
        for name in self.models_dic.keys():
            self.models_dic[name].eval()
        with torch.no_grad():
            for i, data_list in enumerate(dataloader):
                input_list = data_list[0]
                true_list = data_list[1]
                input_list = [x.float().to(self.device) for x in input_list]
                true_list = [x.float().to(self.device) for x in true_list]
                batch_x_mark = data_list[2].float().to(self.device)
                batch_y_mark = data_list[3].float().to(self.device)
                target_x = deepcopy(input_list[-1])
                target_y = deepcopy(true_list[-1])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                if train_phase == 'local':
                    input_x = deepcopy(input_list[0])
                else:
                    input_x = target_x
                target_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark, target_x=target_x)
                target_out = target_out[:, -self.args.pred_len:, -1:]
                # other_out = self._get_other_out(input_list)
                # real_out is the output for inference: avg from target and others
                real_out = target_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                del data_list
                del input_list
                del true_list
                real_out = real_out.detach().cpu()
                target_y = target_y.detach().cpu()
                loss = criterion(real_out, target_y)
                total_loss.append(loss)
        total_loss = np.average(total_loss)
        # for name in self.models_dic.keys():
        #     self.models_dic[name].train()
        return total_loss

    def _train_local(self, setting, stations_list=None):
        self.train_data, self.train_loader = self._get_data(flag='train', stations_list=stations_list)
        self.vali_data, self.vali_loader = self._get_data(flag='val', stations_list=stations_list)
        station = stations_list[0]
        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(os.path.join(path, self.run_type, 'local')):
            os.makedirs(os.path.join(path, self.run_type, 'local'))

        time_now = time.time()

        train_steps = len(self.train_loader)
        early_stopping = EarlyStopping_mekong_local_global(patience=self.args.patience, verbose=self.verbose,
                                                           delta=1e-6)

        model_optim = self._select_optimizer(self.models_dic[station])
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            epoch_time = time.time()
            self.models_dic[station].train()
            for i, data_list in enumerate(self.train_loader):
                iter_count += 1
                input_list = data_list[0]
                true_list = data_list[1]
                batch_x_mark = data_list[2].float().to(self.device)
                batch_y_mark = data_list[3].float().to(self.device)
                input_list = [x.float().to(self.device) for x in input_list]
                true_list = [x.float().to(self.device) for x in true_list]
                input_x = deepcopy(input_list[0])
                target_x = deepcopy(input_list[-1])
                target_y = deepcopy(true_list[-1])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                model_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark, target_x=target_x)
                model_out = model_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                model_optim.zero_grad()
                loss = criterion(model_out, target_y)
                loss.backward()
                model_optim.step()
                del data_list
                del input_list
                del true_list
                # torch.cuda.empty_cache()
                train_loss.append(loss.item())

                if self.verbose:
                    if (i + 1) % 100 == 0:
                        print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                        speed = (time.time() - time_now) / iter_count
                        left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                        print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                        iter_count = 0
                        time_now = time.time()

            if self.verbose:
                print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(criterion, self.vali_loader, train_phase='local', station=station)
            if self.verbose:
                print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}".format(
                    epoch + 1, train_steps, train_loss, vali_loss))
            early_stopping(vali_loss, self.models_dic, path + f'/{self.run_type}/local/')
            if early_stopping.early_stop:
                if self.verbose:
                    print("Early stopping")
                break
            adjust_learning_rate(model_optim, epoch + 1, self.args, self.verbose)

        ckpt_name = f'checkpoint_{station}.pth'
        best_model_path = path + f'/{self.run_type}/local/' + ckpt_name
        self.models_dic[station].load_state_dict(torch.load(best_model_path))

    def _train_global(self, setting):
        # other_station_list = self.other_station_list
        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(os.path.join(path, self.run_type, 'global')):
            os.makedirs(os.path.join(path, self.run_type, 'global'))
        for name in self.models_dic.keys():
            ckpt_name = f'checkpoint_{name}.pth'
            torch.save(self.models_dic[name].state_dict(), os.path.join(path, self.run_type, 'global', ckpt_name))
        # if not other_station_list:
        #     return None

        time_now = time.time()

        train_steps = len(self.train_loader)
        early_stopping = EarlyStopping_mekong_local_global(patience=self.args.patience, verbose=self.verbose,
                                                           delta=1e-6)

        model_optim_dic = {}
        for name, model in self.models_dic.items():
            model_optim_dic[name] = self._select_optimizer(model, lr=self.args.global_lr)
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            epoch_time = time.time()
            for name in self.models_dic.keys():
                self.models_dic[name].train()
            for i, data_list in enumerate(self.train_loader):
                iter_count += 1
                input_list = data_list[0]
                true_list = data_list[1]
                batch_x_mark = data_list[2].float().to(self.device)
                batch_y_mark = data_list[3].float().to(self.device)
                input_list = [x.float().to(self.device) for x in input_list]
                true_list = [x.float().to(self.device) for x in true_list]
                target_i = self.stations_list.index(self.target_station)
                target_x = deepcopy(input_list[target_i])
                target_y = deepcopy(true_list[target_i])
                # decoder input
                dec_inp = torch.zeros_like(target_y[:, -self.args.pred_len:, -1:]).float()
                dec_inp = torch.cat([target_y[:, :self.args.label_len, -1:], dec_inp], dim=1).float().to(self.device)
                # encoder - decoder
                target_out = self.models_dic[self.target_station](target_x, batch_x_mark, dec_inp, batch_y_mark)
                target_out = target_out[:, -self.args.pred_len:, -1:]
                other_out_list = []
                # update others
                for x_i, station in enumerate(self.models_dic.keys()):
                    if station == self.target_station:
                        continue
                    station_i = self.stations_list.index(station)
                    input_x = deepcopy(input_list[station_i])
                    model_out = self.models_dic[station](input_x, batch_x_mark, dec_inp, batch_y_mark,
                                                         target_x=target_x)
                    model_optim_dic[station].zero_grad()
                    loss = (self.alpha * criterion(model_out[:, -self.args.pred_len:, -1:],
                                                   target_y[:, -self.args.pred_len:, -1:])
                            + (1 - self.alpha) * criterion(target_out[:, -self.args.pred_len:, -1:].detach(),
                                                           model_out))
                    loss.backward(retain_graph=True)
                    model_optim_dic[station].step()
                    other_out_list.append(model_out[:, -self.args.pred_len:, -1:].detach())
                avg_other_out = torch.stack(other_out_list, dim=0).mean(dim=0)
                # update target
                model_optim_dic[self.target_station].zero_grad()
                loss = (self.alpha * criterion(target_out[:, -self.args.pred_len:, -1:],
                                               target_y[:, -self.args.pred_len:, -1:])
                        + (1 - self.alpha) * criterion(avg_other_out[:, -self.args.pred_len:, -1:],
                                                       target_out[:, -self.args.pred_len:, -1:]))
                loss.backward()
                model_optim_dic[self.target_station].step()
                real_out = target_out
                real_out = real_out[:, -self.args.pred_len:, -1:]
                target_y = target_y[:, -self.args.pred_len:, -1:]
                del data_list
                del input_list
                del true_list
                loss = criterion(real_out, target_y)
                train_loss.append(loss.item())

                if self.verbose:
                    if (i + 1) % 100 == 0:
                        print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                        speed = (time.time() - time_now) / iter_count
                        left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                        print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                        iter_count = 0
                        time_now = time.time()

            if self.verbose:
                print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(criterion, self.vali_loader, train_phase='global', station=self.target_station)
            if self.verbose:
                print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}".format(
                    epoch + 1, train_steps, train_loss, vali_loss))
            early_stopping(vali_loss, self.models_dic, path + f'/{self.run_type}/global/')
            if early_stopping.early_stop:
                if self.verbose:
                    print("Early stopping")
                break

            for adj_lr_i, name in enumerate(self.models_dic.keys()):
                if adj_lr_i == 0:
                    verbose = self.verbose
                else:
                    verbose = False
                adjust_learning_rate(model_optim_dic[name], epoch + 1, self.args, verbose)

        for name in self.models_dic.keys():
            ckpt_name = f'checkpoint_{name}.pth'
            best_model_path = path + f'/{self.run_type}/global/' + ckpt_name
            self.models_dic[name].load_state_dict(torch.load(best_model_path))

    def train(self, setting, test_station):
        link_exists = False
        criterion = self._select_criterion()
        for station in self.stations_list:
            if self.args.load_model:
                try:
                    ckpt_name = f'checkpoint_{station}.pth'
                    self.models_dic[station].load_state_dict(torch.load(
                        os.path.join('./checkpoints/' + setting, self.run_type, 'local', ckpt_name)
                    ))
                    if self.verbose:
                        print("load local model, other station: {}".format(station))
                except FileNotFoundError:
                    self._train_local(setting, [station, self.target_station])
            else:
                self._train_local(setting, [station, self.target_station])
        local_vali_loss = self.vali(criterion, self.vali_loader, 'final', station=self.target_station)

        self.train_data, self.train_loader = self._get_data(flag='train')
        self.vali_data, self.vali_loader = self._get_data(flag='val')
        if self.args.load_model:
            try:
                for station in self.stations_list:
                    ckpt_name = f'checkpoint_{station}.pth'
                    self.models_dic[station].load_state_dict(torch.load(
                        os.path.join('./checkpoints/' + setting, self.run_type, 'global', ckpt_name)
                    ))
                    if self.verbose:
                        print("load Global model, other station: {}".format(station))
            except FileNotFoundError:
                self._train_global(setting)
        else:
            self._train_global(setting)
        global_vali_loss = self.vali(criterion, self.vali_loader, 'final', station=self.target_station)
        if global_vali_loss < self.args.flow_list_vali_factor * local_vali_loss:
            link_exists = True
        return link_exists
