import torch
import pandas as pd
import numpy as np
import os
from abc import ABC, abstractmethod
from utils import create_logger
from data_utils.timestamps import TimeCovariates, get_sin_cos_embedding
from data_utils.utils import (
    flatten_windows,
    ema_trend,
)
import math
from data_utils.utils import MinMaxArgs, MinMaxScaler, load_tsf_as_timeseries
from scipy import io as scipy_io
from types import SimpleNamespace

class TimeSeriesDataset(ABC):
    def __init__(
        self,
        args : SimpleNamespace,
        logger=None,
    ):

        self.device = "cpu"
        self.exp_dir = args.exp_dir
        self.data_path = args.data_path
        data_name = args.data_path.split("/")[-1].split(".")[0]
        self.data_name = (
            data_name if data_name != "sim4" else "fMRI"
        )  # sim4 is a fMRI dataset
        self.seq_len = args.seq_len
        self.scale_time = args.scale_time
        self.overlapping_seqs_stride = args.overlapping_seqs_stride
        self.decompose = args.decompose

        self.timefreq_path = os.path.join(self.exp_dir, f"{data_name}_timefreq.pt")
        self.norm_params_path = os.path.join(self.exp_dir, f"{data_name}_norm_params.pt")
        self.logger = logger if logger is not None else create_logger(self.exp_dir, f"{self.data_name}_{self.__class__.__name__.lower()}.log")
        
        self.timefreq_data, self.time_covariates = None, None
        self.time_min, self.time_max = None, None

    def class_params_tosave_and_load(self):
        return [
            "timefreq_data",
            "time_covariates",
            "torch_data",
            "orig_df",
            "data_len",
            "n_features",
        ]
    
    def norm_params_tosave_and_load(self):
        return [
            "time_min",
            "time_max",
        ]

    def save_data(self, params_to_save: list[str], save_path: str):
        dict_to_save = self.__dict__.copy()
        for param in list(dict_to_save.keys()):
            if param not in params_to_save:
                dict_to_save.pop(param, None)
        torch.save(dict_to_save, save_path)
        self.logger.info(f"Saved {', '.join(params_to_save)} to {save_path}")

    def load_data(self, params_to_load: list[str], load_path=None):
        loaded_dict = torch.load(
            load_path,
            map_location=self.device,
            weights_only=False,
        )
        for param in list(loaded_dict.keys()):
            if param not in params_to_load:
                loaded_dict.pop(param, None)

        self.__dict__.update(loaded_dict)
        self.logger.info(f"Loaded {', '.join(params_to_load)} from {load_path}")

    def get_data(self):
        if self.timefreq_data is None:
            self.__init_dataset__()
        return self.timefreq_data, self.time_covariates

    def get_orig_df(self):
        return pd.DataFrame(self.fold_timeseries_windows(self.torch_data, avg_window=True).numpy(), columns=[f"var{i}" for i in range(self.n_features)])

    def compute_timestamp_data(self, timecovariates):
        return timecovariates.unfold(
            0, self.seq_len, self.overlapping_seqs_stride
        ).permute(
            0, 2, 1
        )  # (n_samples, n_covariates, seq_len) -> (n_samples, seq_len, n_covariates)

    @abstractmethod
    def compute_timefreq_single(self, torch_data: torch.Tensor):
        """
        Compute the time-frequency representation of a single time series.
        Args:
            torch_data (torch.Tensor): Time series data of shape (B, L, K).
        Returns:
            torch.Tensor: Time-frequency representation of shape (B, L, n_freqs, K).
        """
        pass

    @abstractmethod
    def compute_timefreq_decomposed(self, torch_data: torch.Tensor):
        """
        Compute the time-frequency representation of a decomposed time series.
        Args:
            torch_data (torch.Tensor): Time series data of shape (B, L, K).
        Returns:
            torch.Tensor: Time-frequency representation of shape (B, L, 2, n_freqs, K).
        """
        pass

    @abstractmethod
    def compute_norm_params(self, torch_data : torch.Tensor):
        pass

    @abstractmethod
    def normalize_timefreq_data(self, torch_data: torch.Tensor):
        pass
    
    @abstractmethod
    def unnormalize_timefreq_data(self, torch_data: torch.Tensor):
        pass
    

    def _scale_time(self, torch_data: torch.Tensor):
        """
        Scale the time series data using MinMaxScaler.
        Args:
            torch_data (torch.Tensor): Time series data of shape (B, L, K).
        Returns:
            torch.Tensor: Scaled time series data.
        """
        self.logger.info("Scaling time series data")
        original_shape = torch_data.shape
        torch_data = torch_data.reshape(-1, self.n_features)

        if self.time_min is not None and self.time_max is not None:
            torch_data = MinMaxArgs(torch_data, self.time_min, self.time_max)
        else:
            torch_data, self.time_min, self.time_max = MinMaxScaler(torch_data, return_scalers=True)
        torch_data = torch_data.reshape(original_shape)
        return torch_data
    
    def _unscale_time(self, torch_data: torch.Tensor):
        """
        Unscale the time series data using MinMaxScaler.
        Args:
            torch_data (torch.Tensor): Scaled time series data of shape (B, L, K).
        Returns:
            torch.Tensor: Unscaled time series data.
        """
        self.logger.info("Unscaling time series data")
        original_shape = torch_data.shape
        torch_data = torch_data.reshape(-1, self.n_features)
        torch_data = torch_data * (self.time_max - self.time_min) + self.time_min
        torch_data = torch_data.reshape(original_shape)
        return torch_data
    

    def __init_dataset__(self):
        
        assert not os.path.exists(self.timefreq_path) or os.path.exists(self.norm_params_path), (
            f"Either both {self.timefreq_path} and {self.norm_params_path} should exist, or neither should exist."
        )

        if os.path.exists(self.timefreq_path) and os.path.exists(self.norm_params_path):
            self.logger.info(
                f"Loading {self.data_name} timefreq_data and norm params from {self.timefreq_path} and {self.norm_params_path}"
            )
            self.load_data(self.class_params_tosave_and_load(), self.timefreq_path)
            self.load_data(self.norm_params_tosave_and_load(), self.norm_params_path)
            return

        self.logger.info(f"Computing {self.data_name} timefreq_data") 

        if self.data_name in ["mujoco", "sine"]:
            # These two dataset are already in the required shape (B, L, K)
            
            assert self.seq_len == 24, "Sequence length must be 24 for Mujoco dataset"
            torch_data = torch.load(self.data_path)  # B, L, K already
            timestamps_available = False
            self.data_len = (
                torch_data.shape[0] + torch_data.shape[1] - 1
            )  # supposing overlapping stride is 1
            self.n_features = torch_data.shape[2]
            self.orig_df = None            
        else:
            if self.data_name == "fMRI":
                data = scipy_io.loadmat(self.data_path)["ts"]
                orig_df = pd.DataFrame(data)
            elif self.data_name in ["mujoco2d", "sine2d"]:
                orig_df = pd.DataFrame(
                    torch.load(self.data_path).numpy()
                )
            elif self.data_path.endswith(".tsf"):
                orig_df = load_tsf_as_timeseries(self.data_path)
            else:
                orig_df = pd.read_csv(self.data_path)

            if "date" in orig_df.columns:
                orig_df["date"] = pd.to_datetime(orig_df["date"])
                orig_df = orig_df.set_index("date")
                timestamps_available = True
            else:
                timestamps_available = False

            self.data_len, self.n_features = orig_df.shape
            self.logger.info(
                f"Original {self.data_name} time series shape: {orig_df.shape}"
            )

            torch_data = torch.from_numpy(
                orig_df.to_numpy()
            ).float()  # (data_len, n_features)

            torch_data = torch_data.unfold(
                dimension=0,
                size=self.seq_len,
                step=self.overlapping_seqs_stride,
            ).permute(
                0, 2, 1
            )  # (B, L, K)

            self.orig_df = orig_df

        self.torch_data = torch_data

        if self.scale_time:
            torch_data = self._scale_time(torch_data)

        if self.decompose:
            self.logger.info(
                "Extracting trend and seasonal+residual components using EMA trend"
            )
            trend = ema_trend(
                torch_data,
                alpha=0.1,
            )  # B, L, K
            season_resid = torch_data - trend  # B, L, K

            torch_data = torch.stack(
                [
                    trend,
                    season_resid,
                ],
                dim=0,
            )

        # shape (B, L, K) if not decomposed, (2, B, L, K) if decomposed

        self.timefreq_data = (
            self.compute_timefreq_single(torch_data)
            if not self.decompose
            else self.compute_timefreq_decomposed(torch_data)
        )

        self.timefreq_data = self.normalize_timefreq_data(self.timefreq_data)

        if timestamps_available:
            time_covariates = (
                TimeCovariates(
                    orig_df.index,
                    normalized=True,
                    holiday=False,
                    sincos_embed_dim=4,
                )
                .get_covariates()
                .to_numpy()
            )
        else:
            time_covariates = get_sin_cos_embedding(
                np.arange(self.data_len),
                dim=28,  # 28 the total number of covariates in case of timestamps available.
            )

        self.time_covariates = self.compute_timestamp_data(
            torch.from_numpy(time_covariates).float()
        )

        self.logger.info(f"Saving {self.data_name} timefreq_data and norm params to {self.timefreq_path} and {self.norm_params_path}")
        self.save_data(self.class_params_tosave_and_load(), self.timefreq_path)
        self.save_data(self.norm_params_tosave_and_load(), self.norm_params_path)

    @abstractmethod
    def inverse_timefreq_single(self, timefreq_data: torch.Tensor):
        pass

    @abstractmethod
    def inverse_timefreq_decomposed(self, timefreq_data: torch.Tensor):
        pass

    def fold_timeseries_windows(self, timeseries_data, avg_window: bool = False):
        """
        Fold the windows of the time series data.
        Args:
            timeseries_data (torch.Tensor): Time series data of shape (B, L, K).
            avg_window (bool): Whether to average the windows or not.
        Returns:
            torch.Tensor: Folded time series data.
        """
        timeseries_data = flatten_windows(
            timeseries_data, self.overlapping_seqs_stride, avg_window=avg_window
        )
        return timeseries_data

    def get_timeseries_from_timefreq(
        self,
        timefreq_data,
        unscale: bool = False,
        fold_windows: bool = False,
        avg_window: bool = False,
    ):
        """
        Get the time series data from the time-frequency representation.
        Args:
            timefreq_data (torch.Tensor): Time-frequency representation of shape (B, L, C, F, K).
            unscale (bool): Whether to unscale the time series data in the time domain. Notice that this is only applied if scale_time was set to True.
            fold_windows (bool): Whether to fold the windows of the time series data to obtain a single time series of shape (T, K).
            avg_window (bool): Whether to average the windows or not. I fold_windows == False, the parameter is ignored.

            Notes:
            By default, unscale = False, fold_windows = False, avg_window = False. This is the basic test setting for metrics computation.
        Returns:
            torch.Tensor: Time series data of shape (B, L, K) | (T, K).
        """
        B, L, C, F, K = timefreq_data.shape
        timefreq_data = self.unnormalize_timefreq_data(timefreq_data)
        if self.decompose:
            trend, season_resid = self.inverse_timefreq_decomposed(timefreq_data)
            timeseries = trend + season_resid
        else:
            timeseries = self.inverse_timefreq_single(timefreq_data)

        assert (
            timeseries.shape == (B, self.seq_len, self.n_features)
        ), f"Expected timeseries shape (B, {self.seq_len}, {self.n_features}), but got {timeseries.shape}"
        if self.scale_time and unscale:
            timeseries = self._unscale_time(timeseries)

        if fold_windows:
            return self.fold_timeseries_windows(timeseries, avg_window=avg_window)
        return timeseries
