from pathlib import Path

import numpy as np
import torch
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.multiclass import unique_labels

from tabicl.config.config_run import ConfigRun
from tabicl.core.dataset_split import make_stratified_dataset_split
from tabicl.core.enums import ModelName, Task
from tabicl.core.trainer_finetune import TrainerFinetune
from tabicl.models.foundation.foundation_transformer import FoundationTransformer


class TabForestPFNClassifier(BaseEstimator, ClassifierMixin):

    def __init__(
        self, 
        n_classes: int = 2,
        seed: int = 0,
        backbone: str = 'tabforestpfn',
        overwrite_hyperparams: dict = {}
    ):
        super().__init__()
        
        self.n_classes = n_classes
        self.output_dir = Path('output_tabforestpfn')
        self.seed = seed
        self.backbone = backbone
        self.overwrite_hyperparams = overwrite_hyperparams
        
        assert backbone in ['tabforestpfn', 'tabforest', 'tabpfn']

        path_to_weights_options = {
            'tabforestpfn': 'weights/tabforestpfn.pt',
            'tabforest': 'weights/tabforest.pt',
            'tabpfn': 'weights/tabpfn_retrained.pt'
        }
        path_to_weights = path_to_weights_options[backbone]

        hyperparams={
            'dim_model': 512,
            'dim_embedding': 100,
            'n_layers': 12,
            'n_heads': 4,
            'y_as_float_embedding': True,
            'quantile_embedding_gpu': True,
            'feature_count_scaling_gpu': True,
            'n_classes': 10,
            'max_samples_support': 8192,
            'max_samples_query': 1024,
            'max_epochs': 300,
            'optimizer': 'adamw',
            'lr': 1.e-5,
            'weight_decay': 0,
            'lr_scheduler': False,
            'lr_scheduler_patience': 25,
            'warmup_steps': 0,
            'early_stopping_patience': 40,
            'early_stopping_data_split': 'VALID',
            'early_stopping_max_samples': 2048,
            'precision': 'float32',
            'grad_scaler_enabled': False,
            'grad_scaler_scale_init': 65536.,
            'grad_scaler_scale_min': 65536.,
            'grad_scaler_growth_interval': 1000,
            'label_smoothing': 0.0,
            'use_pretrained_weights': True,
            'path_to_weights': path_to_weights,
            'use_quantile_transformer': False,
            'use_feature_count_scaling': False,
            'shuffle_classes': True,
            'shuffle_features': False,
            'random_mirror_x': True,
        }   

        hyperparams.update(overwrite_hyperparams)

        cfg = ConfigRun(
            output_dir=self.output_dir,
            device=torch.device('cpu'),
            cpus=None,
            seed=seed,
            model_name=ModelName.FOUNDATION,
            task=Task.CLASSIFICATION,
            dataset_size=None,
            openml_dataset_id=0,
            openml_dataset_name='custom',
            datafile_path=Path(),
            hyperparams=hyperparams 
        )

        model = FoundationTransformer(
            dim_model=cfg.hyperparams['dim_model'],
            dim_embedding=cfg.hyperparams['dim_embedding'],
            dim_output=cfg.hyperparams['dim_output'],
            n_layers=cfg.hyperparams['n_layers'],
            n_heads=cfg.hyperparams['n_heads'],
            y_as_float_embedding=cfg.hyperparams['y_as_float_embedding'],
            quantile_embedding=cfg.hyperparams['quantile_embedding_gpu'],
            feature_count_scaling=cfg.hyperparams['feature_count_scaling_gpu'],
            use_pretrained_weights=cfg.hyperparams['use_pretrained_weights'],
            path_to_weights=cfg.hyperparams['path_to_weights']
        )
        
        self.trainer = TrainerFinetune(cfg, model, n_classes=n_classes)


    def fit(self, X, y):

        self.classes_ = unique_labels(y)
        self.X_ = X
        self.y_ = y

        X_train, X_valid, y_train, y_valid = make_stratified_dataset_split(X, y)
        self.trainer.train(X_train, y_train, X_valid, y_valid)

        return self
    
    def predict(self, X):

        logits =  self.trainer.predict(self.X_, self.y_, X)
        return logits.argmax(axis=1)
    
    def predict_proba(self, X):
        logits = self.trainer.predict(self.X_, self.y_, X)
        return np.exp(logits) / np.exp(logits).sum(axis=1)[:, None]
    