import argparse
import itertools

import pathlib
from collections import defaultdict
from collections.abc import Sequence

import colorcet
import numpy as np
import scipy.stats
import torch
import wandb
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator
from statsmodels.stats.proportion import proportion_confint
from tqdm import trange

from collector.strategy import record_episode
from helpers import init_envs, init_model
from logger.ratemap import compute_ratemap, normalize_over_axis

VALID_ENV = "len300"
VALID_EPS = 1000

N_RATEMAP_EPS = 300  # ratemap generation

WANDB_PROJECT = "ENTER WANDB PROJECT HERE"
WANDB_ENTITY = "ENTER WANDB USERNAME HERE"

GROUPBY_CONFIG_KEY = "memory_type"

EPISODE_MIN = 100_000
EPISODE_MAX = 100_000
EPISODE_STEP = 500

WANDB_RUNS = [
    "run-name-1",
    "run-name-2",
    "run-name-3",
    "run-name-4",
]


def normalize_and_sort(hists, axis):
    normed_hists = normalize_over_axis(hists, axis=axis)

    argmax_hists = np.argmax(np.nan_to_num(normed_hists, nan=0), axis=axis)
    sorted_hists = normed_hists[np.argsort(argmax_hists)]
    sorted_hists = np.nan_to_num(sorted_hists, nan=0)
    return sorted_hists


def gen_ratemaps(run, env, model, episode, mode):
    max_evid = max(env.num_interval_on_left[1], env.num_interval_on_right[1]) - 1

    ctx_evid_hists, ctx_evid_edges, \
    ctx_pos_hists, ctx_pos_edges, \
    psychometric_data = compute_ratemap(env, model, env.seq_len + 1, max_evid, N_RATEMAP_EPS)

    if mode == "evidence":
        edges = ctx_evid_edges
        sorted_hists = normalize_and_sort(ctx_evid_hists, axis=1)
        xlabel = "Evidence"
    elif mode == "position":
        edges = ctx_pos_edges
        sorted_hists = normalize_and_sort(ctx_pos_hists, axis=1)
        xlabel = "Position"
    else:
        raise ValueError()

    n_neur = len(sorted_hists)

    fig, ax = plt.subplots(1, 1, dpi=600, figsize=(2, 2.5), constrained_layout=True)

    if mode == "evidence":
        # ax.xaxis.set_minor_locator(MultipleLocator(2))
        ax.xaxis.set_major_locator(MultipleLocator(10))
    ax.set_xlabel(xlabel, fontsize=10)

    ax.set_ylabel("Neuron", fontsize=10)
    ax.set_ylim(n_neur, 0.01)  # hide 0 by making ylim 0.01
    # ax.yaxis.set_minor_locator(MultipleLocator(50))
    ax.yaxis.set_major_locator(MultipleLocator(50))

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    pcm = ax.pcolormesh(
        edges,
        range(n_neur + 1),
        sorted_hists,
        cmap=colorcet.cm.bgy,
    )

    cb = fig.colorbar(pcm, ax=ax, location='top', ticks=[0, 1])
    cb.outline.set_visible(False)

    path = pathlib.Path(f"./out/ratemaps/{VALID_ENV}/ratemap_{run.config['memory_type']}_{mode}_{run.name}_ep{episode}.png")
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(path)
    print(f"Figure saved to {path.absolute()}")
    plt.close(fig)

    # ax.axis('off')
    # cb.remove()
    # path = pathlib.Path(f"./out/ratemaps/{VALID_ENV}/clean_{run.config['memory_type']}_{mode}_{run.name}_ep{episode}.png")
    # fig.savefig(path, bbox_inches=0)
    # print(f"Figure saved to {path.absolute()}")

    if mode == "evidence":
        # Generate psychometric curve
        psy_right = psychometric_data[:, 0]
        psy_left = psychometric_data[:, 1]
        psy_total = psy_right + psy_left
        with np.errstate(divide='ignore', invalid='ignore'):
            means = psy_right / psy_total
            alpha = 0.32  # one-sigma (per Pinto et al., 2018)
            ci_low, ci_upp = proportion_confint(psy_right, psy_total, alpha=alpha, method="jeffreys")

        xs = np.arange(-max_evid, max_evid + 1)
        err_low = means - ci_low
        err_upp = ci_upp - means

        fig, ax = plt.subplots(1, 1, figsize=(2, 2), dpi=600, constrained_layout=True)
        ax.set(xlim=(-max_evid, max_evid), ylim=(0, 100), xlabel="Evidence", ylabel="Turned right (%)")

        ax.axhline(y=50, c="lightgray", ls="--")
        ax.axvline(x=0, c="lightgray", ls="--")
        ax.errorbar(xs, means * 100, yerr=np.stack([err_low, err_upp]) * 100, fmt=".")

        path = pathlib.Path(f"./out/ratemaps/{VALID_ENV}/psychometric_{run.config['memory_type']}_{mode}_{run.name}_ep{episode}.png")
        fig.savefig(path, bbox_inches=0)
        print(f"Figure saved to {path.absolute()}")
        plt.close(fig)

        # Generate psychometric curve (no conf int)
        fig, ax = plt.subplots(1, 1, figsize=(2, 2), dpi=600, constrained_layout=True)
        ax.set(xlim=(-max_evid, max_evid), ylim=(0, 100), xlabel="Evidence", ylabel="Turned right (%)")

        ax.axhline(y=50, c="lightgray", ls="--")
        ax.axvline(x=0, c="lightgray", ls="--")
        ax.errorbar(xs, means * 100, yerr=np.zeros_like(np.stack([err_low, err_upp])), fmt=".")

        path = pathlib.Path(f"./out/ratemaps/{VALID_ENV}/psychometric_noconfint_{run.config['memory_type']}_{mode}_{run.name}_ep{episode}.png")
        fig.savefig(path, bbox_inches=0)
        print(f"Figure saved to {path.absolute()}")
        plt.close(fig)


