import argparse
import os
import json
from collections import defaultdict
from pathlib import Path
from typing import List

from alignment import analyze_alignment_results

DATASET_PAIRS = [("beavertails", "330k_test_only_unsafe"), ("real_toxicity", "train_only_unsafe"), ("ultrasafety", "train_only_unsafe")]
METHOD_NAMES = ["No Alignment", "SaP", "Re-Control", "Sample-BRT", "BRT"]

def load_results(
    base_dir: str,
    seed: int,
    model: str,
    dataset: str,
    split: str,
    method: str,
    filename: str = "combined.json",
    max_depth: int = 2,
) -> List[dict]:
    """
    Loads the merged JSON results for a given configuration.

    Search strategy:
    - Look in method directory for combined.json.
    - If not found, descend up to `max_depth` levels if each level has exactly one subdir.
    - Raise error if multiple subdirs exist (ambiguous).
    """
    path = Path(base_dir) / f"seed_{seed}" / model / dataset / split / method

    for depth in range(max_depth + 1):
        candidate = path / filename
        if candidate.exists():
            print(f"candidate = {candidate}")
            with open(candidate, "r") as f:
                return json.load(f)

        # descend if exactly one subdir
        subdirs = [d for d in path.iterdir() if d.is_dir() and "calibration" not in d.name]
        if len(subdirs) == 1:
            path = subdirs[0]
        elif len(subdirs) == 0:
            raise FileNotFoundError(f"No {filename} found under {path}")
        else:
            raise RuntimeError(
                f"Ambiguous: multiple subdirs under {path}: {[d.name for d in subdirs]}"
            )

    raise FileNotFoundError(
        f"Reached max depth={max_depth} without finding {filename} for {method}"
    )

if __name__ == "__main__":    
    p = argparse.ArgumentParser()
    p.add_argument("--model_name", type=str, default="Qwen/Qwen2-1.5B")
    p.add_argument("--seed", type=int)
    p.add_argument("--root_dir", type=str, required=True)

    args = p.parse_args()

    parsed_model_name = args.model_name.replace('/', '_')
    
    results_root = os.path.join(args.root_dir, "/alignment_results")
    save_dir = f"{args.root_dir}/final_alignment_results_only_unsafe/seed_{args.seed}"
    os.makedirs(save_dir, exist_ok=True)

    json_path = os.path.join(save_dir, f"{parsed_model_name}.json")

    results_data = defaultdict(dict)

    # To skip the methods we've already computed
    completed = {d: {m: False for m in METHOD_NAMES} for d, _ in DATASET_PAIRS}
    if os.path.exists(json_path):
        with open(json_path, "r") as f:
            loaded_json = json.load(f)

        for dataset_pair in DATASET_PAIRS:
            dataset_name = dataset_pair[0]
            if dataset_pair[0] in loaded_json:
                for method in METHOD_NAMES:
                    if method in loaded_json[dataset_name] and "average_safety_rate" in loaded_json[dataset_name][method]:
                        completed[dataset_name][method] = True
                        results_data[dataset_name][method] = loaded_json[dataset_name][method]

    for dataset_name, split in DATASET_PAIRS:
        print(f"Generating results for model_name = {args.model_name}, dataset_name = {dataset_name}")

        # SaP (direct combined.json at method level)

        if (not completed[dataset_name]["No Alignment"] or not completed[dataset_name]["SaP"]):
            sap_data = load_results(
                results_root, seed=args.seed,
                model=parsed_model_name,
                dataset=dataset_name, split=split, method="sap"
            )
            print("SaP Results")
            sap_alignment_metrics, baseline_metrics = analyze_alignment_results.compute_alignment_metrics(sap_data, compute_baseline_metrics=True)

            results_data[dataset_name]["No Alignment"] = {
                "average_safety_rate": baseline_metrics.average_safety,
                "average_cosine_similarity": baseline_metrics.average_cosine_similarity,
                "average_word_diversity": baseline_metrics.average_word_diversity
            }

            results_data[dataset_name]["SaP"] = {
                "average_safety_rate": sap_alignment_metrics.average_safety,
                "average_cosine_similarity": sap_alignment_metrics.average_cosine_similarity,
                "average_word_diversity": sap_alignment_metrics.average_word_diversity
            }

        if not completed[dataset_name]["Re-Control"]:
            # ReControl (combined.json inside 1 subdir)
            recontrol_data = load_results(
                results_root, seed=args.seed,
                model=parsed_model_name,
                dataset=dataset_name, split=split, method="recontrol"
            )
            print("Re-Control Results")
            recontrol_metrics, _ = analyze_alignment_results.compute_alignment_metrics(recontrol_data, compute_baseline_metrics=False)

            results_data[dataset_name]["Re-Control"] = {
                "average_safety_rate": recontrol_metrics.average_safety,
                "average_cosine_similarity": recontrol_metrics.average_cosine_similarity,
                "average_word_diversity": recontrol_metrics.average_word_diversity
            }


        if not completed[dataset_name]["Sample-BRT"]:
            # Sample-BRT (combined.json inside 2 subdirs)
            sample_brt_data = load_results(
                results_root, seed=args.seed,
                model=parsed_model_name,
                dataset=dataset_name, split=split, method="sample_brt"
            )

            print("Sample BRT Results")
            sample_brt_metrics, _ = analyze_alignment_results.compute_alignment_metrics(sample_brt_data, compute_baseline_metrics=False)

            results_data[dataset_name]["Sample-BRT"] = {
                "average_safety_rate": sample_brt_metrics.average_safety,
                "average_cosine_similarity": sample_brt_metrics.average_cosine_similarity,
                "average_word_diversity": sample_brt_metrics.average_word_diversity   
            }
        
        if not completed[dataset_name]["BRT"]:
            # Sample-BRT (combined.json inside 2 subdirs)
            sample_brt_data = load_results(
                results_root, seed=args.seed,
                model=parsed_model_name,
                dataset=dataset_name, split=split, method="brt"
            )

            print("BRT Results")
            sample_brt_metrics, _ = analyze_alignment_results.compute_alignment_metrics(sample_brt_data, compute_baseline_metrics=False)

            results_data[dataset_name]["BRT"] = {
                "average_safety_rate": sample_brt_metrics.average_safety,
                "average_cosine_similarity": sample_brt_metrics.average_cosine_similarity,
                "average_word_diversity": sample_brt_metrics.average_word_diversity   
            }

    with open(json_path, "w") as f:
        json.dump(results_data, f, indent=2)
