import os
import torch
from torch.distributed import init_process_group
import numpy as np

import glob
import json
import time

def ddp_setup():
    init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

def aggregate_results(assets_path, cfg):
    """Collects and averages results from all ranks."""
    zero_shot_files = glob.glob(os.path.join(assets_path, "results-*.json"))
    finetune_files = glob.glob(os.path.join(assets_path, "results-*-finetune.json"))

    all_zero_shot_results = []
    all_finetune_results = []

    # Load all zero-shot evaluation results
    for file in zero_shot_files:
        if 'finetune' in file: 
            continue
        with open(file, "r") as f:
            all_zero_shot_results.append(json.load(f))

    # Load all fine-tune evaluation results
    for file in finetune_files:
        with open(file, "r") as f:
            all_finetune_results.append(json.load(f))

    # Ensure all processes have written their results
    if len(all_zero_shot_results) == 0 or len(all_finetune_results) == 0:
        print("WARNING: No results found. Aggregation may be incomplete.")

    # Aggregate results (compute mean for numeric values)
    def compute_mean(results_list):
        if not results_list:
            return None
        return {key: sum(d[key] for d in results_list) / len(results_list) for key in results_list[0]}

    aggregated_results = {
        "model": cfg.backbone,
        "pooling": cfg.pooling,
        "zero_shot": compute_mean(all_zero_shot_results),
        "finetune": compute_mean(all_finetune_results),
    }

    # Save aggregated results
    final_results_path = os.path.join(assets_path, "final_results.json")
    with open(final_results_path, "w") as f:
        json.dump(aggregated_results, f, indent=4)

    print("Final Aggregated Results:", aggregated_results)


def summarise_final_results(main_assets_path, output_filename="final_summary.json"):
    """
    Go into each run folder under `main_assets_path`, extract final evaluation metrics,
    and compute mean and std for each metric across all runs.

    Parameters:
    - main_assets_path: str, path to folder containing run-{i} folders.
    - output_filename: str, name of the summary file to save.

    Returns:
    - A dictionary with mean and std for each metric.
    """
    all_metrics = []

    # Go through each run folder
    for run_folder in sorted(os.listdir(main_assets_path)):
        run_path = os.path.join(main_assets_path, run_folder)
        if os.path.isdir(run_path) and run_folder.startswith("run-"):
            result_path = os.path.join(run_path, "final_results.json")  # Adjust if your file name varies
            if os.path.exists(result_path):
                with open(result_path, "r") as f:
                    metrics = json.load(f)['finetune']
                    all_metrics.append(metrics)
            else:
                print(f"Warning: Missing result file in {run_folder}")

    # Make sure we have some data
    if not all_metrics:
        print("No results found to summarize.")
        return {}

    # Assume all metric dicts have the same keys
    summary = {}
    metric_keys = all_metrics[0].keys()

    for key in metric_keys:
        values = [run[key] for run in all_metrics if key in run]
        mean_val = np.mean(values)
        std_val = np.std(values)
        summary[key] = {"mean": mean_val, "std": std_val}

    # Save the final summary
    summary_path = os.path.join(main_assets_path, output_filename)
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"Final summary saved to {summary_path} with contents: {summary}")
    return summary
