from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate, visual
from utils.losses import RmseLoss
from utils.metrics import metric
from utils.constants_LamaH import Constants
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np
import pandas as pd
from tqdm import trange

current_dir = os.path.dirname(os.path.abspath(__file__))
warnings.filterwarnings('ignore')


class Batcher:
    def __init__(self, args, x, y, c=None, buff_time=0, shuffle=True, device='cuda'):
        """
        x: [T, N, C_in]  历史输入序列
        y: [T, N, C_out]  目标序列
        c: [N, C_static] 静态特征
        buff_time: 训练阶段 warm-up 长度，输入总长度 = seq_len + buff_time
        """
        self.x = x
        self.y = y
        self.c = c
        self.seq_len = args.seq_len
        self.label_len = args.label_len
        self.pred_len = args.pred_len
        self.buff_time = buff_time
        self.batch_size = args.batch_size
        self.device = device
        self.shuffle = shuffle

        self.T, self.N, self.C_in = x.shape
        _, _, self.C_out = y.shape

        # 所有时间-站点索引组合
        self.station_ids = np.arange(self.N)
        max_time_idx = self.T - (self.seq_len + self.buff_time + self.pred_len) + 1
        self.time_indices = np.arange(max_time_idx)
        self.all_indices = np.array(np.meshgrid(self.time_indices, self.station_ids)).T.reshape(-1, 2)
        self.n_samples = len(self.all_indices)

        # 预计算序列索引
        self.seq_range = torch.arange(self.seq_len + self.buff_time, device=self.device)
        self.label_range = torch.arange(self.label_len + self.pred_len, device=self.device)

        # 计算 n_iter_epoch 保证 ~99% 样本覆盖
        if self.batch_size >= self.N:
            """batch_size larger than total num_stations"""
            self.batch_size = self.N
        p = self.batch_size * self.seq_len / (self.N * (self.T - buff_time))  # 一次迭代覆盖比例
        self.n_iter_epoch = int(np.ceil(np.log(0.01) / np.log(1 - p)))

        # 将数据转到 GPU，避免每次 batch 重复拷贝
        self.x = self.x.to(self.device)
        self.y = self.y.to(self.device)
        if self.c is not None:
            self.c = self.c.to(self.device)

        # 当前迭代位置（顺序模式用）
        self.current_idx = 0

    def get_epoch_batches(self):
        # 打乱索引
        if self.shuffle:
            perm = np.random.permutation(self.n_samples)
            indices = self.all_indices[perm]
        else:
            indices = self.all_indices

        for i in range(self.n_iter_epoch):
            start_idx = i * self.batch_size
            end_idx = min(start_idx + self.batch_size, self.n_samples)
            batch_idx = indices[start_idx:end_idx]

            batch_time = torch.tensor(batch_idx[:, 0], device=self.device, dtype=torch.long)
            batch_station = torch.tensor(batch_idx[:, 1], device=self.device, dtype=torch.long)

            # seq_x: [B, seq_len + buff_time, C_in]
            seq_x_idx = batch_time[:, None] + self.seq_range[None, :]
            batch_x = self.x[seq_x_idx, batch_station[:, None], :]

            # seq_y: [B, label_len + pred_len, C_out]
            seq_y_idx = batch_time[:, None] + (self.seq_len - self.label_len) + self.label_range[None, :]
            batch_y = self.y[seq_y_idx, batch_station[:, None], :]

            # 拼接静态特征
            if self.c is not None:
                batch_c = self.c[batch_station]  # [B, C_static]
                batch_c_exp = batch_c[:, None, :].expand(-1, self.seq_len + self.buff_time, -1)
                batch_x = torch.cat([batch_c_exp, batch_x], dim=-1)

            yield batch_x, batch_y, batch_time, batch_station


