import os
import time
from collections.abc import Iterable
from pprint import pprint
from typing import Optional, Dict, Any

import pandas as pd

from impugen.scenarios import simulate_missing
from . import SeedContext
from .io import maybe_load, maybe_save
from ..metrics import (
    AlphaPrecision, C2ST, ColumnShapeTrend,
    MLE, DCR, DPIMIA, ImputationScore, ImputationMLE, ImbalanceMLE, EvalModelPrediction
)

import numpy as np
from sklearn.neighbors import KernelDensity
from scipy.stats import zscore


class Timer:
    """
    A simple context manager for measuring execution time.

    Usage:
        with Timer() as elapsed:
            # run your code
        print(elapsed.interval)  # seconds

    Attributes:
        start (float): The starting time when entering the context.
        end (float): The ending time when exiting the context.
        interval (float): The elapsed time in seconds (end - start).
    """

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.end = time.time()
        self.interval = self.end - self.start


def pprint_stream(*args, stream=None, **kwargs) -> None:
    """
    Print objects in a pretty format using `pprint`, and optionally direct the output to a file.

    Args:
        *args: Variable-length arguments to be pretty-printed.
        stream (Optional[str]): A file-like object (e.g., open log file) where the output is also written.
        **kwargs: Additional keyword arguments passed to pprint.
    """
    pprint(*args, **kwargs)
    if stream is not None:
        pprint(*args, **kwargs, stream=stream)


def update_key(result: Dict[str, Any], prefix: str) -> Dict[str, Any]:
    """
    Prefix each key in the dictionary with a specified prefix.

    Args:
        result (dict): A dictionary of results (key-value pairs).
        prefix (str): A string to prepend to each key, separated by a dot.

    Returns:
        dict: A new dictionary with updated keys.
    """
    return {f'{prefix}.{key}': item for key, item in result.items()}


def _run_metric(
        *args,
        metric,
        log_dir: Optional[str],
        report: dict,
        prefix: Optional[str] = None
) -> dict:
    """
    Execute a given metric's `evaluation` method, print the results,
    optionally prefix the metric keys, and update the report dictionary.

    Args:
        *args: Positional arguments to pass into `metric.evaluation`.
        metric: An object with an `evaluation` method that returns a dict of results.
        log_dir (Optional[str]): Path to a log file or directory where results can be printed.
        report (dict): The dictionary to be updated with the metric results.
        prefix (str, optional): A string prefix for each metric key.

    Returns:
        dict: The updated report dictionary.
    """
    result = metric(*args)
    pprint_stream(type(metric).__name__, stream=log_dir)
    pprint_stream(result, stream=log_dir)
    if prefix is not None:
        result = update_key(result, prefix)
    result = {key: item for key, item in result.items() if not isinstance(item, Iterable)}
    report.update(result)
    return report


def _evaluate_runtime(
        model,
        output_directory: Optional[str] = None,
        seed: Optional[int] = None,
        name: Optional[str] = None,
        **kwargs
):
    if output_directory is not None:
        os.makedirs(output_directory, exist_ok=True)
        log = open(os.path.join(output_directory, 'log.txt'), 'a')
    else:
        log = None
    report = {}
    if not hasattr(model, '_elapsed_time'):
        return pd.DataFrame()
    report['time.training'] = model._elapsed_time

    pprint_stream(
        f"{model.name}: {model._elapsed_time:.2f}s for '{name}' dataset training",
        stream=log
    )

    index = f"{name}.{output_directory}.{seed}"
    report_df = pd.DataFrame(report, index=[index])

    if output_directory is not None:
        report_path = os.path.join(output_directory, 'report.csv')
        if os.path.isfile(report_path):
            existing_reports = pd.read_csv(report_path, index_col=0)
            # Merge the new row into the existing DataFrame
            for col in report_df.columns:
                existing_reports.loc[index, col] = report_df.loc[index, col]
            existing_reports.to_csv(report_path)
        else:
            report_df.to_csv(report_path)
    return report_df


