import hydra
import optuna
import os
from omegaconf import DictConfig as Config
import pandas as pd
import logging
import time
from .my_mlflow import log_training_experiment
from sklearn.model_selection import train_test_split
from src.syngen.scripts.datagen_methods import *
from src.syngen.scripts.utils import save_synthetic_data


_log = logging.getLogger(__name__)


def instantiate_datagen_method(name, params, seed=None):
    # Instantiate corresponding generator method
    if name == "random_oversampling":
        return RandomOverSampling(**params, random_state=seed)
    elif name == "smote":
        return Smote(**params, random_state=seed)
    elif name == "smoteenc":
        return SmoteENC(**params, random_state=seed)
    elif name == "ctgan":
        return CTGANGenerator(**params, random_state=seed)
    elif name == "ctabgan":
        return CTABGANGenerator(**params, random_state=seed)
    elif name == "tvae":
        return TVAEGenerator(**params, random_state=seed)
    elif name == "copulagan":
        return CopulaGANGenerator(**params, random_state=seed)
    elif name == "ttvae":
        return TTVAEGenerator(**params, random_state=seed)
    elif name == "ttvae_tbs":
        return TTVAETBSGenerator(**params, random_state=seed)
    elif name == "cttvae":
        return CTTVAEGenerator(**params, random_state=seed)
    elif name == "cttvae_tbs":
        return CTTVAETBSGenerator(**params, random_state=seed)
    else:
        raise ValueError(f"Unknown generator method: {name}")
    

