import argparse
from pathlib import Path
from omegaconf import OmegaConf
import numpy as np
import torch
import random
from collections import defaultdict
from datetime import datetime
from metaworld.data_gen import collect_expert_dataset, create_loader
from metaworld.train_rl import build_vlm_model, compute_rewards_all, build_mdps_for_rewards, train_iql_on_mdps
from metaworld.eval_rl import eval_policies_across_envs, bootstrap_iqm_ci, interquartile_mean
import os
from pathlib import Path

Path("d3rlpy_logs").mkdir(parents=True, exist_ok=True)
os.environ["D3RLPY_LOGDIR"] = "d3rlpy_logs"

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # make cudnn deterministic (slower but reproducible)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False



def run_with_seed(cfg, seed: int, results_dir: Path, run_id: str = None):
    # 1) Collect per-task datasets
    set_seed(seed)
   
    collected = {}
    for t in cfg.collection.tasks:
        obs, acts, rews, terms, base_dir, max_len = collect_expert_dataset(
            env_name=t.env_name,
            desc=t.desc,
            success_episodes=int(cfg.collection.success_episodes),
            max_steps=int(cfg.collection.max_steps),
            seed=seed,
            camera_name=str(cfg.collection.camera_name),
            render_mode=str(cfg.collection.render_mode),
        )
        collected[t.env_name] = {
            "obs": obs, "acts": acts, "terms": terms, "rews": rews,
            "split": f"{t.env_name}_{int(cfg.collection.success_episodes)}",
        }

    # 2) Build VLM model and compute rewards
    device = torch.device(cfg.model.get("device", "cuda:0") if torch.cuda.is_available() else "cpu")
    model = build_vlm_model(cfg.model, device)
    rewards_by_task = {}
    root_dir = Path(cfg.dataset_eval.root_dir)
    for env_name, bundle in collected.items():
        loader, _ = create_loader(root_dir=root_dir, split=bundle["split"], num_workers=int(cfg.dataset_eval.num_workers))
        rewards_by_task[env_name] = compute_rewards_all(
            loader, model, device,
            ttt_lr=float(cfg.evaluation.ttt_lr[0] if isinstance(cfg.evaluation.ttt_lr, (list, tuple)) else cfg.evaluation.ttt_lr),
            ttt_epochs=int(cfg.evaluation.ttt_epochs[0] if isinstance(cfg.evaluation.ttt_epochs, (list, tuple)) else cfg.evaluation.ttt_epochs),
            methods=cfg.evaluation.reward_methods,
        )

    # 3) Build MDPs per reward for the chosen training env
    env_order = [t.env_name for t in cfg.collection.tasks]

    # concat buffers across envs
    obs_all   = np.concatenate([collected[e]["obs"]   for e in env_order], axis=0)
    acts_all  = np.concatenate([collected[e]["acts"]  for e in env_order], axis=0)
    terms_all = np.concatenate([collected[e]["terms"] for e in env_order], axis=0)
    rewards_oracle = np.concatenate([collected[e]["rews"] for e in env_order], axis=0)

    # assume all envs produced the same set of reward names
    reward_names = list(next(iter(rewards_by_task.values())).keys())

    # concat each reward across envs in the same order
    rewards_all = {
        name: np.concatenate([rewards_by_task[e][name] for e in env_order], axis=0)
        for name in reward_names
    }

    mdps = build_mdps_for_rewards(
        observations=obs_all,
        actions=acts_all,
        terminals=terms_all,
        rewards_by_name=rewards_all,
        expert_rewards = rewards_oracle,
        expert=cfg.evaluation.get("expert", False),
    )

    # 4) Train IQL per reward
    trained = train_iql_on_mdps(
        mdps,
        n_steps=int(cfg.training.get("iql_n_steps", 20000)),
        n_steps_per_epoch=int(cfg.training.get("iql_n_steps_per_epoch", 5000)),
        device=str(cfg.training.get("iql_device", "cuda")),
    )
    
    for reward_name, algo in trained.items():
        save_path = results_dir / f"iql_{reward_name}_{run_id}.d3"
        algo.save_model(str(save_path))
        print(f"[IQL] saved model -> {save_path}")

    # 5) Evaluate across envs
    # 5) Evaluate across envs
    table, policy_stats = eval_policies_across_envs(
        trained,
        list(cfg.collection.eval_envs),
        episodes=int(cfg.training.get("eval_episodes", 20)),
        max_steps=int(cfg.training.get("eval_max_steps", 150)),
        seed=seed,
    )

    print("\n=== Summary (env, reward, success_rate, avg_return) ===")
    summary_file = results_dir / f"eval_summary_{run_id}_seed_{seed}.txt"
    with open(summary_file, "w") as f:
        f.write(f"# success_episodes={cfg.collection.success_episodes}\n")
        f.write(f"# eval_episodes={cfg.training.eval_episodes}\n")
        f.write(f"# iql_n_steps={cfg.training.iql_n_steps}\n\n")

        # Per-env rows with binomial CI + across-env stats (for that reward)
        f.write("env,reward,success_rate,avg_return,ci_low,ci_high,across_env_mean,across_env_iqm,iqm_ci_low,iqm_ci_high\n")
        for env, name, sr, avg_ret, lo, hi in table:
            ps = policy_stats[name]
            f.write(f"{env},{name},{sr:.4f},{avg_ret:.2f},{lo:.4f},{hi:.4f},"
                    f"{ps['mean']:.4f},{ps['iqm']:.4f},{ps['iqm_lo']:.4f},{ps['iqm_hi']:.4f}\n")

        # ---- Write aggregated stats across envs ----
        f.write("\n# Aggregated across envs (success rate)\n")
        f.write("reward,mean,iqm,iqm_ci_low,iqm_ci_high\n")
        for name, ps in policy_stats.items():
            f.write(f"{name},{ps['mean']:.4f},{ps['iqm']:.4f},"
                    f"{ps['iqm_lo']:.4f},{ps['iqm_hi']:.4f}\n")

    print(f"[Eval] summary written -> {summary_file}")
    return table


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", type=str, default="metaworld/config_rl.yaml")
    args = ap.parse_args()
    cfg = OmegaConf.load(args.config)
    run_id = datetime.now().strftime("%Y%m%d-%H%M%S")

    results_dir = Path("exp_metaworld/results")
    results_dir.mkdir(parents=True, exist_ok=True)

    seeds = list(cfg.get("seeds", [0]))
    print(f"[Run] using seeds: {seeds}")

    # in-memory aggregators
    from collections import defaultdict
    agg = defaultdict(lambda: {"sr": [], "ret": []})      # (env, reward) -> lists across seeds
    envseed_rows = []                                      # (env, reward, seed, sr, ret)
    per_seed_reward = defaultdict(lambda: defaultdict(list))  # reward -> seed -> [sr per env]

    # ---- run per seed and collect ----
    for seed in seeds:
        print(f"\n===== Running with seed {seed} =====\n")
        table = run_with_seed(cfg, int(seed), results_dir, run_id)  # [(env, reward, sr, avg_ret, lo, hi)]
        for env, reward, sr, avg_ret, lo, hi in table:
            agg[(env, reward)]["sr"].append(float(sr))
            agg[(env, reward)]["ret"].append(float(avg_ret))
            envseed_rows.append((env, reward, int(seed), float(sr), float(avg_ret)))
            per_seed_reward[reward][int(seed)].append(float(sr))

    # ---- build per-task seed-averaged table + reward-level IQM over task×seed ----
    per_task_avg = []                        # (env, reward, sr_mean, ret_mean)
    per_reward_taskseed = defaultdict(list)  # reward -> [sr across all env×seed]

    for (env, reward), vals in agg.items():
        sr_mean  = float(np.mean(vals["sr"]))
        ret_mean = float(np.mean(vals["ret"]))
        per_task_avg.append((env, reward, sr_mean, ret_mean))
        per_reward_taskseed[reward].extend(vals["sr"])  # keep all seed SRs

    # precompute reward-level stats (IQM over task×seed)
    reward_stats = {}
    for reward, xs in per_reward_taskseed.items():
        x = np.asarray(xs, dtype=np.float64)                # len ≈ num_envs * num_seeds
        mean = float(np.mean(x))
        iqm  = interquartile_mean(x)
        lo, hi = bootstrap_iqm_ci(x, n_boot=10000, alpha=0.05, seed=0)
        reward_stats[reward] = (mean, iqm, lo, hi)

    # ---- single aggregate file with all sections ----
    agg_file = results_dir / f"eval_summary_agg_{run_id}.txt"
    with open(agg_file, "w") as f:
        # metadata
        f.write(f"# success_episodes={cfg.collection.success_episodes}\n")
        f.write(f"# eval_episodes={cfg.training.eval_episodes}\n")
        f.write(f"# iql_n_steps={cfg.training.iql_n_steps}\n")
        f.write(f"# seeds={seeds}\n\n")

        # 1) env×seed rows (raw points for IQM)
        f.write("# Env×Seed results (raw points)\n")
        f.write("env,reward,seed,success_rate,avg_return\n")
        for env, reward, seed, sr, ret in sorted(envseed_rows):
            f.write(f"{env},{reward},{seed},{sr:.4f},{ret:.2f}\n")

        # 2) per-seed summary per reward (mean & IQM across envs)
        f.write("\n# Per-seed summary per reward (across envs)\n")
        f.write("reward,seed,mean_sr,iqm_sr\n")
        for reward in sorted(per_seed_reward.keys()):
            for seed in sorted(per_seed_reward[reward].keys()):
                vals = np.asarray(per_seed_reward[reward][seed], dtype=np.float64)  # env-level SRs for this seed
                mean_sr = float(np.mean(vals))
                iqm_sr  = interquartile_mean(vals)
                f.write(f"{reward},{seed},{mean_sr:.4f},{iqm_sr:.4f}\n")

        # 3) per-task (env) seed-averaged rows with across-reward stats (IQM over env×seed)
        f.write("\n# Per-task seed-averaged with across-reward stats (IQM over env×seed)\n")
        f.write("env,reward,success_rate,avg_return,ci_low,ci_high,across_env_mean,across_env_iqm,iqm_ci_low,iqm_ci_high\n")
        for env, reward, sr_mean, ret_mean in sorted(per_task_avg):
            mean, iqm, lo, hi = reward_stats[reward]  # same across all env rows of this reward
            f.write(f"{env},{reward},{sr_mean:.4f},{ret_mean:.2f},-1,-1,{mean:.4f},{iqm:.4f},{lo:.4f},{hi:.4f}\n")

        # 4) final aggregated block per reward (IQM over env×seed)
        f.write("\n# Aggregated across tasks×seeds (success rate)\n")
        f.write("reward,mean,iqm,iqm_ci_low,iqm_ci_high\n")
        for reward in sorted(reward_stats.keys()):
            mean, iqm, lo, hi = reward_stats[reward]
            f.write(f"{reward},{mean:.4f},{iqm:.4f},{lo:.4f},{hi:.4f}\n")

    print(f"[Eval] aggregated summary written -> {agg_file}")


if __name__ == "__main__":
    main()
