from collections import defaultdict
from typing import Dict
import logging
import os
import json

import hydra
from omegaconf import DictConfig
import wandb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns
from scipy.stats import bootstrap

from inference_rlhf.code.helpers.io import json_dump

sns.set_theme(style="whitegrid")

log = logging.getLogger(__name__)

DATA_DIR = "./data"
FACE_COLOR = "#F7F7FF"
MARKER = "o"
LINEWIDTH = 1.5
MARKERSIZE = 9
MARKEREDGEWIDTH = 1.2

BETA_TO_COLOR = {
    "0.0": sns.color_palette()[1],
    "0.5": sns.color_palette()[0],
    "1.0": sns.color_palette()[2],
}

BETA_TO_LABEL = {
    "0.0": "Vanilla",
    "0.5": "Low RepExp",
    "1.0": "High RepExp",
}

def plot_cdf(beta_data, beta, ax: plt.axes):
    xs = np.arange(0, 1025, 64)
    ys = []
    yerr = []
    for x in xs:
        ys.append(np.sum(beta_data <= x) / len(beta_data))
        yerr.append(1 * stats.sem(beta_data<=x))
    ax.plot(xs, ys, label=f'{BETA_TO_LABEL[beta]} ($\\beta={beta}$)', marker=MARKER, markersize=MARKERSIZE, linewidth=LINEWIDTH, markeredgewidth=MARKEREDGEWIDTH, markeredgecolor=FACE_COLOR, color=BETA_TO_COLOR[beta])
    yticks = np.arange(0.15, 0.35, 0.025)
    ax.set_ylim(0.175, 0.35)
    # ax.fill_between(xs, np.array(ys) - np.array(yerr), np.array(ys) + np.array(yerr), alpha=0.2)

def plot_beta_cdfs(beta_to_correct_np: Dict[float, np.ndarray], cfg: DictConfig, ax: plt.Axes) -> None:
    """
    Plot the CDF of samples-to-correct for each beta.

    Args:
        beta_to_correct_np: Dict[float, np.ndarray]
            - beta_to_correct_np[beta] = samples-to-correct
        cfg: DictConfig
            - cfg.policy.name
        ax: plt.Axes
            - ax to plot on
    
    Returns:
        None
    """
    for beta in beta_to_correct_np:
        plot_cdf(beta_to_correct_np[beta], beta, ax)
    # ax.set_title('CDF of samples-to-correct')
    ax.legend()
    ax.set_xlabel('k')
    ax.set_ylabel('Pass@k')
    ax.set_xlim(128, 1024)
    ax.set_xticks(np.arange(128, 1025, 128))

    # # Save figure
    # folder_to_save = os.path.join('figures', 'token_level', cfg.policy.name, cfg.task.name)
    # os.makedirs(folder_to_save, exist_ok=True)
    # file_name = f'cdf_samples_to_correct_{list(beta_to_correct_np.keys())}_tag_{cfg.plot.tag}.pdf'
    # print(f'Saving figure to {os.path.join(folder_to_save, file_name)}')
    # plt.savefig(os.path.join(folder_to_save, file_name))

def plot_quantile_diff(beta_to_correct_np: Dict[float, np.ndarray], cfg: DictConfig) -> None:
    """
    Plot the difference between the quantiles of the beta_to_correct_np data.
    """
    quantiles = np.linspace(0.05, 0.35, 70)

    plt.figure(figsize=(10, 6))
    for beta in beta_to_correct_np.keys():
        if beta == "0.0":
            continue
    
        # Compute the observed difference in quantiles
        qa = np.quantile(beta_to_correct_np["0.0"], quantiles)
        qb = np.quantile(beta_to_correct_np[beta], quantiles)
        delta_q = qa - qb

        # Define bootstrap function
        def quantile_diff(data, q):
            a, b = data
            return  np.quantile(a, q) - np.quantile(b, q)

        # Get bootstrap CIs
        n_resamples = 1000
        ci_low = []
        ci_high = []

        for q in quantiles:
            res = bootstrap(
                data=(beta_to_correct_np["0.0"], beta_to_correct_np[beta]),
                statistic=lambda *args: quantile_diff(args, q),
                paired=True,
                confidence_level=0.95,
                n_resamples=n_resamples,
                # method='percentile',
                vectorized=False
            )
            ci_low.append(res.confidence_interval.low)
            ci_high.append(res.confidence_interval.high)

        # Plot
        plt.plot(quantiles, delta_q, label=f'{BETA_TO_LABEL[beta]} ($\beta={beta}$)')
        plt.fill_between(quantiles, ci_low, ci_high, alpha=0.3)
        plt.axhline(0, color='gray', linestyle='--')
    
    plt.xlabel('Quantile')
    plt.ylabel('# additional samples-to-correct')
    plt.legend()
    plt.tight_layout()
    plt.savefig('test-diff.pdf')