def optuna_objective_for_generator(trial, cfg: Config, train_data: pd.DataFrame, method_name: str):
    params = dict(cfg.datagen_method.params)
    params["sampling_strategy"] = cfg.datagen_method.params.sampling_strategy

    if method_name in ["ctgan", "copulagan"]:
        valid_combinations = [(p, bs) for p in [1, 5, 10] for bs in [64, 128, 256, 500] if bs % p == 0]
        pac_batch = trial.suggest_categorical("pac_batch_size", valid_combinations)
        params["pac"] = pac_batch[0]
        params["batch_size"] = pac_batch[1]
        params["epochs"] = trial.suggest_int("epochs", 100, 300, step=50)
        params.pop("pac_batch_size", None)
    elif method_name in ["tvae"]:
        params["epochs"] = trial.suggest_categorical("epochs", [100, 150, 300])
        params["batch_size"] = trial.suggest_categorical("batch_size", [64, 128, 256, 512])
    elif method_name in ["ctabgan"]:
        params["epochs"] = trial.suggest_categorical("epochs", [150, 200, 250])
        params["batch_size"] = trial.suggest_categorical("batch_size", [64, 128, 256])
        params["class_dim"] = trial.suggest_categorical("class_dim", [(128, 128, 128), (256, 256, 256, 256)])
        params["random_dim"] = trial.suggest_categorical("random_dim", [64, 100, 128])
        params["num_channels"] = trial.suggest_categorical("num_channels", [32, 64, 128])
        params["l2scale"] = trial.suggest_float("l2scale", 1e-6, 1e-3, log=True)
        params["lr"] = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
    elif method_name in ["ttvae"]:
        params["epochs"] = trial.suggest_categorical("epochs", [300])
        params["l2scale"] = trial.suggest_categorical("l2scale", [0.00001, 0.0001, 0.001])
        params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128])
        params["latent_dim"] = trial.suggest_categorical("latent_dim", [16, 32, 64])
        valid_embed_nhead = [(64, 4), (128, 4), (128, 8), (256, 8)]
        embedding_dim, nhead = trial.suggest_categorical("embed_nhead", valid_embed_nhead)
        params["embedding_dim"] = embedding_dim
        params["nhead"] = nhead
        params.pop("embed_nhead", None)
        params["dim_feedforward"] = trial.suggest_categorical("dim_feedforward", [512, 1024, 2048])
        params["dropout"] = trial.suggest_float("dropout", 0.0, 0.3)
    elif method_name in ["ttvae_tbs"]:
        params["epochs"] = trial.suggest_categorical("epochs", [300])
        params["l2scale"] = trial.suggest_categorical("l2scale", [0.00001, 0.0001, 0.001])
        params["lambda_scale"] = trial.suggest_categorical("lambda_scale", [0.3, 0.4, 0.7])
        params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128])
        params["latent_dim"] = trial.suggest_categorical("latent_dim", [16, 32, 64])
        valid_embed_nhead = [(64, 4), (128, 4), (128, 8), (256, 8)]
        embedding_dim, nhead = trial.suggest_categorical("embed_nhead", valid_embed_nhead)
        params["embedding_dim"] = embedding_dim
        params["nhead"] = nhead
        params.pop("embed_nhead", None)
        params["dim_feedforward"] = trial.suggest_categorical("dim_feedforward", [512, 1024, 2048])
        params["dropout"] = trial.suggest_float("dropout", 0.0, 0.3)
    elif method_name in ["cttvae"]:
        params["epochs"] = trial.suggest_categorical("epochs", [300])
        params["l2scale"] = trial.suggest_categorical("l2scale", [0.00001, 0.0001, 0.001])
        params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128])
        params["latent_dim"] = trial.suggest_categorical("latent_dim", [16, 32, 64])
        valid_embed_nhead = [(64, 4), (128, 4), (128, 8), (256, 8)]
        embedding_dim, nhead = trial.suggest_categorical("embed_nhead", valid_embed_nhead)
        params["embedding_dim"] = embedding_dim
        params["nhead"] = nhead
        params.pop("embed_nhead", None)
        params["dim_feedforward"] = trial.suggest_categorical("dim_feedforward", [512, 1024, 2048])
        params["dropout"] = trial.suggest_float("dropout", 0.0, 0.3)
        params["triplet_margin"] = trial.suggest_float("triplet_margin", 0.1, 1.0)
    elif method_name in ["cttvae_tbs"]:
        params["epochs"] = trial.suggest_categorical("epochs", [300])
        params["l2scale"] = trial.suggest_categorical("l2scale", [0.00001, 0.0001, 0.001])
        params["lambda_scale"] = trial.suggest_categorical("lambda_scale", [0.3, 0.4, 0.7])
        params["batch_size"] = trial.suggest_categorical("batch_size", [32, 64, 128])
        params["latent_dim"] = trial.suggest_categorical("latent_dim", [16, 32, 64])
        valid_embed_nhead = [(64, 4), (128, 4), (128, 8), (256, 8)]
        embedding_dim, nhead = trial.suggest_categorical("embed_nhead", valid_embed_nhead)
        params["embedding_dim"] = embedding_dim
        params["nhead"] = nhead
        params.pop("embed_nhead", None)
        params["dim_feedforward"] = trial.suggest_categorical("dim_feedforward", [512, 1024, 2048])
        params["dropout"] = trial.suggest_float("dropout", 0.0, 0.3)
        params["triplet_margin"] = trial.suggest_float("triplet_margin", 0.1, 1.0)
    else:
        raise ValueError(f"Optuna tuning not supported for: {method_name}")

    datagen_method = instantiate_datagen_method(name=method_name, params=params, seed=cfg.seed)

    # Train model (if applicable)
    if hasattr(datagen_method, "fit"):
        model_path = os.path.join(
            cfg.paths.datagen_methods_dir,
            f"{cfg.datagen_method.name}/{cfg.dataset.name}/finetuning/",
            f"{cfg.datagen_method.model_filename}_trial_{trial}.pt"
        )
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        _log.info(f"Trial {trial}: Training '{cfg.datagen_method.name}' model...")
        datagen_method.fit(train_data, categorical_columns=cfg.dataset.categorical_columns, conditional_column=cfg.dataset.conditional_column, save_path=model_path)
        _log.info(f"'{cfg.datagen_method.name}' model trained. Trial {trial} finished.")

    trial.set_user_attr("full_params", params)
    return datagen_method.final_loss if hasattr(datagen_method, "final_loss") else 0.0