def _evaluate_generation(
        model,
        train: pd.DataFrame,
        test: Optional[pd.DataFrame] = None,
        holdout: Optional[pd.DataFrame] = None,
        output_directory: Optional[str] = None,
        seed: Optional[int] = None,
        name: Optional[str] = None,
        df: dict = None,
        **kwargs,
) -> pd.DataFrame:
    """
    Evaluate a generative model by:
      1) Generating synthetic data of the same size as `train`.
      2) Saving the generated data (if output_directory is provided).
      3) Computing metrics (AlphaPrecision, C2ST, ColumnShapeTrend, etc.) on the generated data.
      4) Optionally computing additional metrics involving `val` or `holdout`.

    The final report is saved as 'report.csv' in `output_directory`,
    appending or creating a CSV with the latest results.

    Args:
        model (GenBase): A model with a `.gen` method to produce synthetic data.
        train (pd.DataFrame): The training data used to match the size of synthetic data.
        test (pd.DataFrame, optional): An optional validation set for additional metrics (MLE, DCR).
        holdout (pd.DataFrame, optional): An optional holdout set for further metrics (DPIMIA).
        output_directory (str, optional): Directory where logs and CSV outputs are saved. If None, no saving is done.
        seed (int, optional): A random seed used for data generation.
        name (str, optional): A string to label the dataset in logs/reports.
        **kwargs: Additional arguments passed to `model.gen`.

    Returns:
        report_df (pd.DataFrame)
    """
    # Ensure the model has a .gen method
    if not hasattr(model, 'generate_uncond'):
        return pd.DataFrame()

    if model.model_flags['drop_target']:
        return pd.DataFrame()

    # Prepare the log path if an output directory is provided
    if output_directory is not None:
        os.makedirs(output_directory, exist_ok=True)
        log = open(os.path.join(output_directory, 'log.txt'), 'a')
    else:
        log = None

    report = {}

    # Generate synthetic data
    if isinstance(df, dict) and df.get('gen', None) is not None:
        gen = df['gen']
    else:
        with Timer() as elapsed:
            gen = model.generate_uncond(len(train), seed=seed, **kwargs)
        pprint_stream(
            f"{model.name}: {elapsed.interval:.2f}s for '{name}' dataset generation ({len(train)} rows)",
            stream=log
        )
        maybe_save(gen, 'generate_uncond.csv', output_directory)
        report['n_samples'] = len(gen)
        report['time.generation'] = elapsed.interval

    # Metrics on generated data
    for metric_cls in [AlphaPrecision, C2ST, ColumnShapeTrend]:
        with SeedContext(seed):
            try:
                report = _run_metric(
                    train,
                    gen,
                    metric=metric_cls(model._transform, drop_target_column=False),
                    log_dir=log,
                    report=report
                )
            except:
                pass

    # Optional metrics involving validation data
    if test is not None:
        for metric_cls in [
            MLE,
            DCR
        ]:
            with SeedContext(seed):
                try:
                    report = _run_metric(
                        train,
                        test,
                        gen,
                        metric=metric_cls(model._transform, drop_target_column=False),
                        log_dir=log,
                        report=report
                    )
                except:
                    pass

    # Optional metrics involving holdout data
    if holdout is not None:
        with SeedContext(seed):
            for metric_cls in [DPIMIA]:
                report = _run_metric(
                    train,
                    holdout,
                    test,
                    gen,
                    metric=metric_cls(model._transform, drop_target_column=False),
                    log_dir=log,
                    report=report
                )
    else:
        with SeedContext(seed):
            holdout = test.sample(frac=0.5)
            for metric_cls in [DPIMIA]:
                report = _run_metric(
                    train,
                    holdout,
                    test.drop(holdout.index),
                    gen,
                    metric=metric_cls(model._transform, drop_target_column=False),
                    log_dir=log,
                    report=report
                )

    # Convert final report to a DataFrame and save/append
    index = f"{name}.{output_directory}.{seed}"
    report_df = pd.DataFrame(report, index=[index])

    if output_directory is not None:
        report_path = os.path.join(output_directory, 'report.csv')
        if os.path.isfile(report_path):
            existing_reports = pd.read_csv(report_path, index_col=0)
            # Merge the new row into the existing DataFrame
            for col in report_df.columns:
                existing_reports.loc[index, col] = report_df.loc[index, col]
            existing_reports.to_csv(report_path)
        else:
            report_df.to_csv(report_path)
    return report_df


