from typing import Generator, Iterator

import numpy as np
import torch

from tabicl.config.config_pretrain import ConfigPretrain
from tabicl.data.preprocessor import Preprocessor
from tabicl.data.synthetic_generator_selector import select_synthetic_dataset_generator


class SyntheticDataset(torch.utils.data.IterableDataset):

    def __init__(self, cfg: ConfigPretrain) -> None:
        self.cfg = cfg


    def __iter__(self) -> Iterator:
        self.synthetic_dataset_generator = select_synthetic_dataset_generator(self.cfg)
        return self.generator()
    

    def generator(self) -> Generator[dict[str, torch.Tensor], None, None]:
        
        while True:
            x, y = next(self.synthetic_dataset_generator)
            
            x_support, y_support, x_query, y_query = self.split_into_support_and_query(x, y)

            preprocessor = Preprocessor(
                max_features=self.cfg.data.max_features,
                n_classes=int(y.max()) + 1,
                max_classes=self.cfg.data.max_classes,
                use_quantile_transformer=self.cfg.preprocessing.use_quantile_transformer,
                use_feature_count_scaling=self.cfg.preprocessing.use_feature_count_scaling,
                shuffle_classes=True,
                shuffle_features=True,
                random_mirror_x=True
            )
            preprocessor.fit(x_support, y_support)

            x_support = preprocessor.transform_X(x_support)
            y_support = preprocessor.transform_y(y_support)
            x_query = preprocessor.transform_X(x_query)
            y_query = preprocessor.transform_y(y_query)

            x_support_tensor = torch.tensor(x_support, dtype=torch.float32)
            y_support_tensor = torch.tensor(y_support, dtype=torch.float32)
            x_query_tensor = torch.tensor(x_query, dtype=torch.float32)
            y_query_tensor = torch.tensor(y_query, dtype=torch.int64)

            yield {
                'x_support': x_support_tensor,
                'y_support': y_support_tensor,
                'x_query': x_query_tensor,
                'y_query': y_query_tensor
            }




    def split_into_support_and_query(self, x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:

        curr_samples = x.shape[0]

        n_samples_support = np.random.randint(low=self.cfg.data.min_samples_support, high=self.cfg.data.max_samples_support)
        rand_index = np.random.permutation(curr_samples)

        rand_support_index = rand_index[:n_samples_support]
        rand_query_index = rand_index[n_samples_support:n_samples_support+self.cfg.data.n_samples_query]

        x_support = x[rand_support_index]
        y_support = y[rand_support_index]
        x_query = x[rand_query_index]
        y_query = y[rand_query_index]

        return x_support, y_support, x_query, y_query
    
