import gym
import numpy as np
import scipy.stats
import torch
import wandb
from einops import rearrange
from matplotlib import pyplot as plt
from statsmodels.stats.proportion import proportion_confint
from tqdm import trange

from envs.wrappers.tensor_wrapper import TensorWrapper
from policy.base import BasePolicy


def compute_ratemap(env: gym.Env, model: BasePolicy, n_pos: int, max_evid: int, n_eps: int):
    model.eval()

    env = TensorWrapper(env)

    pos_hists_ctx = np.zeros([model.net.ctx_size, n_pos])
    evid_hists_ctx = np.zeros([model.net.ctx_size, max_evid * 2 + 1])

    pos_hists_ctx_count = np.zeros([model.net.ctx_size, n_pos])
    evid_hists_ctx_count = np.zeros([model.net.ctx_size, max_evid * 2 + 1])

    psychometric_data = np.zeros([max_evid * 2 + 1, 2])

    with torch.no_grad():
        for _ in trange(n_eps, desc="Ratemap Generation", leave=False):
            h = None
            obs = env.reset()
            done = False

            while not done:
                obs_in = rearrange(obs, '... -> 1 1 ...')  # add batch and time dim

                # select action
                act, h, _, _, _ = model.greedy(obs_in, h)
                act = rearrange(act, '1 1 -> ')  # remove batch and time dim
                # act_logits = rearrange(act_logits, '1 1 a -> a')  # remove batch and time dim

                # apply action to environment
                obs, rew, done, info = env.step(act.item())

                ctx = rearrange(model.net.ctx, '1 1 c -> c').detach().cpu().numpy()  # remove batch and time dim
                ctx_idx = np.indices(ctx.shape)[0]
                ctx_evid = np.tile(info["evidence"], reps=len(ctx))
                ctx_pos = np.tile(info["position"], reps=len(ctx))

                ctx_binedges = np.linspace(-0.5, len(ctx) - 0.5, len(ctx)+1, endpoint=True)
                evid_binedges = np.linspace(-max_evid - 0.5, max_evid + 0.5, max_evid * 2 + 2, endpoint=True)
                pos_binedges = np.linspace(-0.5, n_pos - 0.5, n_pos+1, endpoint=True)

                evid_h_ctx, _, evid_edges_ctx, _ = scipy.stats.binned_statistic_2d(
                    ctx_idx, ctx_evid, ctx,
                    statistic="sum",
                    bins=[ctx_binedges, evid_binedges],
                )
                evid_hists_ctx += evid_h_ctx

                evid_h_ctx_count, _, _, _ = scipy.stats.binned_statistic_2d(
                    ctx_idx, ctx_evid, None,
                    statistic="count",
                    bins=[ctx_binedges, evid_binedges],
                )
                evid_hists_ctx_count += evid_h_ctx_count

                pos_h_ctx, _, pos_edges_ctx, _ = scipy.stats.binned_statistic_2d(
                    ctx_idx, ctx_pos, ctx,
                    statistic="sum",
                    bins=[ctx_binedges, pos_binedges],
                )
                pos_hists_ctx += pos_h_ctx

                pos_h_ctx_count, _, _, _ = scipy.stats.binned_statistic_2d(
                    ctx_idx, ctx_pos, None,
                    statistic="count",
                    bins=[ctx_binedges, pos_binedges],
                )
                pos_hists_ctx_count += pos_h_ctx_count

            # end of trial - gather data for psychometric plot
            if info["position"] == env.seq_len:
                if act == 2 or act == 3:
                    evid_idx = int(info["evidence"]+max_evid)
                    psychometric_data[evid_idx, act-2] += 1

    psychometric_data = np.array(psychometric_data)

    with np.errstate(divide='ignore', invalid='ignore'):
        evid_hists_ctx = evid_hists_ctx / evid_hists_ctx_count
        # evid_hists_ctx = np.nan_to_num(evid_hists_ctx, nan=0.0)

        pos_hists_ctx = pos_hists_ctx / pos_hists_ctx_count
        # pos_hists_ctx = np.nan_to_num(pos_hists_ctx, nan=0.0)

    return evid_hists_ctx, evid_edges_ctx, pos_hists_ctx, pos_edges_ctx, psychometric_data


def normalize_over_axis(arr, axis=1):
    with np.errstate(divide='ignore', invalid='ignore'):
        return (arr - np.nanmin(arr, axis=axis, keepdims=True)) / (
                np.nanmax(arr, axis=axis, keepdims=True) - np.nanmin(arr, axis=axis, keepdims=True))