def _evaluate_imbalance_learning(
        model,
        train: pd.DataFrame,
        test: Optional[pd.DataFrame] = None,
        output_directory: Optional[str] = None,
        seed: Optional[int] = None,
        name: Optional[str] = None,
        df: dict = None,
        **kwargs
) -> pd.DataFrame:
    # Ensure the model has a .cond_gen method
    if not hasattr(model, 'rebalance_targets'):
        return pd.DataFrame()

    if model.model_flags['drop_target'] or model.tgt not in model.tabular_transform.columns:
        return pd.DataFrame()

    # Prepare the log path if an output directory is provided
    if output_directory is not None:
        os.makedirs(output_directory, exist_ok=True)
        log = open(os.path.join(output_directory, 'log.txt'), 'a')
    else:
        log = None

    report = {}

    # Generate synthetic data
    if isinstance(df, dict) and df.get('rebalanced', None) is not None:
        rebalanced = df['rebalanced']
    else:
        with Timer() as elapsed:
            rebalanced = model.rebalance_targets(train, seed=seed, **kwargs)
        pprint_stream(
            f"{model.name}: {elapsed.interval:.2f}s for '{name}' dataset rebalancing ({len(rebalanced)} rows)",
            stream=log
        )
        maybe_save(rebalanced, 'rebalanced.csv', output_directory)
        report['n_samples.rebalancing'] = len(rebalanced)
        report['time.rebalancing'] = elapsed.interval

    # Optional metrics involving validation data
    with SeedContext(seed):
        if test is not None:
            for metric_cls in [ImbalanceMLE]:
                report = _run_metric(
                    train,
                    test,
                    rebalanced,
                    metric=metric_cls(model._transform, drop_target_column=False),
                    log_dir=log,
                    report=report,
                    prefix='IMB'
                )

    # Convert final report to a DataFrame and save/append
    index = f"{name}.{output_directory}.{seed}"
    report_df = pd.DataFrame(report, index=[index])

    if output_directory is not None:
        report_path = os.path.join(output_directory, 'report.csv')
        if os.path.isfile(report_path):
            existing_reports = pd.read_csv(report_path, index_col=0)
            # Merge the new row into the existing DataFrame
            for col in report_df.columns:
                existing_reports.loc[index, col] = report_df.loc[index, col]
            existing_reports.to_csv(report_path)
        else:
            report_df.to_csv(report_path)
    return report_df


def _evaluate_prediction(
        model,
        train: pd.DataFrame,
        test: Optional[pd.DataFrame] = None,
        output_directory: Optional[str] = None,
        seed: Optional[int] = None,
        name: Optional[str] = None,
        df: dict = None,
        **kwargs
) -> pd.DataFrame:
    # Ensure the model has a .impute method
    target_column = model.tgt
    if target_column is None:
        return pd.DataFrame()
    if not hasattr(model, 'predict_proba'):
        return pd.DataFrame()

    if model.model_flags['drop_target'] or model.tgt not in model.tabular_transform.columns:
        return pd.DataFrame()

    # Prepare the log path if an output directory is provided
    if output_directory is not None:
        os.makedirs(output_directory, exist_ok=True)
        log = open(os.path.join(output_directory, 'log.txt'), 'a')
    else:
        log = None

    report = {}

    test_feature = test.copy()
    test_feature[target_column] = pd.NA
    # Generate synthetic data
    if isinstance(df, dict) and df.get('pred', None) is not None:
        pred = df['pred']
    else:
        with Timer() as elapsed:
            pred = model.predict_proba(test_feature, seed=seed, mask_target_column=False, **kwargs)
        pprint_stream(
            f"{model.name}: {elapsed.interval:.2f}s for '{name}' dataset prediction ({len(pred)} rows)",
            stream=log
        )
        test_feature[target_column] = pred
        maybe_save(test_feature, 'prediction.csv', output_directory)
        report['n_samples.prediction'] = len(pred)
        report['time.prediction'] = elapsed.interval

    # Optional metrics involving validation data
    if test is not None:
        for metric_cls in [EvalModelPrediction]:
            with SeedContext(seed):
                report = _run_metric(
                    train,
                    test,
                    pred,
                    metric=metric_cls(model._transform, drop_target_column=False),
                    log_dir=log,
                    report=report,
                    prefix='Pred'
                )

    # Convert final report to a DataFrame and save/append
    index = f"{name}.{output_directory}.{seed}"
    report_df = pd.DataFrame(report, index=[index])

    if output_directory is not None:
        report_path = os.path.join(output_directory, 'report.csv')
        if os.path.isfile(report_path):
            existing_reports = pd.read_csv(report_path, index_col=0)
            # Merge the new row into the existing DataFrame
            for col in report_df.columns:
                existing_reports.loc[index, col] = report_df.loc[index, col]
            existing_reports.to_csv(report_path)
        else:
            report_df.to_csv(report_path)
    return report_df


