from pathlib import Path

import numpy as np
import torch
from loguru import logger
from sklearn.base import BaseEstimator

from tabicl.config.config_run import ConfigRun
from tabicl.core.callbacks import Checkpoint, EarlyStopping
from tabicl.core.collator import CollatorWithPadding
from tabicl.core.enums import MetricName, ModelName, Task
from tabicl.core.get_loss import get_loss
from tabicl.core.get_optimizer import get_optimizer
from tabicl.core.get_scheduler import get_scheduler
from tabicl.core.grad_scaler import GradScaler
from tabicl.data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
from tabicl.data.preprocessor import Preprocessor
from tabicl.results.prediction_metrics import PredictionMetrics
from tabicl.results.prediction_metrics_tracker import PredictionMetricsTracker


class TrainerFinetune(BaseEstimator):

    def __init__(
            self, 
            cfg: ConfigRun,
            model: torch.nn.Module,
            n_classes: int
        ) -> None:

        self.cfg = cfg
        self.model = model.to("cuda", non_blocking=True)
        self.n_classes = n_classes
        
        self.loss = get_loss(self.cfg)
        self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
        self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
        self.scaler = GradScaler(
            enabled=self.cfg.hyperparams['grad_scaler_enabled'],
            scale_init=self.cfg.hyperparams['grad_scaler_scale_init'],
            scale_min=self.cfg.hyperparams['grad_scaler_scale_min'],
            growth_interval=self.cfg.hyperparams['grad_scaler_growth_interval']
        )

        self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams['early_stopping_patience'])
        self.checkpoint = Checkpoint(Path("temp_weights"), id=str(self.cfg.device))
        self.preprocessor = Preprocessor(
            dim_embedding=self.cfg.hyperparams['dim_embedding'],
            n_classes=self.n_classes,
            dim_output=self.cfg.hyperparams['dim_output'],
            use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
            use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
            shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
            shuffle_features=self.cfg.hyperparams['shuffle_features'],
            random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
            random_mirror_regression=self.cfg.hyperparams['random_mirror_regression'],
            task=self.cfg.task
        )

        self.checkpoint.reset(self.model)



    def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):

        self.preprocessor.fit(x_train, y_train)       

        x_train_transformed = self.preprocessor.transform_X(x_train) 
        y_train_transformed = self.preprocessor.transform_y(y_train)
        
        dataset_train_generator = DatasetFinetuneGenerator(
            self.cfg,
            x = x_train_transformed,
            y = y_train_transformed,
            task = self.cfg.task,
            max_samples_support = self.cfg.hyperparams['max_samples_support'],
            max_samples_query = self.cfg.hyperparams['max_samples_query']
        )

        self.checkpoint.reset(self.model)

        metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
        self.log_start_metrics(metrics_valid)
        self.checkpoint(self.model, metrics_valid.loss)

        for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):

            dataset_train = next(dataset_train_generator)            
            loader_train = self.make_loader(dataset_train, training=True)
            self.model.train()
            
            prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)

            for batch in loader_train:
                
                with torch.autocast(device_type="cuda", dtype=getattr(torch, self.cfg.hyperparams['precision'])):
            
                    x_support = batch['x_support'].to("cuda", non_blocking=True)
                    y_support = batch['y_support'].to("cuda", non_blocking=True)
                    x_query = batch['x_query'].to("cuda", non_blocking=True)
                    y_query = batch['y_query'].to("cuda", non_blocking=True)
                    padding_features = batch['padding_features'].to("cuda", non_blocking=True)
                    padding_obs_support = batch['padding_obs_support'].to("cuda", non_blocking=True)
                    padding_obs_query = batch['padding_obs_query'].to("cuda", non_blocking=True)
                    
                    y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
                    loss = self.loss(y_hat, y_query)

                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()

                y_hat = y_hat.float()
                prediction_metrics_tracker.update(y_hat, y_query)

            metrics_train = prediction_metrics_tracker.get_metrics()
            metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)   

            self.log_metrics(epoch, metrics_train, metrics_valid)

            self.checkpoint(self.model, metrics_valid.loss)
            
            self.early_stopping(metrics_valid.loss)
            if self.early_stopping.we_should_stop():
                logger.info("Early stopping")
                break

            if epoch < self.cfg.hyperparams['warmup_steps']:
                self.scheduler_warmup.step()
            else:
                self.scheduler_reduce_on_plateau.step(metrics_valid.loss)

        self.checkpoint.set_to_best(self.model)

    
    def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
        
        self.model.eval()

        x_support_transformed = self.preprocessor.transform_X(x_support)
        x_query_transformed = self.preprocessor.transform_X(x_query)
        y_support_transformed = self.preprocessor.transform_y(y_support)
        y_query_transformed = self.preprocessor.transform_y(y_query)


        dataset = DatasetFinetune(
            self.cfg, 
            x_support = x_support_transformed, 
            y_support = y_support_transformed, 
            x_query = x_query_transformed,
            y_query = y_query_transformed,
            max_samples_support = self.cfg.hyperparams['max_samples_support'],
            max_samples_query = self.cfg.hyperparams['max_samples_query'],
        )

        loader = self.make_loader(dataset, training=False)
        prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)

        with torch.no_grad():
            for batch in loader:

                with torch.autocast(device_type="cuda", dtype=getattr(torch, self.cfg.hyperparams['precision'])):
                
                    x_s = batch['x_support'].to("cuda", non_blocking=True)
                    y_s = batch['y_support'].to("cuda", non_blocking=True)
                    x_q = batch['x_query'].to("cuda", non_blocking=True)
                    y_q = batch['y_query'].to("cuda", non_blocking=True)
                    padding_features = batch['padding_features'].to("cuda", non_blocking=True)
                    padding_obs_support = batch['padding_obs_support'].to("cuda", non_blocking=True)
                    padding_obs_query = batch['padding_obs_query'].to("cuda", non_blocking=True)
                    
                    y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)

                y_hat = y_hat.float()
                prediction_metrics_tracker.update(y_hat, y_q)

        metrics_eval = prediction_metrics_tracker.get_metrics()
        return metrics_eval
    

    def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:

        x_support_transformed = self.preprocessor.transform_X(x_support)
        x_query_transformed = self.preprocessor.transform_X(x_query)
        y_support_transformed = self.preprocessor.transform_y(y_support)

        dataset = DatasetFinetune(
            self.cfg, 
            x_support = x_support_transformed, 
            y_support = y_support_transformed, 
            x_query = x_query_transformed,
            y_query = None,
            max_samples_support = self.cfg.hyperparams['max_samples_support'],
            max_samples_query = self.cfg.hyperparams['max_samples_query'],
        )

        loader = self.make_loader(dataset, training=False)
        self.model.eval()

        y_pred_list = []

        with torch.no_grad():
            for batch in loader:

                with torch.autocast(device_type="cuda", dtype=getattr(torch, self.cfg.hyperparams['precision'])):
                
                    x_s = batch['x_support'].to("cuda", non_blocking=True)
                    y_s = batch['y_support'].to("cuda", non_blocking=True)
                    x_q = batch['x_query'].to("cuda", non_blocking=True)
                    padding_features = batch['padding_features'].to("cuda", non_blocking=True)
                    padding_obs_support = batch['padding_obs_support'].to("cuda", non_blocking=True)
                    padding_obs_query = batch['padding_obs_query'].to("cuda", non_blocking=True)
                    
                    y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)

                y_hat = y_hat[0].float().cpu().numpy()
                y_hat = self.preprocessor.inverse_transform_y(y_hat)
                y_pred_list.append(y_hat)

        y_pred = np.concatenate(y_pred_list, axis=0)
        return y_pred
    

    def load_params(self, path):
        self.model.load_state_dict(torch.load(path))
    

    def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:

        match self.cfg.model_name:
            case ModelName.TABPFN | ModelName.FOUNDATION | ModelName.FOUNDATION_FLASH:
                pad_to_max_features = True
            case ModelName.TAB2D:
                pad_to_max_features = False
            case _:
                raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")

        return torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            shuffle=training,
            pin_memory=True,
            num_workers=0,
            drop_last=False,
            collate_fn=CollatorWithPadding(
                max_features=self.cfg.hyperparams['dim_embedding'],
                pad_to_max_features=pad_to_max_features
            ),
        )
    

    def log_start_metrics(self, metrics_valid: PredictionMetrics):

        match self.cfg.task:
            case Task.REGRESSION:
                logger.info((
                    f"Epoch 000 "
                    f"| Train MSE: -.---- "
                    f"| Train MAE: -.---- "
                    f"| Train r2: -.---- "
                    f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
                    f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
                    f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
                ))
            case Task.CLASSIFICATION:
                logger.info((
                    f"Epoch 000 "
                    f"| Train CE: -.---- "
                    f"| Train acc: -.---- "
                    f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
                    f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
                ))
    

    def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):

        match self.cfg.task:
            case Task.REGRESSION:
                logger.info((
                    f"Epoch {epoch:03d} "
                    f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
                    f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
                    f"| Train r2: {metrics_train.metrics[MetricName.R2]:.4f} "
                    f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
                    f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
                    f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
                ))
            case Task.CLASSIFICATION:
                logger.info((
                    f"Epoch {epoch:03d} "
                    f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
                    f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
                    f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
                    f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
                ))