def normalize_and_sort(hists_1, hists_2):
    # normalize
    normed_hists_1 = normalize_over_axis(hists_1, axis=1)
    normed_hists_2 = normalize_over_axis(hists_2, axis=1)

    # get sort idxs from hists_1
    argmax_hists = np.argmax(np.nan_to_num(normed_hists_1, nan=0), axis=1)
    sort_idxs = np.argsort(argmax_hists)

    # rearrange hist columns
    sorted_hists_1 = normed_hists_1[sort_idxs]
    sorted_hists_2 = normed_hists_2[sort_idxs]

    return sorted_hists_1, sorted_hists_2, normed_hists_1


def log_ratemaps(envs_dict: dict[str, gym.Env], model: BasePolicy, n_eps: int, log_individual_ratemaps: bool, log_aggregate_ratemaps: bool):
    if model.net.ctx_size > 200:
        # prevent large ratemaps from being computed
        return

    if not log_individual_ratemaps and not log_aggregate_ratemaps:
        return

    for name, env in envs_dict.items():
        max_evid = max(env.num_interval_on_left[1], env.num_interval_on_right[1])

        ctx_evid_hists, ctx_evid_edges, \
            ctx_pos_hists, ctx_pos_edges, \
            psychometric_data = compute_ratemap(env, model, env.seq_len + 1, max_evid, n_eps)

        if log_individual_ratemaps:
            ctx_evid_hists = np.nan_to_num(ctx_evid_hists, nan=0.0)
            ctx_pos_hists = np.nan_to_num(ctx_pos_hists, nan=0.0)

            for i, hist in enumerate(ctx_evid_hists):
                wandb.log({f"ratemap/{name}/evidence/rnn/{i}": wandb.Histogram(np_histogram=(hist, ctx_evid_edges))}, commit=False)

            for i, hist in enumerate(ctx_pos_hists):
                wandb.log({f"ratemap/{name}/position/rnn/{i}": wandb.Histogram(np_histogram=(hist, ctx_pos_edges))}, commit=False)

        if log_aggregate_ratemaps:
            ctx_evid_hists_2, _, \
                ctx_pos_hists_2, _, _ = compute_ratemap(env, model, env.seq_len + 1 + 1, max_evid, n_eps)

            # ctx_evid
            sorted_hists, sorted_hists_2, normed_hists = normalize_and_sort(ctx_evid_hists, ctx_evid_hists_2)
            wandb.log({f"ratemap/{name}/evidence/rnn_all": wandb.Image(normed_hists * 255, mode="L")}, commit=False)
            wandb.log({f"ratemap/{name}/evidence/sorted/rnn_all": wandb.Image(sorted_hists * 255, mode="L")}, commit=False)
            wandb.log({f"ratemap/{name}/evidence/sorted/rnn_all_2": wandb.Image(sorted_hists_2 * 255, mode="L")}, commit=False)

            # ctx_pos
            sorted_hists, sorted_hists_2, normed_hists = normalize_and_sort(ctx_pos_hists, ctx_pos_hists_2)
            wandb.log({f"ratemap/{name}/position/rnn_all": wandb.Image(normed_hists * 255, mode="L")}, commit=False)
            wandb.log({f"ratemap/{name}/position/sorted/rnn_all": wandb.Image(sorted_hists * 255, mode="L")}, commit=False)
            wandb.log({f"ratemap/{name}/position/sorted/rnn_all_2": wandb.Image(sorted_hists_2 * 255, mode="L")}, commit=False)

        # Generate psychometric curve
        psy_right = psychometric_data[:, 0]
        psy_left = psychometric_data[:, 1]
        psy_total = psy_right + psy_left
        with np.errstate(divide='ignore', invalid='ignore'):
            means = psy_right / psy_total
            alpha = 0.32  # one-sigma (per Pinto et al., 2018)
            ci_low, ci_upp = proportion_confint(psy_right, psy_total, alpha=alpha, method="jeffreys")

        xs = np.arange(-max_evid, max_evid+1)
        err_low = means - ci_low
        err_upp = ci_upp - means

        fig, ax = plt.subplots(1, 1, figsize=(3, 3), dpi=300, constrained_layout=True)
        ax.set(xlim=(-max_evid, max_evid), ylim=(0, 100), xlabel="Evidence", ylabel="Turned right (%)")

        ax.axhline(y=50, c="lightgray", ls="--")
        ax.axvline(x=0, c="lightgray", ls="--")
        ax.errorbar(xs, means*100, yerr=np.stack([err_low, err_upp])*100, fmt=".")

        fig.canvas.draw()
        rgba = np.asarray(fig.canvas.buffer_rgba())
        plt.close(fig)

        wandb.log({f"psychometric_curve/{name}": wandb.Image(rgba, mode="RGBA")}, commit=False)