def _evaluate_missing_value_imputation(
        model,
        missing_input,
        missing_target,
        train: pd.DataFrame,
        test: pd.DataFrame,
        output_directory: Optional[str] = None,
        seed: Optional[int] = None,
        name: Optional[str] = None,
        df: dict = None,
        **kwargs
) -> pd.DataFrame:
    # Ensure the model has a .impute method
    if not hasattr(model, 'impute'):
        return pd.DataFrame()

    # Prepare the log path if an output directory is provided
    if output_directory is not None:
        os.makedirs(output_directory, exist_ok=True)
        log = open(os.path.join(output_directory, 'log.txt'), 'a')
    else:
        log = None

    report = {}
    in_sample_input = missing_input(train)
    in_sample_target = missing_target(train)
    out_of_sample_input = missing_input(test)
    out_of_sample_target = missing_target(test)

    if isinstance(df, dict) and df.get('in_sample_impute', None) is not None:
        in_sample_impute = df['in_sample_impute']
    else:
        maybe_save(in_sample_target, 'in_sample.target.csv', output_directory)
        maybe_save(in_sample_input, 'in_sample.masked.csv', output_directory)
        # impute in-sample data
        with Timer() as elapsed:
            in_sample_impute = model.impute(in_sample_input, seed=seed, mask_target_column=True, **kwargs)
        pprint_stream(
            f"{model.name}: {elapsed.interval:.2f}s for '{name}' dataset imputation ({len(in_sample_input)} rows)",
            stream=log
        )
        maybe_save(in_sample_impute, 'in_sample.imputed.csv', output_directory)
        report['n_samples.imputation.in_sample'] = len(in_sample_input)
        report['time.imputation.in_sample'] = elapsed.interval

    for metric_cls in [ImputationScore]:
        with SeedContext(seed):
            report = _run_metric(
                in_sample_target,
                in_sample_impute,
                metric=metric_cls(model._transform, drop_target_column=True),
                log_dir=log,
                report=report,
                prefix='imputation.in_sample'
            )

    if not model.model_flags['in_sample_only']:
        if isinstance(df, dict) and df.get('out_of_sample_impute', None) is not None:
            out_of_sample_impute = df['out_of_sample_impute']
        else:
            maybe_save(out_of_sample_target, 'out_of_sample.target.csv', output_directory)
            maybe_save(out_of_sample_input, 'out_of_sample.masked.csv', output_directory)
            with Timer() as elapsed:
                out_of_sample_impute = model.impute(out_of_sample_input, seed=seed, mask_target_column=True, **kwargs)
            pprint_stream(
                f"{model.name}: {elapsed.interval:.2f}s for '{name}' dataset imputation ({len(out_of_sample_input)} rows)",
                stream=log
            )
            maybe_save(out_of_sample_impute, 'out_of_sample.imputed.csv', output_directory)
            report['n_samples.imputation.out_of_sample'] = len(out_of_sample_input)
            report['time.imputation.out_of_sample'] = elapsed.interval

        for metric_cls in [ImputationScore]:
            with SeedContext(seed):
                report = _run_metric(
                    out_of_sample_target,
                    out_of_sample_impute,
                    metric=metric_cls(model._transform, drop_target_column=True),
                    log_dir=log,
                    report=report,
                    prefix='imputation.out_of_sample'
                )

    # Convert final report to a DataFrame and save/append
    index = f"{name}.{output_directory}.{seed}"
    report_df = pd.DataFrame(report, index=[index])

    if output_directory is not None:
        report_path = os.path.join(output_directory, 'report.csv')
        if os.path.isfile(report_path):
            existing_reports = pd.read_csv(report_path, index_col=0)
            # Merge the new row into the existing DataFrame
            for col in report_df.columns:
                existing_reports.loc[index, col] = report_df.loc[index, col]
            existing_reports.to_csv(report_path)
        else:
            report_df.to_csv(report_path)
    return report_df


