from collections import defaultdict
from multiprocessing import Pool

import numpy as np
import scipy.stats
import torch
import wandb

from collector.strategy import record_episode
from helpers import init_envs, init_model


VALID_ENV = "len10000_towers1_30"
VALID_EPS = 100

WANDB_PROJECT = "ENTER WANDB PROJECT HERE"
WANDB_ENTITY = "ENTER WANDB USERNAME HERE"

GROUPBY_CONFIG_KEY = "memory_type"

EPISODE = 100_000

WANDB_RUNS = [
    "run-name-1",
    "run-name-2",
    "run-name-3",
    "run-name-4",

]


def gen_metrics(run, env, model):
    valid_rewards = []
    for _ in range(VALID_EPS):
        trajectory = record_episode(env, model, explore=False)
        valid_rewards.append(trajectory.total_reward)
    valid_rewards = np.array(valid_rewards)

    print(run.name, valid_rewards.mean())
    return valid_rewards.mean()

    # np.save(path, valid_rewards)
    # print(f"Saved generated rewards to: {path.absolute()}")

def worker(run_name: str) -> (str, float):
    api = wandb.Api()
    runs = api.runs(f'{WANDB_ENTITY}/{WANDB_PROJECT}', filters={"display_name": run_name})

    # print(f"{len(runs)} run(s) found.")
    assert len(runs) == 1

    run = runs[0]

    config = dict(run.config)

    custom_valid_envs = {
        "len300_towers1_30": {
            "num_interval_on_left": [1, 30],
            "num_interval_on_right": [1, 30],
            "seq_len": 300,
            "_max_episode_steps": 320
        },
        "len3000": {
            "seq_len": 3000,
            "_max_episode_steps": 3200
        },
        "len3000_towers1_30": {
            "num_interval_on_left": [1, 30],
            "num_interval_on_right": [1, 30],
            "seq_len": 3000,
            "_max_episode_steps": 3200
        },
        "len10000": {
            "seq_len": 10000,
            "_max_episode_steps": 10600
        },
        "len10000_towers1_30": {
            "num_interval_on_left": [1, 30],
            "num_interval_on_right": [1, 30],
            "seq_len": 10000,
            "_max_episode_steps": 10600
        },
    }

    train_env, valid_envs_dict = init_envs(config, custom_valid_envs)
    model = init_model(config, train_env)

    print(f"{run.name} ({EPISODE = }) ({run.url})")

    file = run.file(f"checkpoints/{EPISODE}.pt")

    if file.size <= 0:
        print("Episode not found.")
        return

    file.download(root=f"./tmp/{run.name}", replace=True)

    model.load_state_dict(torch.load(f"./tmp/{run.name}/{file.name}")["model_state_dict"])
    model.eval()

    print(f"{run.name} - Trainable Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))

    valid_env = valid_envs_dict[VALID_ENV]

    mean_rew = gen_metrics(run, valid_env, model)
    run_group = run.config[GROUPBY_CONFIG_KEY]

    return run_group, mean_rew


def main():
    with Pool() as pool:
        group_mean_rews = pool.map(worker, WANDB_RUNS)

    vals: defaultdict[str, list[float]] = defaultdict(list)
    for group, mean_rew in group_mean_rews:
        vals[group].append(mean_rew)

    print("\n\n")

    for group_name, mean_rews in vals.items():
        mean = np.mean(mean_rews)  # mean over runs
        sem = scipy.stats.sem(mean_rews, ddof=0)  # SE over runs
        print(group_name, mean, sem)


if __name__ == "__main__":
    main()
