import hydra
import torch
import logging
import numpy as np
import pandas as pd
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
from tt_sbi.tasks import get_task
from tt_sbi.inference.npe import load_npe_model, sample_npe_posterior
from tt_sbi.inference.nn import get_embedding_net
from tt_sbi.tta.adapters import TTAPosterior, NPEPFNPosterior
from tt_sbi.tta.rff import RFFTTAAdapter, RFFTTAConfig
from tt_sbi.tta.ncpp import NCPPTTAAdapter, NCPPTTAConfig
from tt_sbi.tta.ocsvm_tta import OCSVMTTAAdapter, OCSVMTTAConfig
from tt_sbi.utils.metrics import MMDLoss, compute_posterior_mmd, compute_predictive_mmd, compute_rmse, compute_coverage
from tt_sbi.utils.misc import find_models, load_ncpp_model
from npe_pfn import TabPFN_Based_NPE_PFN

log = logging.getLogger(__name__)

def evaluate(methods, test_data, task, embedding_net, use_summary_stats, device, n_eval=50, n_samples=1000):
    thetas = test_data["thetas"][:n_eval]
    X_obs = test_data["X_obs"][:n_eval]
    X_clean = test_data["X_clean"][:n_eval]
    S_obs = test_data["S_obs"][:n_eval]
    
    S_clean = task.compute_summary_statistics(X_clean)
    
    has_analytic_posterior = hasattr(task, 'sample_posterior')
    
    if use_summary_stats:
        S_tta = S_obs
        S_clean_tta = S_clean
    else:
        with torch.no_grad():
            S_tta = embedding_net(S_obs.to(device)).cpu()
            S_clean_tta = embedding_net(S_clean.to(device)).cpu()
    
    results = {name: [] for name in methods}
    saved_summaries = {name: {} for name in methods}
    
    sample_data = {
        "theta_true": thetas.cpu(),
        "x_obs": X_obs.cpu(),
        "x_clean": X_clean.cpu(),
        "posterior_samples": {name: [] for name in methods},
        "true_posterior_samples": [] if has_analytic_posterior else None,
    }
    
    for i in tqdm(range(len(thetas)), desc="Evaluating"):
        theta_true = thetas[i]
        x_obs = X_obs[i]
        x_clean = X_clean[i]
        s_obs = S_obs[i]     
        s_tta = S_tta[i]     
        
        true_posterior_samples = None
        if has_analytic_posterior:
            true_posterior_samples = task.sample_posterior(x_clean, n_samples=n_samples, seed=42 + i)
            sample_data["true_posterior_samples"].append(true_posterior_samples.cpu())
        
        x_obs_tta = x_obs
        if x_obs.ndim == 3:
            N, H, W = x_obs.shape
            x_obs_tta = x_obs.reshape(N, H * W)
        
        for name, method in methods.items():
            instance_summary = {}
            l2_adapted_clean = float('nan') 
            posterior_mmd_clean = float('nan') 
            n_steps = 0 
            
            if isinstance(method, (TTAPosterior, NPEPFNPosterior)):
                samples, info = method.sample(x_obs_tta, n_samples=n_samples, adapt=True,
                                          return_info=True, s_obs=s_tta)
                if info and 'best_s' in info:
                    s_adapted = info['best_s'].cpu()
                    instance_summary['s_adapted'] = s_adapted
                    n_steps = info.get('n_steps', 0)
                    s_clean_tta_i = S_clean_tta[i]
                    
                    if 'ocsvm' in name and not use_summary_stats and s_adapted.shape != s_clean_tta_i.shape:
                        with torch.no_grad():
                            s_adapted_for_l2 = embedding_net(s_adapted.unsqueeze(0).to(device)).squeeze(0).cpu()
                    else:
                        s_adapted_for_l2 = s_adapted
                    
                    l2_adapted_clean = torch.sqrt(((s_adapted_for_l2 - s_clean_tta_i) ** 2).sum()).item()
            else:
                samples = sample_npe_posterior(method, s_obs, n_samples, device)
            
            sample_data["posterior_samples"][name].append(samples.cpu())
            
            if true_posterior_samples is not None:
                posterior_mmd_clean = compute_posterior_mmd(samples, true_posterior_samples, device)
            
            _, coverage_90 = compute_coverage(samples, theta_true, alpha=0.1)
            _, coverage_95 = compute_coverage(samples, theta_true, alpha=0.05)
            
            pred_mmd = compute_predictive_mmd(samples, x_clean, task, device, seed=42+i)
            
            instance_summary['s_obs'] = s_obs.cpu() if torch.is_tensor(s_obs) else torch.tensor(s_obs)
            instance_summary['s_clean'] = S_clean[i].cpu() if torch.is_tensor(S_clean) else torch.tensor(S_clean[i])
            saved_summaries[name][i] = instance_summary

            results[name].append({
                "rmse": compute_rmse(samples, theta_true),
                "pred_mmd": pred_mmd,
                "l2_adapted_clean": l2_adapted_clean,
                "posterior_mmd_clean": posterior_mmd_clean,
                "n_steps": n_steps,
                "coverage_90": coverage_90,
                "coverage_95": coverage_95,
            })
    
    for name in methods:
        sample_data["posterior_samples"][name] = torch.stack(sample_data["posterior_samples"][name])
    
    if has_analytic_posterior:
        sample_data["true_posterior_samples"] = torch.stack(sample_data["true_posterior_samples"])
    
    return results, saved_summaries, sample_data