def gen_metrics(run, env, model, episode: int):
    # path = pathlib.Path(f"./out/generated_rewards/{VALID_ENV}/{run.name}_ep{episode}.npy")
    # path.parent.mkdir(parents=True, exist_ok=True)
    #
    # if path.exists():
    #     print(f"File already exists, skipping: {path.absolute()}")
    #     return

    valid_rewards = []
    for _ in trange(VALID_EPS, leave=False):
        trajectory = record_episode(env, model, explore=False)
        valid_rewards.append(trajectory.total_reward)
    valid_rewards = np.array(valid_rewards)

    print(run.name, valid_rewards.mean())

    return valid_rewards.mean()

    # np.save(path, valid_rewards)
    # print(f"Saved generated rewards to: {path.absolute()}")


def main(argv: Sequence[str] | None = None):
    args = parse_args(argv)

    episode_range = range(EPISODE_MIN, EPISODE_MAX+1, EPISODE_STEP)

    prod = list(itertools.product(episode_range, WANDB_RUNS))
    # assert args.job_array_idx < len(prod), f"{args.job_array_idx} < {len(prod)}"
    # episode, run_name = prod[args.job_array_idx]

    vals: defaultdict[str, list[float]] = defaultdict(list)

    for episode, run_name in prod:
        api = wandb.Api()
        runs = api.runs(f'{WANDB_ENTITY}/{WANDB_PROJECT}', filters={"display_name": run_name})

        # print(f"{len(runs)} run(s) found.")
        assert len(runs) == 1

        run = runs[0]

        config = dict(run.config)

        custom_valid_envs = {
            "len300_towers1_30": {
                "num_interval_on_left": [1, 30],
                "num_interval_on_right": [1, 30],
                "seq_len": 300,
                "_max_episode_steps": 320
            },
            "len3000": {
                "seq_len": 3000,
                "_max_episode_steps": 3200
            },
            "len3000_towers1_30": {
                "num_interval_on_left": [1, 30],
                "num_interval_on_right": [1, 30],
                "seq_len": 3000,
                "_max_episode_steps": 3200
            },
            "len10000": {
                "seq_len": 10000,
                "_max_episode_steps": 10600
            },
        }

        train_env, valid_envs_dict = init_envs(config, custom_valid_envs)
        model = init_model(config, train_env)

        print(f"{run.name} ({episode = }) ({run.url})")

        file = run.file(f"checkpoints/{episode}.pt")

        if file.size <= 0:
            print("Episode not found.")
            return

        file.download(root=f"./tmp/{run.name}", replace=True)

        model.load_state_dict(torch.load(f"./tmp/{run.name}/{file.name}")["model_state_dict"])
        model.eval()

        valid_env = valid_envs_dict[VALID_ENV]

        mean_rew = gen_metrics(run, valid_env, model, episode)
        # gen_ratemaps(run, valid_env, model, episode, "evidence")
        # gen_ratemaps(run, valid_env, model, episode, "position")

        run_group = run.config[GROUPBY_CONFIG_KEY]
        vals[run_group].append(mean_rew)

    for group, mean_rews in vals.items():
        mean = np.mean(mean_rews)
        sem = scipy.stats.sem(mean_rews, ddof=0)
        print(group, mean, sem)


def parse_args(argv=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('job_array_idx', type=int)
    args = parser.parse_args(argv)
    return args


if __name__ == "__main__":
    main()