def load_data_from_wandb(cfg: DictConfig) -> Dict[float, Dict[int, Dict[int, int]]]:
    """
    Load samples-to-correct data from wandb.

    Args:
        cfg: Hydra config

    Returns:
        beta_to_correct: Dict[float, Dict[int, Dict[int, int]]]
            - beta_to_correct[beta][generation_idx][seed] = samples-to-correct
    """

    save_path = os.path.join(DATA_DIR, f"token_level_{cfg.plot.betas}_tag_{cfg.plot.tag}.json")
    if os.path.exists(save_path):
        return json.load(open(save_path))

    beta_to_correct = defaultdict(lambda: defaultdict(dict))
    bad_runs = 0

    wandb_api = wandb.Api()
    for beta in cfg.plot.betas:
        runs = wandb_api.runs(
            cfg.user.wandb_project,
            filters={
                "config.policy.name": cfg.policy.name,
                "config.task.name": cfg.task.name,
                "config.coreset.elliptical.beta": beta,
                "config.seed": {"$in": list(cfg.plot.seeds)},
                "tags": cfg.plot.tag,
                "config.sampling.top_k": cfg.sampling.top_k
            }
        )

        for run in tqdm(runs, desc=f'Loading data for beta={beta} ...'):
            # Skip if run has not finished
            if run.state != "finished":
                continue

            try:
                history = run.history(keys=['k', 'any_strict_correct'], samples=int(1e5))
                seed = run.config['seed']
                generation_idx = run.config['task']['generation']['generation_idx']

                if np.sum(history['any_strict_correct']) == 0:
                    if len(history['any_strict_correct']) < 1024:
                        bad_runs += 1
                        log.info(f'WARNING: No correct answers found for run {run.id} while < 1024 samples')
                        continue
                    else:
                        # Set to 1025 to indicate that the answer was never found
                        beta_to_correct[beta][generation_idx][seed] = 1025
                    continue

                beta_to_correct[beta][generation_idx][seed] = len(history['any_strict_correct'])

            except Exception as e:
                log.info(f'WARNING: Error loading data for run {run.id}')
                raise e

    log.info(f'Found {bad_runs} bad runs')

    # Save data
    log.info(f'Saving data to {save_path}')
    json_dump(beta_to_correct, save_path)

    return beta_to_correct

def keep_only_shared_runs(beta_to_correct: Dict[float, Dict[int, Dict[int, int]]]) -> None:
    """
    Drop any seed-generation_idx pairs that are not present in all betas

    Args:
        beta_to_correct: Dict[float, Dict[int, Dict[int, int]]]
            - beta_to_correct[beta][generation_idx][seed] = samples-to-correct

    Returns:
        None
    """
    # Find all (generation_idx, seed) pairs that are present in all betas
    all_pairs = [
        set((gen_idx, seed) for gen_idx in beta_to_correct[beta] for seed in beta_to_correct[beta][gen_idx])
        for beta in beta_to_correct
    ]
    shared_pairs = set.intersection(*all_pairs) if all_pairs else set()

    # log by seed
    seed_count = defaultdict(int)
    for pair in shared_pairs:
        seed_count[pair[1]] += 1
    for seed in seed_count:
        log.info(f'Seed {seed} has {seed_count[seed]} shared samples')

    # Reconstruct each beta's dict with only shared pairs
    for beta in list(beta_to_correct.keys()):
        new_inner = defaultdict(dict)
        for gen_idx, seed in shared_pairs:
            new_inner[gen_idx][seed] = beta_to_correct[beta][gen_idx][seed]
        beta_to_correct[beta] = dict(new_inner)

def numpify_data(beta_to_correct: Dict[float, Dict[int, Dict[int, int]]]) -> Dict[float, np.ndarray]:
    """
    Convert the beta_to_correct data to a dictionary of numpy arrays.

    Args:
        beta_to_correct: Dict[float, Dict[int, Dict[int, int]]]
            - beta_to_correct[beta][generation_idx][seed] = samples-to-correct

    Returns:
        beta_to_correct_np: Dict[float, np.ndarray]
            - beta_to_correct_np[beta] = [samples-to-correct]
    """
    beta_to_correct_list = defaultdict(list)
    for beta in beta_to_correct:
        for gen_idx in range(200):
            if str(gen_idx) not in beta_to_correct[beta]:
                print(f'{beta} {gen_idx} not found')
                continue
            for seed in beta_to_correct[beta][str(gen_idx)]:
                beta_to_correct_list[beta].append(beta_to_correct[beta][str(gen_idx)][seed])

    return {k: np.stack(v) for k, v in beta_to_correct_list.items()}