def aggregate(results):
    summary = {}
    for name, vals in results.items():
        rmse = [v["rmse"] for v in vals]
        mmd = [v["pred_mmd"] for v in vals]
        l2 = [v["l2_adapted_clean"] for v in vals]
        post_mmd = [v["posterior_mmd_clean"] for v in vals]
        cov_90 = [v["coverage_90"] for v in vals]
        cov_95 = [v["coverage_95"] for v in vals]
        
        n_samples = len(cov_90)
        cov_90_mean = np.mean(cov_90)
        cov_95_mean = np.mean(cov_95)
        
        summary[name] = {
            "rmse_mean": np.mean(rmse),
            "rmse_std": np.std(rmse),
            "pred_mmd_mean": np.mean(mmd),
            "pred_mmd_std": np.std(mmd),
            "l2_adapted_clean_mean": np.nanmean(l2),
            "l2_adapted_clean_std": np.nanstd(l2),
            "posterior_mmd_clean_mean": np.nanmean(post_mmd),
            "posterior_mmd_clean_std": np.nanstd(post_mmd),
            "coverage_90_mean": cov_90_mean,
            "coverage_90_se": np.sqrt(cov_90_mean * (1 - cov_90_mean) / n_samples),
            "coverage_95_mean": cov_95_mean,
            "coverage_95_se": np.sqrt(cov_95_mean * (1 - cov_95_mean) / n_samples),
        }
    return summary


