import json
import sys
import torch
from ray import tune
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.optuna import OptunaSearch
from exp_ood import OODDetectionExp
from factory import (
    make_ood_dataset,
    make_lossfn,
    make_model,
    make_optimizer,
    make_reporter,
)
import copy
import numpy as np

# Fixed list of seeds to use for multiple runs
SEEDS = [42, 123, 456, 789, 1024]

def parse_param(param):
    if not isinstance(param, dict):
        return tune.choice([param])
    if "min" in param.keys():
        if param["distribution"] == "randint":
            return tune.randint(param["min"], param["max"])
        elif param["distribution"] == "uniform":
            return tune.uniform(param["min"], param["max"])
        else:
            pass
    if "values" in param.keys():
        return tune.choice(param["values"])


def run_single_experiment(cfg, seed):
    # Create a copy of the config and update with the specific seed
    cfg_copy = copy.deepcopy(cfg)
    cfg_copy["exp"]["seed"] = seed
    
    reporter = make_reporter(cfg_copy["reporter"], cfg_copy)
    dataset_ind, dataset_ood_tr, dataset_ood_te = make_ood_dataset(cfg_copy["dataset"])
    loss_fn, eval_func = make_lossfn(cfg_copy["lossfn"])
    model = make_model(cfg_copy["model"], dataset_ind)
    
    warmup_optimizer = None
    if cfg_copy["model"]["name"].lower() == "sgcn":
        teacher, model = model
        teacher_optimizer = make_optimizer(cfg_copy["optimizer"], teacher)
    elif cfg_copy["model"]["name"].lower() == "gpn":
        optimizer, warmup_optimizer = make_optimizer(cfg_copy["optimizer"], model)
    else:
        optimizer = make_optimizer(cfg_copy["optimizer"], model)

    exp = OODDetectionExp(cfg=cfg_copy["exp"], 
                        cfg_model=cfg_copy["model"], 
                        model=model, 
                        criterion=loss_fn, 
                        eval_func=eval_func, 
                        optimizer=optimizer,
                        warmup_optimizer=warmup_optimizer,
                        reporter=reporter, 
                        dataset_ind=dataset_ind, 
                        dataset_ood_tr=dataset_ood_tr, 
                        dataset_ood_te=dataset_ood_te)

    metrics = exp.run()
    return metrics


def objective(cfg):
    # Run experiment with 5 different seeds
    all_metrics = []
    
    for seed in SEEDS:
        try:
            metrics = run_single_experiment(cfg, seed)
            all_metrics.append(metrics)
        except Exception as e:
            print(f"Error running experiment with seed {seed}: {e}")
            # If an experiment fails, we'll skip it but continue with others
            continue
    
    # Skip if all experiments failed
    if not all_metrics:
        return {
            "ood_aupr": 0.0,
            "ood_aupr_in": 0.0,
            "ood_aupr_out": 0.0,
            "ood_fpr95": 0.0,
            "ood_detection_acc": 0.0,
            "num_successful_runs": 0
        }
    
    # Aggregate metrics across all seeds
    avg_metrics = {}
    std_metrics = {}
    
    # First, collect all values for each metric
    metric_values = {
        "END_AUROC": [],
        "END_AUPR_in": [],
        "END_AUPR_out": [],
        "END_FPR95": [],
        "END_DETECTION_acc": []
    }
    
    for result in all_metrics:
        for metric in metric_values.keys():
            if metric in result:
                metric_values[metric].append(result[metric])
    
    # Calculate averages and standard deviations
    for metric, values in metric_values.items():
        if values:
            avg_metrics[metric] = np.mean(values)
            std_metrics[metric] = np.std(values)
        else:
            avg_metrics[metric] = 0.0
            std_metrics[metric] = 0.0
    
    # Return the average metrics that Ray Tune will use for optimization
    return {
        "ood_aupr": avg_metrics["END_AUROC"],
        "ood_aupr_in": avg_metrics["END_AUPR_in"],
        "ood_aupr_out": avg_metrics["END_AUPR_out"],
        "ood_fpr95": avg_metrics["END_FPR95"],
        "ood_detection_acc": avg_metrics["END_DETECTION_acc"],
        "ood_aupr_std": std_metrics["END_AUROC"],
        "ood_aupr_in_std": std_metrics["END_AUPR_in"],
        "ood_aupr_out_std": std_metrics["END_AUPR_out"],
        "ood_fpr95_std": std_metrics["END_FPR95"],
        "ood_detection_acc_std": std_metrics["END_DETECTION_acc"],
        "num_successful_runs": len(all_metrics)
    }