def evaluate_runtime(cfg, model, output_directory=None):
    seed = cfg.seed
    name = cfg.dataset.name
    kwargs = cfg.get('kwargs', dict())
    _evaluate_runtime(model, output_directory, seed, name, **kwargs)


def evaluate_generation(cfg, model, output_directory=None):
    train = maybe_load(cfg.dataset.train_path)
    test = maybe_load(cfg.dataset.test_path)
    holdout = maybe_load(cfg.dataset.holdout_path)
    seed = cfg.seed
    name = cfg.dataset.name
    kwargs = cfg.get('kwargs', dict())

    if not cfg.sample_on_eval:
        df = dict(gen=maybe_load(os.path.join(output_directory.split(os.sep + 'evaluation')[0], 'generate_uncond.csv')))
    else:
        df = None

    _evaluate_generation(model, train, test, holdout, output_directory, seed, name, df, **kwargs)


def evaluate_imbalance_learning(cfg, model, output_directory=None):
    train = maybe_load(cfg.dataset.train_path)
    test = maybe_load(cfg.dataset.test_path)
    seed = cfg.seed
    name = cfg.dataset.name
    kwargs = cfg.get('kwargs', dict())

    if not cfg.sample_on_eval:
        df = dict(rebalanced=maybe_load(os.path.join(output_directory.split(os.sep + 'evaluation')[0], 'rebalanced.csv')))
    else:
        df = None

    _evaluate_imbalance_learning(model, train, test, output_directory, seed, name, df, **kwargs)


def evaluate_prediciton(cfg, model, output_directory=None):
    train = maybe_load(cfg.dataset.train_path)
    test = maybe_load(cfg.dataset.test_path)
    seed = cfg.seed
    name = cfg.dataset.name
    kwargs = cfg.get('kwargs', dict())

    if not cfg.sample_on_eval and os.path.isfile(os.path.join(output_directory.split(os.sep + 'evaluation')[0], 'prediction.csv')):
        df = dict(pred=maybe_load(os.path.join(output_directory.split(os.sep + 'evaluation')[0], 'prediction.csv')))
    else:
        df = None

    _evaluate_prediction(model, train, test, output_directory, seed, name, df, **kwargs)


def evaluate_missing(cfg, model, output_directory=None):
    if cfg.scenario._target_ in [None, 'None', 'none']:
        return None
    else:
        cfg['missing'] = cfg.scenario  # TODO: Refactor scenario settings handling
    train = maybe_load(cfg.dataset.train_path)
    test = maybe_load(cfg.dataset.test_path)
    seed = cfg.seed
    name = cfg.dataset.name
    kwargs = cfg.get('kwargs', dict())

    missing_input = simulate_missing(cfg, model._transform, False)
    missing_target = simulate_missing(cfg, model._transform, True)

    if not cfg.sample_on_eval:
        df = dict(in_sample_impute=maybe_load(os.path.join(output_directory.split(os.sep + 'evaluation')[0], 'in_sample.imputed.csv')),
                  out_of_sample_impute=maybe_load(os.path.join(output_directory.split(os.sep + 'evaluation')[0], 'out_of_sample.imputed.csv')),)
    else:
        df = None

    _evaluate_missing_value_imputation(model,
                                       missing_input,
                                       missing_target,
                                       train,
                                       test,
                                       output_directory,
                                       seed,
                                       name,
                                       df,
                                       **kwargs)
