from .exp_basic import ExpBasic
from dataloader.dataset_short_term_forecasting import M4Meta
from autoaugment import AutoAugmentBasic
from dataloader import get_dataset
from torch import optim
import torch.nn as nn
import os
from utils.tools import EarlyStopping
import time
import torch
import numpy as np
import pandas as pd
from utils.losses import smape_loss

from utils.metrics import metric
from utils.m4summary import M4Summary

import os
import time

class ExpShortTermForecasting(ExpBasic):
    """
    """
    def _maybe_set_m4_defaults(self):
        if getattr(self.config.args, "dataset_type", "").upper() == "M4" or \
            getattr(self.config.args, "data", "").upper() == "M4":
            sp = getattr(self.config.args, "seasonal_patterns", "Monthly")
            self.config.args.pred_len = M4Meta.horizons_map[sp]
            self.config.args.seq_len = 2 * self.config.args.pred_len
            self.config.args.label_len = self.config.args.pred_len
            self.config.args.frequency_map = M4Meta.frequency_map[sp]
    
    def _build_model(self) -> AutoAugmentBasic:
        self._maybe_set_m4_defaults()

        train_data, _ = self._get_data('TRAIN', 'TRAIN')
        test_data, _ = self._get_data('TEST', 'TEST')

        self.config.dimensions.update(
            n_channels=train_data.n_channels,
            seq_len=max(train_data.seq_len, test_data.seq_len),
            n_features=train_data.n_features, # = D_enc
            pred_len=train_data.pred_len
        )

        self.config.args.n_channels = train_data.n_channels
        self.config.args.seq_len = max(train_data.seq_len, test_data.seq_len)
        self.config.args.n_features = train_data.n_features
        self.config.args.pred_len = train_data.pred_len
        self.config.args.label_len = train_data.label_len

        # model init
        model: AutoAugmentBasic = self.model_class(self.config).float().to(self.device)
        return model
    
    def _get_data(self, flag: str, load_as: str):
        if load_as not in self.loaded_data:
            dataset, data_loader = get_dataset(self.config, flag)
            self.loaded_data[load_as] = (dataset, data_loader)
        else:
            dataset, data_loader = self.loaded_data[load_as]
        return dataset, data_loader

    def _select_optimizer(self):
        return optim.RAdam(self.model.parameters(), lr=self.config.args.learning_rate)
    
    def _select_criterion(self):
        # return self.model.get_criterion(default_criterion=nn.MSELoss())
        return smape_loss()

    
    # shared tools
    @torch.no_grad()
    def _make_dec_inp(self, batch_x, batch_y):
        """
        dec_inp: [B, C, Llabel_len+pred_len] -
        """
        B, C, _ = batch_x.shape
        Llabel = self.config.args.label_len
        Lpred = self.config.args.pred_len

        # print(f"label_len: {Llabel}")
        # print(f"pred_len: {Lpred}")
        # exit(0)
        hist = batch_x[:, :, -Llabel:] # [B, C, Llabel]
        zeros = torch.zeros(B, C, Lpred, device=batch_x.device, dtype=batch_x.dtype)
        return torch.cat([hist, zeros], dim=2)

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                dec_inp = self._make_dec_inp(batch_x, batch_y)
                outputs, _, _ = self.model(batch_x, dec_inp, batch_x_mark, batch_y_mark)

                pred = outputs.detach()
                true = batch_y.detach()

                loss = criterion(pred, true)
                total_loss.append(loss.item())
        
        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss




    def train(self):
        train_data, train_loader = self._get_data('TRAIN', 'TRAIN')
        vali_data, vali_loader = self._get_data('VALI', 'VALI')
        test_data, test_loader = self._get_data('TEST', 'TEST')

        ckpt_path = self.config.get_checkpoint_path()
        if not os.path.exists(os.path.dirname(ckpt_path)):
            os.makedirs(os.path.dirname(ckpt_path))
        
        time_now = time.time()
        early_stopping = EarlyStopping(patience=self.config.args.patience, verbose=True)
        model_optim = self._select_optimizer()
        criterion = self._select_criterion()
        train_steps = len(train_loader)

        for epoch in range(self.config.args.train_epochs):
            print("[Model State]", end="")
            print(self.model.summarize_state())
            iter_count = 0
            epoch_time = time.time()
            losses = []

            self.model.train()
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()

                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)
                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                dec_inp = self._make_dec_inp(batch_x, batch_y) # [B, C, Llabel+Ly]

                # print(f"batch_x.shape: {batch_x.shape}")
                # print(f"batch_y.shape: {batch_y.shape}")
                # print(f"dec_inp.shape: {dec_inp.shape}")
                # print(f"bathc_x_mark.shape: {batch_x_mark.shape}")
                # print(f"batch_y_mark.shape: {batch_y_mark.shape}")
                # exit(0)

                outputs, _, _ = self.model(batch_x, dec_inp, batch_x_mark, batch_y_mark)
                # outputs = torch.nan_to_num(outputs, nan=0.0, posinf=1e3, neginf=-1e3)

                # print(f"outputs.shape: {outputs.shape}")
                # print(f"batch_y.shape: {batch_y.shape}")

                # print("=================")
                # print(outputs)
                # print("=================")
                # print(batch_y)
                # print("=================")

                # exit(0)

                loss = criterion(outputs, batch_y)
                loss.backward()
                model_optim.step()
                losses.append(loss.item())

                if (i+1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.config.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()
            
            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(losses)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
                epoch + 1, train_steps, train_loss, vali_loss, test_loss))
            early_stopping(vali_loss, self.model, ckpt_path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
        
        # self.model.load_state_dict(torch.load(ckpt_path, weights_only=True, map_location=self.device))
        return self.model


    def test(self, load_checkpoint: bool = False):
        test_data, test_loader = self._get_data('TEST', 'TEST')        
        if load_checkpoint:
            checkpoint_path = self.config.get_checkpoint_path()
            print(f"try load model checkpoint from {checkpoint_path}")
            self.model.load_state_dict(torch.load(checkpoint_path,
                                                  weights_only=True,
                                                  map_location=self.device))
            
        preds, trues = [], []

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                batch_y = batch_y.float().to(self.device)

                batch_x_mark = batch_x_mark.float().to(self.device)
                batch_y_mark = batch_y_mark.float().to(self.device)

                # decoder input
                dec_inp = self._make_dec_inp(batch_x, batch_y) 
                outputs, _, _ = self.model(batch_x, dec_inp, batch_x_mark, batch_y_mark)

                # 统一为 [B, C, Ly]
                f_dim = -1 if getattr(self.config.args, "features", "S") == "MS" else 0
                pred = outputs[:, f_dim:, -self.config.args.pred_len:]
                true = batch_y[:, f_dim:, -self.config.args.pred_len:]

                preds.append(pred.detach().cpu().numpy())
                trues.append(true.detach().cpu().numpy())

        preds = np.concatenate(preds, axis=0) # [B, C, L]
        trues = np.concatenate(trues, axis=0) # [B, C, L]
        print("test shape: ", preds.shape, trues.shape)

        preds_eval = np.transpose(preds, (0, 2, 1))
        trues_eval = np.transpose(trues, (0, 2, 1))

        mae, mse, rmse, mape, mspe = metric(preds_eval, trues_eval)
        print(f"mse: {mse}, mae: {mae}")


        out = self.config.get_test_result_path()
        os.makedirs(os.path.dirname(out), exist_ok=True)

        with open(out, 'a') as f:
            f.write(f"{self.config.get_keyword()} \n MSE: {mse}, MAE: {mae}\n\n")




        # import re

        # # base results folder
        # base_folder = "./m4_results_AutoTSA"
        # os.makedirs(base_folder, exist_ok=True)

        # downstream_name = str(self.config.args.downstream)
        # season = str(self.config.args.seasonal_patterns)
        # fname = f"{season}_forecast.csv"


        

        run_tag = getattr(self.config.args, "run_tag", time.strftime("%Y%m%d_%H%M%S"))

        base_folder = "./m4_results_AutoTSA"
        run_dir = os.path.join(base_folder, str(run_tag))
        os.makedirs(run_dir, exist_ok=True)

        downstream_name = str(self.config.args.downstream)
        season = str(self.config.args.seasonal_patterns)
        fname = f"{season}_forecast.csv"

        csv_path = os.path.join(run_dir, fname)
        
        
        B, C, L = preds.shape
        if C == 1:
            pred_matrix = preds[:, 0, :]                  # [B, L]
        else:
            pred_matrix = preds[:, 0, :]                  

        assert L == self.config.args.pred_len, f"pred_len不一致: got L={L}, expected {self.args.pred_len}"
        columns = [f"V{i+1}" for i in range(self.config.args.pred_len)]

        ids = test_loader.dataset._ids[:pred_matrix.shape[0]]
        forecasts_df = pd.DataFrame(pred_matrix, columns=columns, index=ids)
        forecasts_df.index.name = 'id'

        tmp_path = csv_path + ".tmp"
        forecasts_df.to_csv(tmp_path)
        os.replace(tmp_path, csv_path)

        print(f"[INFO] Saved forecast to {csv_path}")


        # print(f"owa: {owa_results}")

        return