def parse_tuning_subconfigs(group_name, cfg):
    params, core = {}, {}
    if f"{group_name}_tuning" in cfg:
        for param in cfg[f"{group_name}_tuning"]:
            params[param["name"]] = parse_param(param)
    if group_name in cfg:
        core = cfg[group_name]
    return {**params, **core}


def main():
    config_name = sys.argv[1]
    with open("{}".format(config_name), mode="r") as f:
        cfg = json.load(f)

    dataset_cfg = cfg["dataset"]
    if isinstance(dataset_cfg, dict):
        dataset_cfg = [dataset_cfg]

    exp_config = parse_tuning_subconfigs("exp", cfg)
    model_config = parse_tuning_subconfigs("model", cfg)
    loss_fn_config = parse_tuning_subconfigs("lossfn", cfg)
    optimizer_config = parse_tuning_subconfigs("optimizer", cfg)

    reporter_cfg = {}
    if "reporter" in cfg:
        reporter_cfg = cfg["reporter"]

    for dataset in dataset_cfg:
        _cfg = {
            "dataset": dataset,
            "model": model_config,
            "lossfn": loss_fn_config,
            "optimizer": optimizer_config,
            "reporter": reporter_cfg,
            "exp": exp_config,
        }
        print("Starting hyperparameter tuning with configuration:")
        print(_cfg)
        print(f"Each configuration will be evaluated with {len(SEEDS)} seeds: {SEEDS}")

        trainable = tune.with_parameters(objective)
        if torch.cuda.is_available():
            trainable_with_resources = tune.with_resources(trainable, {"gpu": 1})
        else:
            trainable_with_resources = tune.with_resources(trainable, {"cpu": 2})

        num_samples = 100
        if cfg["model"]["name"].lower() == "gspde":
            num_samples = 1000

        algo = OptunaSearch(seed=52519)
        tuner = tune.Tuner(
            trainable_with_resources,
            tune_config=tune.TuneConfig(
                metric="ood_aupr",
                mode="max",
                search_alg=ConcurrencyLimiter(algo, max_concurrent=64),
                num_samples=num_samples,
            ),
            param_space=_cfg,
        )
        results = tuner.fit()
        
        best_result = results.get_best_result()
        print("Best config is:", best_result.config)
        print("Best averaged metrics are:")
        print(f"AUROC: {best_result.metrics['ood_aupr']:.4f} ± {best_result.metrics['ood_aupr_std']:.4f}")
        print(f"AUPR_in: {best_result.metrics['ood_aupr_in']:.4f} ± {best_result.metrics['ood_aupr_in_std']:.4f}")
        print(f"AUPR_out: {best_result.metrics['ood_aupr_out']:.4f} ± {best_result.metrics['ood_aupr_out_std']:.4f}")
        print(f"FPR95: {best_result.metrics['ood_fpr95']:.4f} ± {best_result.metrics['ood_fpr95_std']:.4f}")
        print(f"Detection Acc: {best_result.metrics['ood_detection_acc']:.4f} ± {best_result.metrics['ood_detection_acc_std']:.4f}")
        print(f"Number of successful runs: {best_result.metrics['num_successful_runs']}/{len(SEEDS)}")
        
        # Save the best configuration to a file
        output_file = f"best_config_{dataset['name']}_{dataset['ood_type']}.json"
        with open(output_file, "w") as f:
            json.dump({
                **best_result.config,
                "metrics": {
                    "auroc": best_result.metrics['ood_aupr'],
                    "auroc_std": best_result.metrics['ood_aupr_std'],
                    "aupr_in": best_result.metrics['ood_aupr_in'],
                    "aupr_in_std": best_result.metrics['ood_aupr_in_std'],
                    "aupr_out": best_result.metrics['ood_aupr_out'],
                    "aupr_out_std": best_result.metrics['ood_aupr_out_std'],
                    "fpr95": best_result.metrics['ood_fpr95'],
                    "fpr95_std": best_result.metrics['ood_fpr95_std'],
                    "detection_acc": best_result.metrics['ood_detection_acc'],
                    "detection_acc_std": best_result.metrics['ood_detection_acc_std'],
                    "num_successful_runs": best_result.metrics['num_successful_runs']
                }
            }, f, indent=4)
        print(f"Best configuration saved to {output_file}")


if __name__ == "__main__":
    main()