# uncomment line below if running this file directly 
# @hydra.main(version_base=None, config_path='../conf', config_name="datagen")
def run_training(cfg:Config) -> None:

    training_run_id = None
    # _log.info(OmegaConf.to_yaml(cfg))

    # datasets = load_data(name=cfg.dataset.name, data_dir=cfg.paths.data_dir)

    real_data = pd.read_csv(f"{cfg.paths.clean_data_dir}{cfg.dataset.name}.csv")
    train_data, test_data = train_test_split(real_data, test_size=0.2, stratify=real_data[cfg.dataset.target_column], random_state=42)
    train_data.to_csv(os.path.join(cfg.paths.processed_data_dir, cfg.dataset.splits.train.path), index=False)
    test_data.to_csv(os.path.join(cfg.paths.processed_data_dir, cfg.dataset.splits.test.path), index=False)

    optuna_supported = cfg.optuna_supported_datagen_methods

    if cfg.datagen_method.name in optuna_supported and cfg.optimize:
        _log.info(f"Running Optuna optimization for '{cfg.datagen_method.name}'...")
        study = optuna.create_study(direction="minimize", sampler=optuna.samplers.TPESampler(seed=cfg.seed))
        study.optimize(lambda trial: optuna_objective_for_generator(trial, cfg, train_data, cfg.datagen_method.name), n_trials=cfg.n_trials)
        _log.info("Optuna optimization completed.")
        _log.info(f"Best trial: {study.best_trial.number} with value: {study.best_trial.value}")
        best_params = study.best_trial.user_attrs["full_params"]
        cfg.datagen_method.params = best_params
        _log.info(f"Best hyperparameters: {best_params}")
    else:
        _log.info(f"Skipping Optuna optimization for '{cfg.datagen_method.name}'. Using default hyperparameters.")
        best_params = cfg.datagen_method.params
        _log.info(f"Using hyperparameters: {cfg.datagen_method.params}")


    _log.info(f"Using hyperparameters from Optuna: {best_params}")
    datagen_method = instantiate_datagen_method(name=cfg.datagen_method.name, params=best_params, seed=cfg.seed)
    _log.info(f"Generator method '{cfg.datagen_method.name}' instantiated.")

    # Create model save path
    model_path = ''

    # Train model (if applicable)
    if hasattr(datagen_method, "fit"):
        # Create model save path
        model_path = os.path.join(
                cfg.paths.datagen_methods_dir,
                f"{cfg.datagen_method.name}/{cfg.dataset.name}",
                cfg.datagen_method.model_filename
            )
        _log.info(f"Training '{cfg.datagen_method.name}' model...")
        start_time_training = time.time()
        datagen_method.fit(train_data, categorical_columns=cfg.dataset.categorical_columns, conditional_column=cfg.dataset.conditional_column, save_path=model_path)
        datagen_method.train_time = time.time() - start_time_training
        _log.info(f"'{cfg.datagen_method.name}' model trained.")

        if (cfg.mlflow):
            # Log training experiment
            training_run_id = log_training_experiment(cfg, datagen_method)
            _log.info(f"{cfg.datagen_method.name} training metrics saved to MLflow server.")


    # Load pre-trained model
    if hasattr(datagen_method, "load_model"):
        _log.info(f"Loading pre-trained model for '{cfg.datagen_method.name}'...")
        datagen_method.load_model(model_path)
        
    # Generate and save synthetic data for both 'all' and 'minority'
    gen_time = {'all': 0, 'minority': 0}
    strategy = cfg.datagen_method.params.sampling_strategy

    conditional_column = cfg.dataset.conditional_column
    if conditional_column is None and strategy != "all":
        conditional_column = cfg.dataset.target_column
        _log.warning(f"No conditional column specified for '{strategy}' sampling strategy. Defaulting to target column '{conditional_column}'.")

    _log.info(f"Generating synthetic data using '{cfg.datagen_method.name}' with strategy='{strategy}'...")
    _log.info(f"Using seed {cfg.seed} for reproducibility.")
    start_time_generation = time.time()

    synthetic_data = datagen_method.generate(
        train_data=train_data,
        conditional_column=conditional_column,
        sampling_strategy=strategy,
        n_to_generate=cfg.n_to_generate
    )
    datagen_method.gen_time = time.time() - start_time_generation
    gen_time[strategy] = datagen_method.gen_time

    _log.info(f"'{cfg.datagen_method.name}' synthetic data generated with strategy='{strategy}'.")
    _log.info("Original distribution:")
    _log.info(train_data[conditional_column].value_counts(normalize=True))

    _log.info(f"Synthetic distribution ({strategy}):")
    if synthetic_data is None:
        _log.warning("No synthetic data was returned.")
    else:
        _log.info(synthetic_data[conditional_column].value_counts(normalize=True))

    out_dir = os.path.join(cfg.paths.synth_data_dir, cfg.datagen_method.name, strategy)
    os.makedirs(out_dir, exist_ok=True)
    file_name = cfg.dataset.synthetic_path
    save_synthetic_data(synthetic_data, output_dir=out_dir, file_name=file_name)

    return training_run_id, gen_time


if __name__ == "__main__":
    run_training()