class Exp_MeKong(Exp_Basic):
    def __init__(self, args, station, setting):
        super(Exp_MeKong, self).__init__(args, station, setting)
        self.constants = Constants(args)
        self.model = self._build_model()
        self.ckpt_path = os.path.join(self.args.checkpoints, self.setting, f"wise_{self.args.run_wise}",
                                      self.target_station)
        if not os.path.exists(self.ckpt_path):
            os.makedirs(self.ckpt_path)
        self.test_results_path = os.path.join('test_results', setting, f"wise_{self.args.run_wise}",
                                              self.target_station)
        if not os.path.exists(self.test_results_path):
            os.makedirs(self.test_results_path)
        self.results_path = os.path.join(self.args.results, setting, f"wise_{self.args.run_wise}", self.target_station)
        if not os.path.exists(self.results_path):
            os.makedirs(self.results_path)

    def _build_model(self):
        model = self.model_dict[self.args.model].Model(self.args).float().to(self.device)
        return model

    def _get_data(self):
        (train_dataset, train_loader,
         val_dataset, val_loader,
         test_dataset, test_loader) = data_provider(
            self.args,
            self.verbose,
            f'{self.args.data}_mts',
            self.target_station
        )
        return train_dataset, train_loader, val_dataset, val_loader, test_dataset, test_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        if self.args.loss == 'MSE':
            criterion = nn.MSELoss()
        elif self.args.loss == 'RMSE':
            criterion = RmseLoss()
        else:
            raise NotImplementedError("loss function is not implemented, only accept in ['MSE', 'RMSE']")
        return criterion

    def vali(self, criterion, dataset, data_loader):
        total_loss = []
        self.model.eval()
        data_x = dataset.data_x
        data_y = dataset.data_y
        data_stamp = dataset.data_stamp
        static_tensor = dataset.static_tensor
        basin_norm_data = dataset.basin_norm_data

        # 转成 tensor 并搬到 GPU
        x = torch.tensor(data_x, device=self.device)
        y = torch.tensor(data_y, device=self.device)
        stamp = torch.tensor(data_stamp, device=self.device)
        c = torch.tensor(static_tensor, device=self.device)
        batcher = Batcher(self.args, x, y, c, buff_time=0, shuffle=True, device=self.device)
        with torch.no_grad():
            for batch_x, batch_y, _, _ in batcher.get_epoch_batches():
                batch_x_mark = None
                batch_y_mark = None

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :], device=self.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1)
                # encoder - decoder
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                outputs = outputs[:, -self.args.pred_len:, -1:]
                batch_y = batch_y[:, -self.args.pred_len:, -1:]
                pred = outputs.detach().cpu()
                true = batch_y.detach().cpu()
                loss = criterion(pred, true)
                total_loss.append(loss)
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss

    def train(self):
        (self.train_data, self.train_loader,
         self.vali_data, self.vali_loader,
         self.test_data, self.test_loader) = self._get_data()

        data_x = self.train_data.data_x
        data_y = self.train_data.data_y
        data_stamp = self.train_data.data_stamp
        static_tensor = self.train_data.static_tensor
        basin_norm_data = self.train_data.basin_norm_data

        # 转成 tensor 并搬到 GPU
        x = torch.tensor(data_x, device=self.device)
        y = torch.tensor(data_y, device=self.device)
        stamp = torch.tensor(data_stamp, device=self.device)
        c = torch.tensor(static_tensor, device=self.device)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()
        batcher = Batcher(self.args, x, y, c, buff_time=self.args.seq_len, shuffle=True, device=self.device)

        with trange(1, self.args.train_epochs + 1) as pbar:
            for epoch in pbar:
                pbar.set_description("Training")
                iter_count = 0
                train_loss = []

                self.model.train()
                for batch_x, batch_y, _, _ in batcher.get_epoch_batches():
                    iter_count += 1
                    batch_x_mark = None
                    batch_y_mark = None

                    # decoder input
                    dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :], device=self.device)
                    dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1)

                    # encoder - decoder
                    model_optim.zero_grad()
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                    outputs = outputs[:, -self.args.pred_len:, -1:]
                    batch_y = batch_y[:, -self.args.pred_len:, -1:]
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())
                    loss.backward()
                    model_optim.step()
                train_loss = np.average(train_loss)
                if self.verbose:
                    pbar.set_postfix(loss=train_loss)
                torch.save(self.model.state_dict(), self.ckpt_path + '/' + 'checkpoint.pth')

        best_model_path = os.path.join(str(self.ckpt_path), 'checkpoint.pth')
        self.model.load_state_dict(torch.load(str(best_model_path)))

    def test(self, test=0):
        if test == 1:
            (self.train_data, self.train_loader,
             self.vali_data, self.vali_loader,
             self.test_data, self.test_loader) = self._get_data()

        if self.verbose:
            print('loading model')
        self.model = self._build_model().to(self.device)
        self.model.load_state_dict(torch.load(os.path.join(str(self.ckpt_path), 'checkpoint.pth')))

        num_times = self.test_loader.dataset.data_x.shape[0] - self.args.seq_len - self.args.pred_len + 1
        preds = np.zeros((num_times, self.args.pred_len, self.args.num_stations))
        trues = np.zeros((num_times, self.args.pred_len, self.args.num_stations))

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark, station_ids, time_idx) in enumerate(self.test_loader):
                batch_x = batch_x.float().to(self.device, non_blocking=True)
                batch_y = batch_y[:, :, -1:].float().to(self.device, non_blocking=True)
                batch_x_mark = batch_x_mark.float().to(self.device, non_blocking=True)
                batch_y_mark = batch_y_mark.float().to(self.device, non_blocking=True)

                # decoder input
                dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :], device=self.device)
                dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1)
                # encoder - decoder
                outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                outputs = outputs[:, -self.args.pred_len:, -1:]
                batch_y = batch_y[:, -self.args.pred_len:, -1:]
                outputs = outputs.detach().cpu().numpy()  # [B, H, C]
                batch_y = batch_y.detach().cpu().numpy()

                if self.args.inverse:
                    outputs = self.test_data.inverse_transform(outputs, input_station_ids=station_ids)
                    batch_y = self.test_data.inverse_transform(batch_y, input_station_ids=station_ids)

                for idx in range(len(station_ids)):
                    n = station_ids[idx].item()
                    t = time_idx[idx].item()
                    preds[t, :, n] = outputs[idx, :, -1]
                    trues[t, :, n] = batch_y[idx, :, -1]
                # if i % 20 == 0:
                # input = batch_x[:, :, -1:].detach().cpu().numpy()
                # if self.args.inverse:
                #     input = self.test_data.inverse_transform(input, input_station_ids=station_ids)
                # gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
                # pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
                # visual(gt, pd, os.path.join(str(self.test_results_path), str(i) + '.pdf'))

        if self.verbose:
            print('test shape:', preds.shape, trues.shape)

        all_mae = []
        all_mse = []
        all_rmse = []
        all_mape = []
        all_mspe = []
        all_r2 = []
        all_nse = []
        all_pbias = []
        all_kge = []
        all_flv = []
        all_fhv = []
        for node_i in range(preds.shape[-1]):
            mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv = metric(preds[:, :, node_i], trues[:, :, node_i])
            all_mae.append(mae)
            all_mse.append(mse)
            all_rmse.append(rmse)
            all_mape.append(mape)
            all_mspe.append(mspe)
            all_r2.append(r2)
            all_nse.append(nse)
            all_pbias.append(pbias)
            all_kge.append(kge)
            all_flv.append(flv)
            all_fhv.append(fhv)

        avg_mae = np.mean(all_mae)
        avg_mse = np.mean(all_mse)
        avg_rmse = np.mean(all_rmse)
        avg_mape = np.mean(all_mape)
        avg_mspe = np.mean(all_mspe)
        avg_r2 = np.mean(all_r2)
        avg_nse = np.mean(all_nse)
        avg_pbias = np.mean(all_pbias)
        avg_kge = np.mean(all_kge)
        avg_flv = np.mean(all_flv)
        avg_fhv = np.mean(all_fhv)
        median_mae = np.median(all_mae)
        median_mse = np.median(all_mse)
        median_rmse = np.median(all_rmse)
        median_mape = np.median(all_mape)
        median_mspe = np.median(all_mspe)
        median_r2 = np.median(all_r2)
        median_nse = np.median(all_nse)
        median_pbias = np.median(all_pbias)
        median_kge = np.median(all_kge)
        median_flv = np.median(all_flv)
        median_fhv = np.median(all_fhv)

        if self.verbose:
            print(f'avg_nse:{avg_nse}, avg_pbias:{avg_pbias}, avg_kge:{avg_kge}, avg_flv:{avg_flv}, '
                  f'avg_fhv:{avg_fhv}, avg_mae: {avg_mae}, avg_mse: {avg_mse}, avg_rmse: {avg_rmse}, '
                  f'avg_mape: {avg_mape}, avg_mspe: {avg_mspe}, avg_r2: {avg_r2}')
            print(f'median_nse: {median_nse}, median_pbias: {median_pbias}, median_kge: {median_kge}, '
                  f'median_flv: {median_flv}, median_fhv: {median_fhv}, median_mae: {median_mae}, '
                  f'median_mse: {median_mse}, median_rmse: {median_rmse}, median_mape: {median_mape}, '
                  f'median_mspe: {median_mspe}, median_r2: {median_r2}')
        f = open("result_initial.txt", 'a')
        f.write(self.setting + "  \n")
        f.write(f'avg_nse:{avg_nse}, avg_pbias:{avg_pbias}, avg_kge:{avg_kge}, avg_flv:{avg_flv}, '
                f'avg_fhv:{avg_fhv}, avg_mae: {avg_mae}, avg_mse: {avg_mse}, avg_rmse: {avg_rmse}, '
                f'avg_mape: {avg_mape}, avg_mspe: {avg_mspe}, avg_r2: {avg_r2}')
        f.write('\n')
        f.write(f'median_nse: {median_nse}, median_pbias: {median_pbias}, median_kge: {median_kge}, '
                f'median_flv: {median_flv}, median_fhv: {median_fhv}, median_mae: {median_mae}, '
                f'median_mse: {median_mse}, median_rmse: {median_rmse}, median_mape: {median_mape}, '
                f'median_mspe: {median_mspe}, median_r2: {median_r2}')
        f.write('\n')
        f.write('\n')
        f.close()

        np.save(os.path.join(str(self.results_path), 'metrics.npy'),
                np.array([avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv,
                          avg_fhv]))
        np.save(os.path.join(str(self.results_path), 'pred.npy'), preds)
        np.save(os.path.join(str(self.results_path), 'true.npy'), trues)
        return avg_mae, avg_mse, avg_rmse, avg_mape, avg_mspe, avg_r2, avg_nse, avg_pbias, avg_kge, avg_flv, avg_fhv, \
            median_mae, median_mse, median_rmse, median_mape, median_mspe, median_r2, median_nse, median_pbias, median_kge, median_flv, median_fhv
