import numpy as np
import torch
from aeon.datasets import load_classification

from utils import GlobalConfig
from utils.socks_proxy import ProxyContext
from .dataset_basic import BasicDataset
from .tsfeature_extractors import get_feature_extractor, FeatureExtractor


class DefaultClassificationDataset(BasicDataset):
    def __init__(self, config:GlobalConfig, flag):
        super().__init__(config, flag)

    def _load_data(self, root_path, dataset, flag):
        assert flag in {"TRAIN","TEST"}
        print(f"[{self.__class__.__name__}] loading data (flag={flag})")
        # (batch, n_channel, seq_len) and (batch, n_channel, 1)
        with ProxyContext(self.config):
            self.X, self.Y, metadata = load_classification(dataset, split=flag, extract_path=root_path,
                                                           return_metadata=True)
            self.X = np.array(self.X)

        print(f"[{self.__class__.__name__}] data loaded (flag={flag}) input size={self.X.shape}")
        self.n_classes = len(metadata["class_values"])
        class_map = {v:i for i,v in enumerate(metadata["class_values"])}
        self.Y = np.array([class_map[e] for e in self.Y], dtype=np.int8)        


        self.fe:FeatureExtractor = get_feature_extractor(self.config.args.feature_extractor)(self.config, None)
        # (batch, n_channel, cond_dim)
        self.F = self.fe(self.X)

        self.X = torch.nan_to_num(torch.from_numpy(self.X), nan=0)
        self.Y = torch.nan_to_num(torch.from_numpy(self.Y), nan=0)
        self.F = torch.nan_to_num(torch.from_numpy(self.F), nan=0)
        self.mask = torch.ones((self.X.shape[0], self.X.shape[-1]), dtype=torch.bool, device=self.X.device)
        self._parse_dimensions()

        print(f"self.X.shape: {self.X.shape}")
        print(f"self.F.shape: {self.F.shape}")
        
        return self.X, self.Y, self.F, self.mask


AVAILABLE_CLASSIFICATION_DATASETS = {
    "Default": DefaultClassificationDataset
}