def plot_solve_rate_by_hardness(beta_to_correct_np: Dict[float, np.ndarray], cfg: DictConfig, ax: plt.Axes) -> None:
    """
    Plot the solve rate by hardness.
    """
    delta = 0.2
    for beta in beta_to_correct_np:
        if beta == "0.0" or beta == "0.5":
            continue
        xs = []
        ys = []
        ci_lows = []
        ci_highs = []
        for frac in np.arange(1.0, 0.0, -delta):
            xs.append(1.0 - frac + delta / 2)
            questions = beta_to_correct_np[beta][int((frac - delta) * len(beta_to_correct_np[beta])):int(frac * len(beta_to_correct_np[beta]))]
            questions_base = beta_to_correct_np["0.0"][int((frac - delta) * len(beta_to_correct_np["0.0"])):int(frac * len(beta_to_correct_np["0.0"]))]
            solve_rate = np.mean(questions <= 1024) * 100
            solve_rat_base = np.mean(questions_base <= 1024) * 100
            ys.append(solve_rate - solve_rat_base)

            res = bootstrap(
                data=(questions, questions_base),
                statistic=lambda *args: np.mean(args[0] <= 1024) * 100 - np.mean(args[1] <= 1024) * 100,
                paired=True,
                confidence_level=0.95,
                n_resamples=10000,
                # method='percentile',
                vectorized=False
            )
            ci_lows.append(res.confidence_interval.low)
            ci_highs.append(res.confidence_interval.high)

        ax.errorbar(
            xs, ys, label=f'{BETA_TO_LABEL[beta]} ($\\beta={beta}$)',
            yerr=(ys - np.array(ci_lows), np.array(ci_highs) - ys),
            fmt=MARKER + "-",
            markersize=MARKERSIZE,
            linewidth=LINEWIDTH,
            markeredgewidth=MARKEREDGEWIDTH,
            markeredgecolor=FACE_COLOR,
            capsize=5,
            elinewidth=1,  # Make error bar lines smaller
            capthick=1,
            color=BETA_TO_COLOR[beta]    
        )
        # add horizontal line at 0
        for x, y, ci_low, ci_high in zip(xs, ys, ci_lows, ci_highs):
            ax.fill_between(
                [x - delta / 2, x + delta / 2],
                [ci_low, ci_low],
                [ci_high, ci_high],
                alpha=0.1,
                color=ax.lines[-1].get_color()
                # color=METHOD_TO_COLOR[label]
            )

    ax.axhline(0, color='black', linestyle='--')
    ax.legend()
    ax.set_xlim(0, 1)
    ax.set_xticks(np.arange(0, 1.2, 0.2), ['0', '20', '40', '60', '80', '100'])
    ax.set_yticks([-10, -5, 0, 5, 10, 15])
    ax.set_ylim(-10, 15)
    ax.set_xlabel('Hardness quantile (%)')
    ax.set_ylabel('Solve rate lift (%)')
    # ax.set_title('Solve rate lift by question hardness')

@hydra.main(config_path="../../configs", config_name="master", version_base=None)
def main(cfg: DictConfig):
    beta_to_correct = load_data_from_wandb(cfg)
    
    # Drop any seed-generation_idx pairs that are not present in all betas
    keep_only_shared_runs(beta_to_correct)

    beta_to_correct_np = numpify_data(beta_to_correct)

    fig, axs = plt.subplots(1, 2, figsize=(12, 4.5))

    # Plot CDF
    plot_beta_cdfs(beta_to_correct_np, cfg, axs[0])

    # Plot solve rate by hardness
    plot_solve_rate_by_hardness(beta_to_correct_np, cfg, axs[1])

    plt.tight_layout()
    plt.savefig(os.path.join('figures', 'token_level', cfg.policy.name, cfg.task.name, 'token_level_cdf_and_solve_rate_by_hardness.pdf'))

    # # Plot horizontal diff plot
    # plot_quantile_diff(beta_to_correct_np, cfg)


if __name__ == "__main__":
    main()