@hydra.main(version_base=None, config_path="../configs", config_name="gaussian_config")
def main(cfg: DictConfig):
    log.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
    
    import random
    seed = cfg.get("seed", 42)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    log.info(f"Global seed set to {seed}")
    
    device = cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu")
    output_dir = Path(cfg.get("output_dir", "results"))
    output_dir.mkdir(parents=True, exist_ok=True)
    
    use_summary_stats = cfg.get("use_summary_stats", True)
    
    task = get_task(cfg)
    
    d, n = cfg.dim, cfg.n_obs
    train_path = Path(cfg.data_dir) / cfg.get("data_filename", f"train_d{d}_n{n}.pt")
    log.info(f"Loading training data from {train_path}")
    train_data = torch.load(train_path, weights_only=False)
    X_train = train_data["X"]
    S_train = train_data["S"]
    
    models_dir = Path(cfg.get("models_dir", cfg.output_dir))
    available = find_models(models_dir, d, n)
    log.info(f"Found models: {list(available.keys())}")
    
    if not available:
        raise ValueError(f"No models found in {models_dir} for d={d}, n={n}")
    
    dummy_theta = train_data["thetas"][:10]
    dummy_x = S_train[:10] 
    
    methods = {}
    for name, path in available.items():
        log.info(f"Loading {name} from {path}")
        model, _ = load_npe_model(path, dummy_theta, dummy_x, device)
        methods[name] = model
    
    embedding_net = get_embedding_net(methods["npe"])
    
    use_rff_tta = cfg.get("use_tta", True)
    use_ncpp_tta = cfg.get("use_ncpp_tta", True)
    
    rff_adapter = None
    ncpp_adapter = None
    
    if (use_rff_tta or use_ncpp_tta) and "npe" in methods:
        if use_summary_stats:
            S_tta_train = S_train.numpy()
        else:
            log.info("Computing embedded summaries for TTA training...")
            with torch.no_grad():
                S_tta_train = embedding_net(S_train.to(device)).cpu().numpy()
        
        X_for_tta = X_train.numpy()
        if X_for_tta.ndim == 4:
            M, N, H, W = X_for_tta.shape
            X_for_tta = X_for_tta.reshape(M, N, H * W)
        
        if use_rff_tta:
            log.info("Setting up RFF-TTA adapter...")
            
            if cfg.model_type == "cryo_em":
                X_rff, S_rff = X_for_tta[:10_000], S_tta_train[:10_000]
            else:
                X_rff, S_rff = X_for_tta, S_tta_train
            
            rff_dim = cfg.get("rff_dim", 512)
            regressor_epochs = cfg.get("regressor_epochs", 100)
            log.info(f"RFF-TTA config: rff_dim={rff_dim}, regressor_epochs={regressor_epochs}")
            
            rff_adapter = RFFTTAAdapter(
                RFFTTAConfig(
                    rff_dim=rff_dim,
                    regressor_epochs=regressor_epochs,
                    seed=seed,
                ),
                device=device
            )
            rff_adapter.fit(X_rff, S_rff, calibration_split=cfg.get("rff_calibration_split", 0.05))
            log.info(f"RFF-TTA fitted: gamma={rff_adapter.gamma:.6f}, tau={rff_adapter.config.tau}")
            methods["npe+tta"] = TTAPosterior(methods["npe"], rff_adapter, device, bypass_embedding=not use_summary_stats)
        
        if use_ncpp_tta:
            ncpp_mode = cfg.get("ncpp_mode", "composite")
            ncpp_path = models_dir / f"ncpp_{ncpp_mode}_d{d}_n{n}.pt"
            
            if ncpp_path.exists():
                log.info(f"Setting up NCPP-TTA adapter from {ncpp_path}...")
                ncpp_model, ncpp_cfg = load_ncpp_model(ncpp_path, device)
                
                ncpp_adapter = NCPPTTAAdapter(
                    ncpp_model,
                    NCPPTTAConfig(
                        n_samples=cfg.get("ncpp_n_samples", 3000),
                        tta_steps=cfg.get("ncpp_tta_steps", 200),
                        lr=cfg.get("ncpp_lr", 0.1),
                        seed=seed,
                    ),
                    device=device
                )
                ncpp_adapter.fit(X_for_tta, S_tta_train)
                methods["npe+ncpp_tta"] = TTAPosterior(
                    methods["npe"], ncpp_adapter, device, 
                    bypass_embedding=not use_summary_stats
                )
            else:
                log.info(f"NCPP model not found at {ncpp_path}, skipping NCPP-TTA")
    
    use_ocsvm_tta = cfg.get("use_ocsvm_tta", True)
    ocsvm_adapter = None
    
    if use_ocsvm_tta and "npe" in methods:
        ocsvm_path = models_dir / f"ocsvm_d{d}_n{n}.pkl"
        
        if ocsvm_path.exists():
            log.info(f"Setting up OC-SVM TTA adapter from {ocsvm_path}...")
            
            ocsvm_adapter = OCSVMTTAAdapter(
                OCSVMTTAConfig(
                    detector_path=str(ocsvm_path),
                    seed=cfg.get("seed", 42),
                ),
                device=device
            )
            
            X_for_ocsvm = X_train.numpy()
            
            ocsvm_adapter.fit(
                X_for_ocsvm, 
                S_tta_train if 'S_tta_train' in dir() else S_train.numpy(),
                summary_fn=task.compute_summary_statistics
            )
            
            methods["npe+ocsvm"] = TTAPosterior(
                methods["npe"], ocsvm_adapter, device,
                bypass_embedding=False 
            )
        else:
            log.info(f"OC-SVM detector not found at {ocsvm_path}, skipping OC-SVM TTA")
            log.info(f"  Run: python scripts/train_ocsvm.py --config-name={cfg.model_type}_config")
    
    use_npe_pfn = cfg.get("use_npe_pfn", False)
    
    if use_npe_pfn:
        log.info("Setting up NPE-PFN...")
        prior = task.prior
        
        npe_pfn_estimator = TabPFN_Based_NPE_PFN(prior=prior)
        
        if use_summary_stats:
            S_for_pfn = S_train
        else:
            S_for_pfn = torch.from_numpy(S_tta_train).float()
        npe_pfn_estimator.append_simulations(train_data["thetas"], S_for_pfn)
        
        methods["npe_pfn"] = NPEPFNPosterior(npe_pfn_estimator, adapter=None, device=device)
        
        if rff_adapter is not None:
            methods["npe_pfn+tta"] = NPEPFNPosterior(npe_pfn_estimator, adapter=rff_adapter, device=device)
        
        if ncpp_adapter is not None:
            methods["npe_pfn+ncpp_tta"] = NPEPFNPosterior(npe_pfn_estimator, adapter=ncpp_adapter, device=device)
        
        if ocsvm_adapter is not None:
            embed_fn = None if use_summary_stats else embedding_net
            methods["npe_pfn+ocsvm"] = NPEPFNPosterior(
                npe_pfn_estimator, adapter=ocsvm_adapter, device=device,
                embed_adapter_output=embed_fn
            )
    
    log.info(f"Methods to evaluate: {list(methods.keys())}")
    
    test_dir = Path(cfg.get("test_data_dir"))
    test_files = sorted(test_dir.glob(f"test_d{d}_n{n}_*.pt"))
    
    log.info(f"Found {len(test_files)} test scenarios in {test_dir}")
    
    all_results = []
    all_summaries = {}
    all_samples = {}
    
    for test_file in test_files:
        scenario = test_file.stem.replace(f"test_d{d}_n{n}_", "") or "well_specified"
        log.info(f"\n{'='*60}\nScenario: {scenario}\n{'='*60}")
        
        test_data = torch.load(test_file, weights_only=False)
        results, scenario_summaries, scenario_samples = evaluate(
            methods, test_data, task, embedding_net, use_summary_stats, device,
            n_eval=cfg.get("n_eval", 50),
            n_samples=cfg.get("n_samples", 1000)
        )
        summary = aggregate(results)
        
        all_summaries[scenario] = scenario_summaries
        all_samples[scenario] = scenario_samples
        
        log.info(f"{'Method':<15} {'RMSE':>10} {'Pred MMD':>12} {'Cov90':>8} {'Cov95':>8} {'Post MMD':>12}")
        log.info("-" * 70)
        for name, stats in summary.items():
            post_mmd_str = f"{stats['posterior_mmd_clean_mean']:>12.6f}" if not np.isnan(stats['posterior_mmd_clean_mean']) else f"{'N/A':>12}"
            log.info(f"{name:<15} {stats['rmse_mean']:>10.4f} {stats['pred_mmd_mean']:>12.6f} {stats['coverage_90_mean']:>8.3f} {stats['coverage_95_mean']:>8.3f} {post_mmd_str}")
            
            raw_res = results[name]
            for i, res in enumerate(raw_res):
                all_results.append({
                    "scenario": scenario, 
                    "method": name, 
                    "sample_idx": i,
                    "rmse": res["rmse"],
                    "pred_mmd": res["pred_mmd"],
                    "l2_adapted_clean": res["l2_adapted_clean"],
                    "posterior_mmd_clean": res["posterior_mmd_clean"],
                    "n_steps": res["n_steps"],
                    "coverage_90": res["coverage_90"],
                    "coverage_95": res["coverage_95"],
                })
    
    df = pd.DataFrame(all_results)
    save_path = output_dir / f"benchmark_results_detailed_d{d}_n{n}.csv"
    df.to_csv(save_path, index=False)
    log.info(f"\nSaved detailed benchmark results to {save_path}")

    torch.save(all_summaries, output_dir / f"benchmark_summaries_d{d}_n{n}.pt")
    log.info(f"Saved summaries to {output_dir / f'benchmark_summaries_d{d}_n{n}.pt'}")
    
    torch.save(all_samples, output_dir / f"benchmark_samples_d{d}_n{n}.pt")
    log.info(f"Saved samples to {output_dir / f'benchmark_samples_d{d}_n{n}.pt'}")
    
    df_agg = df.groupby(["scenario", "method"])[["rmse", "pred_mmd", "l2_adapted_clean", "posterior_mmd_clean", "n_steps", "coverage_90", "coverage_95"]].agg(["mean", "std"]).reset_index()
    df_agg.columns = ['_'.join(col).strip() if col[1] else col[0] for col in df_agg.columns.values]
    df_agg.to_csv(output_dir / f"benchmark_results_d{d}_n{n}.csv", index=False)
    log.info(f"Saved aggregated results to {output_dir / f'benchmark_results_d{d}_n{n}.csv'}")


if __name__ == "__main__":
    main()
