from collections import OrderedDict
from typing import Dict

import numpy as np
import pandas as pd
from sklearn.metrics import (
    classification_report,
)

from .base import EvalBase
from .grid_search import GridSearch, _METRIC_IMB


class MLE(EvalBase):
    """
    Evaluate the efficiency of machine learning model trained on real vs. synthetic data.
    Uses an internal GridSearch on XGBoost with either a regression (numerical target)
    or classification (categorical target) objective. The ratio of performance between
    model trained on synthetic vs. real data is returned.

    Inherits:
        EvalBase: A base class that defines an `_evaluation` interface.
    """

    def get_model(self, sdtype):
        return GridSearch(sdtype)

    def _preprocess_data(self, data):
        data = pd.concat([pd.DataFrame(data[0]), pd.DataFrame(data[1]).astype('category')], ignore_index=True, axis=1)
        return data

    def _evaluation(self, real_data: pd.DataFrame, val_data: pd.DataFrame, fake_data: pd.DataFrame) -> Dict:
        """
        Perform a machine learning efficiency evaluation. Trains a model on real data,
        evaluates on validation data, then trains a separate model on synthetic data
        and evaluates on the same validation data. The ratio of performance is computed.

        Args:
            real_data (pd.DataFrame): The real dataset.
            val_data (pd.DataFrame): The validation dataset on which performance is measured.
            fake_data (pd.DataFrame): The synthetic dataset to be evaluated.

        Returns:
            Dict: A dictionary containing:
                {
                    "real_score": float,
                    "fake_score": float,
                    "ratio": float
                }
            Where ratio = fake_score / real_score.
        """
        result = OrderedDict()

        feature_transform = self._transform_set.no_target
        target_transform = self._transform_set.target
        target_column = self._transform_set.target_column
        if target_column is None:
            return result

        # Transform real, val, and fake data into numeric arrays/tensors
        X_real = self._preprocess_data(feature_transform.transform(real_data, scaler='standard', onehot=False))
        X_val = self._preprocess_data(feature_transform.transform(val_data, scaler='standard', onehot=False))
        X_fake = self._preprocess_data(feature_transform.transform(fake_data, scaler='standard', onehot=False))

        # Extract targets
        y_real = target_transform.transform(real_data[[target_column]], onehot=False, scaler='standard',
                                            return_as_tensor=True)
        y_val = target_transform.transform(val_data[[target_column]], onehot=False, scaler='standard',
                                           return_as_tensor=True)
        y_fake = target_transform.transform(fake_data[[target_column]], onehot=False, scaler='standard',
                                            return_as_tensor=True)
        y_real = y_real[~y_real.isnan()]
        y_val = y_val[~y_val.isnan()]
        y_fake = y_fake[~y_fake.isnan()]

        if len(X_real) != len(y_real) or len(X_val) != len(y_val):
            return dict()

        # Determine if the target is numerical or categorical from metadata
        metadata = self.metadata['columns'][target_column]
        sdtype = metadata['sdtype']
        if (sdtype == 'categorical' and
                (len(y_real.unique()) != len(y_fake.unique()) or
                 y_real.unique(return_counts=True)[1].min() < 8 or
                 y_fake.unique(return_counts=True)[1].min() < 8
                )):
            return result

        # Train/evaluate model on real data, if not already done
        if not hasattr(self, 'real_model') or self.real_model is None:
            self.real_model = self.get_model(sdtype)
            self.real_model.fit(X_real, y_real)
        real_score = self.real_model.score(X_val, y_val)
        result['real_score'] = real_score

        # Train/evaluate model on synthetic data
        fake_model = self.get_model(sdtype)
        fake_model.fit(X_fake, y_fake)
        fake_score = fake_model.score(X_val, y_val)
        result['fake_score'] = fake_score

        # Compute ratio of performance
        result['ratio'] = fake_score / real_score if real_score != 0 else float('inf')

        if sdtype == 'categorical':
            result['real_report'] = self.real_model.classification_report(X_val, y_val)
            result['fake_report'] = fake_model.classification_report(X_val, y_val)

        result['real_hparam'] = self.real_model.best_kwargs
        result['fake_hparam'] = fake_model.best_kwargs

        return result


class ImbalanceMLE(MLE):

    def get_model(self, sdtype):
        return GridSearch(sdtype, metric=_METRIC_IMB)


