"""Runs experiment and configurations specified in config.yaml.
Supported experiments:
---------------------
- source_lora: train universal LoRA adapter and evaluate
- source_lora_per_domain: train per-source-domain LoRA adapters and evaluate
- hypernet: train a hypernetwork to merge source-domain LoRA adapters
- target_baselines: evaluate simple LoRA merging baseline strategies 
"""

import os
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

import open_clip
import wandb
import hydra
from omegaconf import DictConfig, OmegaConf
import statistics as stats

from experiments.train_source_models import source_train
from experiments.train_hypernet import train_hypernet
from experiments.helpers import save_results, set_seeds
from experiments.eval_target_baselines import test_baselines
from experiments.eval_source_models import eval_model

@hydra.main(
    config_path=str(Path(__file__).resolve().parent.parent / "config"),
    config_name="config",
    version_base=None
)
def main(cfg: DictConfig):
    """Run experiments and aggregate results per Hydra-composed config.

    Parameters
    ----------
    cfg : DictConfig
        Hydra configuration containing:
        - experiment: name and hyperparameters for the selected experiment
        - dataset: dataset name and domains/splits
        - constants: paths, W&B settings, dataset registries
        - backbone, num_trials, and other global options

    Behavior
    --------
    - Sets seeds per trial for reproducibility
    - Executes the selected experiment for WILDS or DomainNet branches
    - Aggregates metrics across trials and domains
    - Saves summarized results 
    """

    dataset_name = cfg.dataset.name
    num_trials = cfg.num_trials
    batch_size = cfg.experiment.batch_size
    num_epochs = OmegaConf.select(cfg, f"experiment.{dataset_name}.num_epochs", default=None)
    lr_clip = OmegaConf.select(cfg, f"experiment.{dataset_name}.lr_clip", default=None)
    experiment_name = cfg.experiment.name

    all_trials_results = []

    for trial_num in range(num_trials):
        set_seeds(trial_num)
        
        if dataset_name in cfg.constants.wilds_datasets:

            if experiment_name in ["source_lora"]:
                train_domains = cfg.dataset.train_domains
                test_domains = cfg.dataset.test_domains
                test_results = source_train(cfg, trial_num, train_domains, test_domains)
                all_trials_results.append({"global": test_results})

            elif experiment_name in ["hypernet"]:
                train_domains = cfg.dataset.train_domains
                test_domains = cfg.dataset.test_domains
                test_results = train_hypernet(cfg, trial_num, train_domains, test_domains)
                all_trials_results.append({"global": test_results})

            elif experiment_name in ["target_baselines"]:
                train_domains = cfg.dataset.train_domains
                test_domains = cfg.dataset.test_domains
                test_results = test_baselines(cfg, trial_num, train_domains, test_domains)
                all_trials_results.append({"global": test_results})

            elif experiment_name in ["source_lora_per_domain"]:
                per_domain_results = {}
                train_domains = cfg.dataset.train_domains
                test_domains = cfg.dataset.test_domains
                for train_domain in train_domains:
                    test_results = source_train(cfg, trial_num, [train_domain], test_domains)
                    per_domain_results[train_domain] = test_results
                    
                domain_avg_results = {}
                first_domain_metrics = list(per_domain_results[train_domains[0]].keys())
                for metric_name in first_domain_metrics:
                    metric_values = [per_domain_results[domain][metric_name] for domain in train_domains]
                    domain_avg_results[metric_name] = sum(metric_values) / len(metric_values)
                per_domain_results["domain_average"] = domain_avg_results
                all_trials_results.append(per_domain_results)
        else:
            # DomainNet
            domains = cfg.dataset.domains
            per_domain_results = {}

            if experiment_name in ["source_lora"]:
                for test_domain in domains:
                    train_domains = [d for d in domains if d != test_domain]
                    test_results = source_train(cfg, trial_num, train_domains, [test_domain])
                    per_domain_results[test_domain] = test_results

            elif experiment_name in ["hypernet"]:
                for test_domain in domains:
                    train_domains = [d for d in domains if d != test_domain]
                    test_results = train_hypernet(cfg, trial_num, train_domains, [test_domain])
                    per_domain_results[test_domain] = test_results
                    
            elif experiment_name in ["target_baselines"]:
                for test_domain in domains:
                    train_domains = [d for d in domains if d != test_domain]
                    test_results = test_baselines(cfg, trial_num, train_domains, [test_domain])
                    per_domain_results[test_domain] = test_results

            elif experiment_name in ["source_lora_per_domain"]:
                for train_domain in domains:
                    test_results = source_train(cfg, trial_num, [train_domain], [train_domain])
                    per_domain_results[train_domain] = test_results
            
            domain_avg_results = {}
            first_domain_metrics = list(per_domain_results[domains[0]].keys())
            for metric_name in first_domain_metrics:
                metric_values = [per_domain_results[domain][metric_name] for domain in domains]
                domain_avg_results[metric_name] = sum(metric_values) / len(metric_values)
            per_domain_results["domain_average"] = domain_avg_results
            all_trials_results.append(per_domain_results)
    
    # Compute averages across trials
    trial_averages = {}
    trial_stddevs = {}
    
    if dataset_name in ["iwildcam", "camelyon17", "fmow"] and experiment_name != "source_lora_per_domain":
        # Average the global results
        global_metrics = list(all_trials_results[0]["global"].keys())
        trial_averages["global"] = {}
        trial_stddevs["global"] = {}
        
        for metric_name in global_metrics:
            metric_values = [trial_results["global"][metric_name] for trial_results in all_trials_results]
            trial_averages["global"][metric_name] = sum(metric_values) / len(metric_values)
            trial_stddevs["global"][metric_name] = 0.0 if len(metric_values) < 2 else stats.stdev(metric_values)
    
    else:
        # For multi-domain datasets, average each domain separately
        if dataset_name == "domainnet":
            domains = cfg.dataset.domains
        else:
            domains = cfg.dataset.train_domains
            
        domains = domains + ["domain_average"]  # Include the domain average
        
        for domain in domains:
            domain_metrics = list(all_trials_results[0][domain].keys())
            trial_averages[domain] = {}
            trial_stddevs[domain] = {}
            
            for metric_name in domain_metrics:
                metric_values = [trial_results[domain][metric_name] for trial_results in all_trials_results]
                trial_averages[domain][metric_name] = sum(metric_values) / len(metric_values)
                trial_stddevs[domain][metric_name] = 0.0 if len(metric_values) < 2 else stats.stdev(metric_values)
    
    # Print results
    print("\nAverage results across trials (mean ± std):")
    for domain, metrics in trial_averages.items():
        print(f"  {domain}:")
        for metric, value in metrics.items():
            stdv = trial_stddevs[domain][metric]
            print(f"    {metric}: {value:.4f} ± {stdv:.4f}")
    
    results_to_save = {
        "lr_clip": lr_clip,
        "batch_size": batch_size,
        "num_epochs": num_epochs,
        "num_trials": num_trials,
        "dataset_name": dataset_name,
        "all_trials_results": all_trials_results,
        "trial_averages": trial_averages,
        "trial_stddevs": trial_stddevs,
    }
    results_save_path = os.path.join(
        cfg.constants.model_save_path,
        f"{experiment_name}_{dataset_name}_{cfg.backbone.replace('/', '-')}"
    )
    save_results(results_to_save, results_save_path)


if __name__ == "__main__":
    main()
