from typing import Generator, Iterator

import numpy as np
import torch

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.core.enums import Task
from tabicl.data.preprocessor import Preprocessor
from tabicl.data.synthetic_generator_selector import select_synthetic_dataset_generator


class SyntheticDataset(torch.utils.data.IterableDataset):

    def __init__(self, cfg: ConfigPretrain) -> None:
        self.cfg = cfg


    def __iter__(self) -> Iterator:
        self.synthetic_dataset_generator = select_synthetic_dataset_generator(self.cfg)
        return self.generator()
    

    def generator(self) -> Generator[dict[str, torch.Tensor], None, None]:
        
        while True:
            x, y = next(self.synthetic_dataset_generator)
            
            x_support, y_support, x_query, y_query = self.split_into_support_and_query(x, y)

            if np.unique(y_support).shape[0] < 2 or np.unique(y_query).shape[0] < 2:
                # Degenerated cases: there is only one class (or regression value) in the support or query set
                continue
            
            n_classes = max(y) + 1 if self.cfg.data.task == Task.CLASSIFICATION else 0

            preprocessor = Preprocessor(
                dim_embedding=self.cfg.data.max_features,  # We assume that the number of features selected for pertraining always fit the model
                n_classes=n_classes,
                dim_output=self.cfg.data.max_classes,  # We assume that the number of classes selected for pretraining always fit the model
                use_quantile_transformer=self.cfg.preprocessing.use_quantile_transformer,
                use_feature_count_scaling=self.cfg.preprocessing.use_feature_count_scaling,
                shuffle_classes=True,
                shuffle_features=True,
                random_mirror_x=True,
                random_mirror_regression=True,
                task=self.cfg.data.task
            )
            preprocessor.fit(x_support, y_support)

            x_support = preprocessor.transform_X(x_support)
            y_support = preprocessor.transform_y(y_support)
            x_query = preprocessor.transform_X(x_query)
            y_query = preprocessor.transform_y(y_query)

            x_support_tensor = torch.from_numpy(x_support)
            y_support_tensor = torch.from_numpy(y_support)
            x_query_tensor = torch.from_numpy(x_query)
            y_query_tensor = torch.from_numpy(y_query)

            yield {
                'x_support': x_support_tensor,
                'y_support': y_support_tensor,
                'x_query': x_query_tensor,
                'y_query': y_query_tensor
            }




    def split_into_support_and_query(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

        curr_samples = x.shape[0]

        n_samples_support = np.random.randint(low=self.cfg.data.min_samples_support, high=self.cfg.data.max_samples_support)
        rand_index = np.random.permutation(curr_samples)

        rand_support_index = rand_index[:n_samples_support]
        rand_query_index = rand_index[n_samples_support:n_samples_support+self.cfg.data.n_samples_query]

        x_support = x[rand_support_index]
        y_support = y[rand_support_index]
        x_query = x[rand_query_index]
        y_query = y[rand_query_index]

        return x_support, y_support, x_query, y_query
    
