import logging
import gc
from typing import Optional

import hydra
import numpy as np
from jax import random
from omegaconf import DictConfig

from fair_dp_sgd.accounting import get_number_of_steps_for_target_epsilon
from fair_dp_sgd.data import get_data_stream
from fair_dp_sgd.models import get_model
from fair_dp_sgd.training.training_routine import train_and_evaluate
from fair_dp_sgd.utils.cache import cache_results 


@hydra.main(version_base=None, config_path="conf", config_name=None)
def main(cfg: DictConfig):
    """
    Unified finetuning entrypoint.
    - Works with existing configs via --config-name (e.g., finetune_demographic_parity, finetune_equalized_odds, finetune_fnr, finetune_folkstable)
    - Parameterizes default steps if sigma == 0 via training_params.default_steps_if_sigma_zero (optional)
    - Optional metrics caching when metrics_cache.enable is True
    """
    acc_results = []
    metrics_history = []

    gamma = cfg.algorithm.gamma
    default_steps_if_sigma_zero: Optional[int] = (
        getattr(cfg.training_params, "default_steps_if_sigma_zero", None)
    )
    cache_metrics: bool = bool(getattr(cfg, "metrics_cache", {}).get("enable", False))
    cache_key_prefix: str = (
        str(getattr(cfg, "metrics_cache", {}).get("key_prefix", "metrics_per_seed"))
    )

    for _, seed in enumerate(cfg.training_params.seeds):
        try:
            key = random.PRNGKey(seed)
            data_key, model_key, training_key = random.split(key, num=3)
            (train_stream, val_data, test_data) = get_data_stream(cfg, data_key, seed)

            if cfg.algorithm.sigma == 0:
                if default_steps_if_sigma_zero is not None:
                    cfg.training_params.number_of_steps = int(default_steps_if_sigma_zero)
                else:
                    # Sensible default if not provided by config
                    cfg.training_params.number_of_steps = 5000
            else:
                cfg.training_params.number_of_steps = get_number_of_steps_for_target_epsilon(cfg)

            if cfg.training_params.number_of_steps == 0:
                return 0
            logging.info(f"Config:  {cfg}")

            state = get_model(cfg, model_key)
            results = train_and_evaluate(
                cfg=cfg,
                state=state,
                train_stream=train_stream,
                rng=training_key,
                test_data=test_data,
                val_data=val_data,
            )

            train_disparity = results["train_hard_constraint"]
            filtered_results = results[train_disparity - gamma < 0]
            filtered_results["seed"] = seed
            metrics_history.append(filtered_results)

            if len(filtered_results) == 0:
                max_accuracy = 0
            else:
                max_accuracy = filtered_results["val_accuracy"].max()

            acc_results.append(max_accuracy)

            del results
            del train_stream
            del test_data
            del val_data
            gc.collect()
        except Exception:
            import traceback
            import sys
            traceback.print_exception(*sys.exc_info())
            acc_results.append(0)
            logging.info(f"{cfg} has failed on seed {seed}")
            logging.error(f"{cfg} has failed on seed {seed}")
            return 0

    if len(acc_results) == 0:
        return 0

    if cache_metrics and cache_results is not None and len(metrics_history) > 0:
        try:
            import pandas  
            df = pandas.concat(metrics_history)
            logging.info(f"Accuracy results: {acc_results}")
            cache_results(str(cache_key_prefix), df)
            del df
        except Exception:
            pass

    gc.collect()
    return np.mean(acc_results)


if __name__ == "__main__":
    main() 