import logging
import os
import pickle
import time
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path

import pandas as pd
import polars as pl
import torch
from omegaconf import OmegaConf
from tqdm import tqdm

from data.data_preprocess import DataProcessor
from evaluation.eval_alphaprecision import AlphaPrecision
from evaluation.eval_dcr import DCRScore
from evaluation.eval_detection import DetectionScore
from evaluation.eval_mia import MIAScore
from evaluation.eval_ml_efficiency import (
    MLEScore,  # this leads to pkg_resources warning due to outdated dython dependencies
)
from evaluation.eval_similarity import SimilarityScores


class Experiment(ABC):
    def __init__(self, config, args):
        self.seed = args.seed
        self.config = config
        self.common_config = OmegaConf.load("experiments/configs/common_config.yaml")
        self.device = self.common_config.device

        # define paths
        exp_path = datetime.now().strftime("%Y_%m_%d_%H-%M-%S") if args.exp_path is None else args.exp_path
        self.workdir = Path("results") / args.dataset / exp_path
        os.makedirs(self.workdir / "figures", exist_ok=True)

        logging.warning(f"=== Initializing {args.dataset} dataset ===")
        self.dataset = args.dataset
        self.data_processor = DataProcessor(
            self.dataset,
            self.config.data.cat_encoding,
            self.seed,
            self.common_config.val_prop,
            self.common_config.test_prop,
            cat_min_freq=self.common_config.cat_min_freq,
            missing_mechanism=args.miss_mechanism,
            seed_missings=self.seed,
            p_miss=args.p_miss,
            p_obs=self.common_config.p_obs,
        )

        # sample as many observations as there are in the real train set if < 100k
        df_trn, _, _ = self.data_processor.get_data_splits()
        if df_trn.height <= 100_000:
            self.num_samples = df_trn.height
        else:
            self.num_samples = 100_000

        # handle overwriting of sampling steps for ablation experiments
        self.sampling_steps_overwritten = False
        if args.sampling_steps is not None:
            assert self.config.model_name == "tabcascade", (
                "Overwriting sampling steps only supported for TabCascade model."
            )
            self.config.lowres.model.generation_steps = args.sampling_steps
            self.config.highres.model.generation_steps = args.sampling_steps
            self.sampling_steps_overwritten = True

        print(f"CUDA available? {torch.cuda.is_available()}")

        if self.common_config.use_tf32 and torch.cuda.is_available():
            torch.set_float32_matmul_precision("high")

    @abstractmethod
    def train(self, **kwargs): ...

    @abstractmethod
    def sample(self, num_samples, **kwargs): ...

    @abstractmethod
    def save_model(self): ...

    @abstractmethod
    def load_model(self): ...

    def sample_datasets(self):
        logging.warning("=== Loading generative model... ===")
        self.load_model()

        if (self.config.model_name == "tabcascade") and (self.config.lowres.model.variant == "arf"):
            (self.workdir / "samples_arf_lowres").mkdir(parents=True, exist_ok=True)
        else:
            (self.workdir / "samples").mkdir(parents=True, exist_ok=True)

        includes_miss_ind = True
        if (self.config.model_name == "tabcascade") and self.config.highres.model.get("condition_on_z", True):
            includes_miss_ind = False

        logging.warning("=== Benchmarking sampling speed... ===")
        if not self.sampling_steps_overwritten:
            sample_start_time = time.monotonic()
            X_cat_gen, X_num_gen = self.sample(1000, seed=42)
            sample_duration = time.monotonic() - sample_start_time

            if not ((self.config.model_name == "tabcascade") and (self.config.lowres.model.variant == "arf")):
                with (self.workdir / "sample_time.pkl").open("wb") as f:
                    pickle.dump(sample_duration, f)

        for i in tqdm(range(self.common_config.eval_sample_iter)):
            seed = (42 + (i - 1)) * i
            X_cat_gen, X_num_gen = self.sample(self.num_samples, seed=seed)
            df_gen = self.data_processor.postprocess(X_cat_gen, X_num_gen, includes_miss_ind=includes_miss_ind)
            if not self.sampling_steps_overwritten:
                if (self.config.model_name == "tabcascade") and (self.config.lowres.model.variant == "arf"):
                    sample_folder = "samples_arf_lowres"
                else:
                    sample_folder = "samples"
                df_gen.write_parquet(self.workdir / sample_folder / f"gen_data_{seed}.parquet")

    def evaluate(self):
        logging.warning("=== Evaluating synthetic data... ===")

        # init evaluators and ground truth data
        df_trn, _, df_tst = self.data_processor.get_data_splits()

        detect_score = DetectionScore(self.data_processor.cat_cols, self.data_processor.num_cols)
        mia_score = MIAScore(self.data_processor.cat_cols, self.data_processor.num_cols)
        sim_score = SimilarityScores(df_trn, df_tst, self.data_processor.cat_cols)
        dcr_score = DCRScore(df_trn, df_tst, self.data_processor.cat_cols, self.data_processor.num_cols)
        mle_score = MLEScore(self.data_processor.cat_cols, self.data_processor.num_cols, self.data_processor.target)
        alphaprec = AlphaPrecision(self.data_processor.cat_cols)

        eval_results = []
        for i in tqdm(range(self.common_config.eval_sample_iter)):
            results = {}
            seed = (42 + (i - 1)) * i

            if (self.config.model_name == "tabcascade") and (self.config.lowres.model.variant == "arf"):
                sample_folder = "samples_arf_lowres"
            else:
                sample_folder = "samples"
            df_gen = pl.read_parquet(self.workdir / sample_folder / f"gen_data_{seed}.parquet")

            results.update(detect_score.estimate_score(df_trn, df_gen, seed=seed, nfold=5))
            results.update(mia_score.estimate_score(df_trn, df_tst, df_gen, seed=seed, n_iter=5))
            results.update(sim_score.compute_similarity(df_trn, df_gen))
            shape_trend = sim_score.compute_colwise_density_metrics(df_trn, df_gen)
            shape_trend = pd.json_normalize(shape_trend, sep="_").iloc[0].to_dict()  # concat keys
            results.update(shape_trend)

            dcr_res = dcr_score.compute_dcr(df_gen, seed=seed)
            del dcr_res["dcr_raw"]
            results.update(dcr_res)

            results.update(mle_score.get_score(df_trn, df_tst, df_gen, seed=seed))
            results.update(alphaprec.estimate_scores(df_trn, df_gen, seed))

            eval_results.append(results)

        # aggregate results
        eval_results = {k: [d[k] for d in eval_results] for k in eval_results[0]}

        if self.sampling_steps_overwritten:
            file_name = f"eval_results_{self.config.lowres.model.generation_steps}_steps.pkl"
        else:
            file_name = "eval_results.pkl"

        if (self.config.model_name == "tabcascade") and (self.config.lowres.model.variant == "arf"):
            file_name = "arf_low_res_" + file_name

        with (self.workdir / file_name).open("wb") as f:
            pickle.dump(eval_results, f)
        logging.info("=== Evaluation finished, results saved! ===")

    def evaluate_motivation(self):
        logging.warning("=== Loading generative model... ===")

        eval_results = []
        detect_score = DetectionScore(self.data_processor.cat_cols, self.data_processor.num_cols)
        df_trn, _, _ = self.data_processor.get_data_splits()

        for i in tqdm(range(self.common_config.eval_sample_iter)):
            results = {}
            seed = (42 + (i - 1)) * i

            # load samples
            df_gen = pl.read_parquet(self.workdir / "samples" / f"gen_data_{seed}.parquet")

            results["detection_all"] = detect_score.estimate_score(df_trn, df_gen, seed=seed, nfold=5, drop="none")[
                "detection_score"
            ]
            results["detection_cat"] = detect_score.estimate_score(df_trn, df_gen, seed=seed, nfold=5, drop="num")[
                "detection_score"
            ]
            results["detection_num"] = detect_score.estimate_score(df_trn, df_gen, seed=seed, nfold=5, drop="cat")[
                "detection_score"
            ]

            eval_results.append(results)

        # aggregate results
        eval_results = {k: [d[k] for d in eval_results] for k in eval_results[0]}
        with (self.workdir / "motivation_results.pkl").open("wb") as f:
            pickle.dump(eval_results, f)

        print("=== Motivation results saved! ===")

    def save_train_time(self, duration):
        """Save training time in minutes."""
        with open(self.workdir / "train_time.pkl", "wb") as f:
            pickle.dump(duration / 60, f)