class EvalModelPrediction(MLE):

    def _evaluation(self, real_data: pd.DataFrame, val_data: pd.DataFrame, model_prediction: np.ndarray) -> Dict:
        result = OrderedDict()

        feature_transform = self._transform_set.no_target
        target_transform = self._transform_set.target
        target_column = self._transform_set.target_column
        if target_column is None:
            return result

        # Transform real, val, and fake data into numeric arrays/tensors
        X_real = self._preprocess_data(feature_transform.transform(real_data, scaler='standard', onehot=False))
        X_val = self._preprocess_data(feature_transform.transform(val_data, scaler='standard', onehot=False))

        # Extract targets
        y_real = target_transform.transform(real_data[[target_column]], onehot=False, scaler='standard',
                                            return_as_tensor=True)
        y_val = target_transform.transform(val_data[[target_column]], onehot=False, scaler='standard',
                                           return_as_tensor=True)
        y_real = y_real[~y_real.isnan()]
        y_val = y_val[~y_val.isnan()]

        if len(X_real) != len(y_real) or len(X_val) != len(y_val):
            return dict()

        metadata = self.metadata['columns'][target_column]
        sdtype = metadata['sdtype']
        if sdtype == 'numerical':
            y_fake = target_transform.transform(pd.DataFrame(model_prediction.reshape(-1, 1),
                                                             columns=[target_column]), onehot=False, scaler='standard',
                                                return_as_tensor=True)
            y_fake = y_fake[~y_fake.isnan()]
        else:
            y_fake = model_prediction

        # Determine if the target is numerical or categorical from metadata
        metadata = self.metadata['columns'][target_column]
        sdtype = metadata['sdtype']

        # Train/evaluate model on real data, if not already done
        if not hasattr(self, 'real_model') or self.real_model is None:
            self.real_model = self.get_model(sdtype)
            self.real_model.fit(X_real, y_real)
        real_score = self.real_model.score(X_val, y_val)
        result['real_score'] = real_score

        # Train/evaluate model on synthetic data
        fake_score = self.real_model._scorer(y_val, y_fake)
        result['fake_score'] = fake_score

        # Compute ratio of performance
        result['ratio'] = fake_score / real_score if real_score != 0 else float('inf')

        if sdtype == 'categorical':
            result['real_report'] = self.real_model.classification_report(X_val, y_val)
            result['fake_report'] = classification_report(y_val, np.argmax(y_fake, axis=1), digits=4)

        result['real_hparam'] = self.real_model.best_kwargs

        return result


class ImputationMLE(MLE):

    def _evaluation(self, in_sample_input: pd.DataFrame,
                    out_of_sample_input: pd.DataFrame,
                    in_sample_imputed: pd.DataFrame,
                    out_of_sample_imputed: pd.DataFrame) -> Dict:

        result = OrderedDict()

        feature_transform = self._transform_set.no_target
        target_transform = self._transform_set.target
        target_column = self._transform_set.target_column
        if target_column is None:
            return result

        # Transform real, val, and fake data into numeric arrays/tensors
        X_in_sample_input = self._preprocess_data(
            feature_transform.transform(in_sample_input, scaler='standard', onehot=False))
        X_out_of_sample_input = self._preprocess_data(
            feature_transform.transform(out_of_sample_input, scaler='standard', onehot=False))
        X_in_sample_imputed = self._preprocess_data(
            feature_transform.transform(in_sample_imputed, scaler='standard', onehot=False))
        X_out_of_sample_imputed = self._preprocess_data(
            feature_transform.transform(out_of_sample_imputed, scaler='standard', onehot=False))

        # Extract targets
        y_in_sample = target_transform.transform(in_sample_input[[target_column]], onehot=False, scaler='standard',
                                                 return_as_tensor=True)
        y_out_of_sample = target_transform.transform(out_of_sample_input[[target_column]], onehot=False,
                                                     scaler='standard', return_as_tensor=True)
        y_in_sample = y_in_sample[~y_in_sample.isnan()]
        y_out_of_sample = y_out_of_sample[~y_out_of_sample.isnan()]

        # Determine if the target is numerical or categorical from metadata
        metadata = self.metadata['columns'][target_column]
        sdtype = metadata['sdtype']

        # Train/evaluate model on real data, if not already done
        if not hasattr(self, 'real_model') or self.real_model is None:
            self.real_model = GridSearch(sdtype)
            self.real_model.fit(X_in_sample_input, y_in_sample)
        real_score = self.real_model.score(X_out_of_sample_input, y_out_of_sample)
        result['real_score'] = real_score

        # Train/evaluate model on synthetic data
        fake_model = GridSearch(sdtype)
        fake_model.fit(X_in_sample_imputed, y_in_sample)
        fake_score = fake_model.score(X_out_of_sample_imputed, y_out_of_sample)
        result['fake_score'] = fake_score

        # Compute ratio of performance
        result['ratio'] = fake_score / real_score if real_score != 0 else float('inf')

        if sdtype == 'categorical':
            result['real_report'] = self.real_model.classification_report(X_out_of_sample_input, y_out_of_sample)
            result['fake_report'] = fake_model.classification_report(X_out_of_sample_imputed, y_out_of_sample)

        result['real_hparam'] = self.real_model.best_kwargs
        result['fake_hparam'] = fake_model.best_kwargs

        return result
