import functools
import os
import pathlib
from multiprocessing import Pool
from pprint import pprint

import colorcet
import numpy as np
import torch
from matplotlib.ticker import MultipleLocator

import wandb
from matplotlib import pyplot as plt
from statsmodels.stats.proportion import proportion_confint
from tqdm import tqdm

from helpers import init_envs, init_model
from logger.ratemap import compute_ratemap, normalize_over_axis

tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True)

os.environ["WANDB_API_KEY"] = "ENTER WANDB API KEY HERE"

WANDB_PROJECT = "ENTER WANDB PROJECT HERE"
WANDB_ENTITY = "ENTER WANDB USERNAME HERE"

CKPT_EP = 100_000

WANDB_RUNS = [
    "run-name-1",
    "run-name-2",
    "run-name-3",
    "run-name-4",
]

N_EPS = 300  # number of environment episodes used
MODE = "evidence"

VALID_ENV = "len300"


def normalize_and_sort(hists, axis):
    normed_hists = normalize_over_axis(hists, axis=axis)

    argmax_hists = np.argmax(np.nan_to_num(normed_hists, nan=0), axis=axis)
    sorted_hists = normed_hists[np.argsort(argmax_hists)]
    sorted_hists = np.nan_to_num(sorted_hists, nan=0)
    return sorted_hists


def main(wandb_run: str):
    ckpt_ep = CKPT_EP

    api = wandb.Api()
    runs = api.runs(f'{WANDB_ENTITY}/{WANDB_PROJECT}', filters={"display_name": wandb_run})

    # print(f"{len(runs)} run(s) found.")
    assert len(runs) == 1

    run = runs[0]

    config = dict(run.config)

    # if config["memory_type"] == "sith_sub":
    #     config["memory_type"] = "sith_sub_sum"

    # print()
    # pprint(config)
    # print()

    train_env, valid_envs_dict = init_envs(config)
    model = init_model(config, train_env)

    # print()
    # print(model.net)
    # print()
    #
    # total_params = 0
    # for name, p in model.named_parameters():
    #     print(name, p.numel())
    #     total_params += p.numel()
    # print(f"Total Params: ", total_params)
    # print(f"Trainable Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
    # print()

    print(f"{run.name} ({run.url})")

    max_episode = run.summary_metrics["episode"]
    checkpoint_freq = run.config['log_checkpoint_freq']

    print(f"Checkpoint Frequency: {checkpoint_freq:>7}")
    print(f"         Max Episode: {max_episode:>7}")
    print()

    if ckpt_ep is not None:
        episode = ckpt_ep
        file = run.file(f"checkpoints/{episode}.pt")
    else:
        while (episode := input("Enter episode number: ")) != "":
            file = run.file(f"checkpoints/{episode}.pt")

            if file.size > 0:
                break
            else:
                print("Episode not found.\n")

    # print(file)
    file.download(root=f"./tmp/{run.name}", replace=True)
    model.load_state_dict(torch.load(f"./tmp/{run.name}/{file.name}")["model_state_dict"])
    model.eval()

    env = valid_envs_dict[VALID_ENV]

    max_evid = max(env.num_interval_on_left[1], env.num_interval_on_right[1]) - 1

    ctx_evid_hists, ctx_evid_edges, \
    ctx_pos_hists, ctx_pos_edges, \
    psychometric_data = compute_ratemap(env, model, env.seq_len + 1, max_evid, N_EPS)

    if MODE == "evidence":
        edges = ctx_evid_edges
        sorted_hists = normalize_and_sort(ctx_evid_hists, axis=1)
        xlabel = "Evidence"
    elif MODE == "position":
        edges = ctx_pos_edges
        sorted_hists = normalize_and_sort(ctx_pos_hists, axis=1)
        xlabel = "Position"
    else:
        raise ValueError()

    n_neur = len(sorted_hists)

    fig, ax = plt.subplots(1, 1, dpi=600, figsize=(2, 2.5), constrained_layout=True)

    if MODE == "evidence":
        # ax.xaxis.set_minor_locator(MultipleLocator(2))
        ax.xaxis.set_major_locator(MultipleLocator(10))
    ax.set_xlabel(xlabel, fontsize=10)

    ax.set_ylabel("Neuron", fontsize=10)
    ax.set_ylim(n_neur, 0.01)  # hide 0 by making ylim 0.01
    # ax.yaxis.set_minor_locator(MultipleLocator(50))
    ax.yaxis.set_major_locator(MultipleLocator(50))

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # ax.imshow(
    #     ctx_evid_hists,
    #     cmap=colorcet.cm.bgy,
    #     aspect='auto',
    #     extent=(ctx_evid_edges[0], ctx_evid_edges[-1], 0, len(ctx_evid_hists))
    # )

    pcm = ax.pcolormesh(
        edges,
        range(n_neur+1),
        sorted_hists,
        cmap=colorcet.cm.bgy,
    )

    cb = fig.colorbar(pcm, ax=ax, location='top', ticks=[0, 1])
    cb.outline.set_visible(False)

    path = pathlib.Path(f"./out/ratemaps/{VALID_ENV}/ratemap_{run.config['memory_type']}_{MODE}_{run.name}_ep{episode}.png")
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(path)
    print(f"Figure saved to {path.absolute()}")

    ax.axis('off')
    cb.remove()
    path = pathlib.Path(f"./out/ratemaps/{VALID_ENV}/clean_{run.config['memory_type']}_{MODE}_{run.name}_ep{episode}.png")
    fig.savefig(path, bbox_inches=0)
    print(f"Figure saved to {path.absolute()}")

    if MODE == "evidence":
        # Generate psychometric curve
        psy_right = psychometric_data[:, 0]
        psy_left = psychometric_data[:, 1]
        psy_total = psy_right + psy_left
        with np.errstate(divide='ignore', invalid='ignore'):
            means = psy_right / psy_total
            alpha = 0.32  # one-sigma (per Pinto et al., 2018)
            ci_low, ci_upp = proportion_confint(psy_right, psy_total, alpha=alpha, method="jeffreys")

        xs = np.arange(-max_evid, max_evid + 1)
        err_low = means - ci_low
        err_upp = ci_upp - means

        fig, ax = plt.subplots(1, 1, figsize=(2, 2), dpi=600, constrained_layout=True)
        ax.set(xlim=(-max_evid, max_evid), ylim=(0, 100), xlabel="Evidence", ylabel="Turned right (%)")

        ax.axhline(y=50, c="lightgray", ls="--")
        ax.axvline(x=0, c="lightgray", ls="--")
        ax.errorbar(xs, means * 100, yerr=np.stack([err_low, err_upp]) * 100, fmt=".")

        path = pathlib.Path(f"./out/ratemaps/{VALID_ENV}/psychometric_{run.config['memory_type']}_{MODE}_{run.name}_ep{episode}.png")
        fig.savefig(path, bbox_inches=0)
        print(f"Figure saved to {path.absolute()}")


if __name__ == "__main__":
    with Pool() as pool:
        pool.map(main, WANDB_RUNS)

    # for wandb_run, ckpt_ep in WANDB_RUNS:
    #     main(wandb_run, ckpt_ep)
    # main()
