import argparse
from operator import sub
import os
import sys
import math
from itertools import product
import time
from pathlib import Path
from turtle import color, title
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import torch
from torch import nn
import torch.nn.functional as F
import seaborn as sns
from datetime import datetime
from tqdm import trange
from scipy.stats import poisson, nbinom



def poisson_log_prob(rate, z):
    return z * rate.log() - rate - torch.lgamma(z + 1)


def nb_log_prob(z: torch.Tensor, r: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
    return torch.lgamma(r + z) - torch.lgamma(z + 1) - torch.lgamma(r) + r * torch.log(p) + z * torch.log(1 - p)


def compute_upperbound(rate: float, percentile: float = 1e-2, r: float = None, p: float = None):
    if r is not None and p is not None:
        r_val, prob = r, p
        assert r_val > 0 and 0 < prob < 1, f"Invalid parameters for nb: r={r_val}, prob={prob}"
        nb = stats.nbinom(n=r_val, p=prob)
        n_exp = nb.ppf(1.0 - percentile)
        return int(n_exp) + 3

    assert rate > 0.0, f"rate must be positive, got: {rate}"
    pois = stats.poisson(rate)
    n_exp = pois.ppf(1.0 - percentile)
    return max(int(n_exp), 3)


def compute_exact_grad(f_fn: callable, rate: float= None, r:float = None, p:float = None):
    if r is not None and p is not None:
        # Negative Binomial
        r = torch.tensor(r, dtype=torch.float64)
        p = torch.tensor(p, dtype=torch.float64)
        upperbound = compute_upperbound(rate=None, percentile=1e-6, r=r.item(), p=p.item())
        z = torch.arange(upperbound + 1, dtype=torch.float64)
        pmf = torch.exp(nb_log_prob(z, r, p))
        exp_term_p = (f_fn(z, rate) * pmf * (r / p - z / (1 - p))).sum()
        exp_term_r = (f_fn(z, rate) * pmf * (torch.digamma(r + z) - torch.digamma(r) + torch.log(p))).sum()
        return exp_term_r.item(), exp_term_p.item()
    
    assert rate is not None and r is None and p is None
    upperbound = int(rate.item() + 20)
    z = torch.arange(upperbound + 1, dtype=torch.float64)
    f_z = f_fn(z, rate)
    lam = rate.detach().clone().to(torch.float64)
    pmf = torch.exp(poisson_log_prob(lam, z))
    exp_term = (f_z * pmf * (z / lam - 1)).sum()
    return exp_term.item()


def load_fn_results(base, f_map, target_method, n_mc=500, max_rate=10.0, max_tau=5.0):
    dfs = []
    for fname, f in f_map.items():
        matches = sorted(base.glob(f"df_result_poisson_{fname}_*.csv"))
        if not matches:
            print(f"[skip] no file for {fname}")
            continue

        latest = matches[-1]
        print(f"Loading: {latest}")
        df_loaded = pd.read_csv(latest)
        df_proc, _ = compute_mae_and_best_tau(df_loaded, target_method, f, max_tau=max_tau)
        df_proc["fn_name"] = fname
        dfs.append(df_proc)

    if not dfs:
        raise ValueError("No files loaded.")
    df_all = pd.concat(dfs, ignore_index=True)
    return df_all[(df_all["n_monte_carlo"] == n_mc) & (df_all["rate"] <= max_rate) & (df_all["$\\tau$"] <= max_tau)]


def compute_mae_and_best_tau(df_raw: pd.DataFrame, method: str, f_fn, max_tau: float = 5.0):
    method_norm = method.lower()
    df = df_raw[df_raw["method"].str.lower() == method_norm].copy()
    # filter out rows with tau > max_tau
    df = df[df["$\\tau$"] <= max_tau].copy()
    if df.empty:
        raise ValueError(f"No rows found for method={method}")

    df["true_grad"] = df["rate"].apply(
        lambda r: compute_exact_grad(f_fn, torch.tensor(float(r), dtype=torch.float64))
    )
    df["error"] = df["grad_estimate"] - df["true_grad"]
    df["mae"] = df["error"].abs()

    agg_cols = ["rate", "n_monte_carlo", "$\\tau$"]
    mae_stats = (
        df.groupby(agg_cols)["mae"]
        .agg(mae_mean="mean", mae_std="std", mae_count="count")
        .reset_index()
    )

    idx_best = mae_stats.groupby(["rate", "n_monte_carlo"])["mae_mean"].idxmin()
    best_tau_df = (
        mae_stats.loc[idx_best, ["rate", "n_monte_carlo", "$\\tau$", "mae_mean", "mae_std", "mae_count"]]
        .rename(columns={"$\\tau$": "best_tau"})
        .reset_index(drop=True)
    )

    df = df.merge(
        best_tau_df[["rate", "n_monte_carlo", "best_tau"]],
        on=["rate", "n_monte_carlo"],
        how="left",
    )
    return df, best_tau_df


def plot_mae_tau_grid(
    df: pd.DataFrame,
    selected_n_mc,
    max_rate: float = 10.0,
    n_cols: int = 4,
    figsize=(16, 48),
    title: str = ""
):
    # filter
    sub = df[(df["n_monte_carlo"].isin(selected_n_mc)) & (df["rate"] <= max_rate)].copy()

    agg = (
        sub.groupby(["rate", "n_monte_carlo", "$\\tau$"])["mae"]
        .agg(
            mae_mean="mean",
            p10=lambda x: np.percentile(x, 10),
            p90=lambda x: np.percentile(x, 90),
        )
        .reset_index()
        .rename(columns={"$\\tau$": "tau"})
        .sort_values("tau")
    )

    rates = sorted(sub["rate"].unique())
    n_rows = math.ceil(len(rates) / n_cols)
    palette = sns.color_palette("tab20", len(selected_n_mc))
    fig, axs = plt.subplots(n_rows, 
                            n_cols, 
                            figsize=figsize, 
                            squeeze=False,
                            # sharey = True,
                            )

    for i, rate in enumerate(rates):
        r, c = divmod(i, n_cols)
        ax = axs[r, c]
        sub_rate = agg[agg["rate"] == rate]

        for (n_mc, sub_mc), color in zip(sub_rate.groupby("n_monte_carlo"), palette):
            ax.plot(
                sub_mc["tau"], 
                sub_mc["mae_mean"], 
                label=f"n_mc={n_mc}", 
                color=color
                )
            ax.fill_between(
                sub_mc["tau"], 
                sub_mc["p10"], 
                sub_mc["p90"], 
                color=color, 
                alpha=0.2
                )
            
            best_tau = (
                sub.loc[
                    (sub["rate"] == rate) & (sub["n_monte_carlo"] == n_mc),
                    "best_tau",
                ]
                .dropna()
                .unique()
            )
            if best_tau.size > 0:
                best_tau_val = best_tau[0]
                hit = sub_mc[sub_mc["tau"] == best_tau_val]
                if not hit.empty:
                    ax.scatter(
                        best_tau_val,
                        hit["mae_mean"].iloc[0],
                        marker="x",
                        color="tab:red",
                        zorder=5,
                        alpha=0.9,
                    )

        ax.set_xscale("log")
        ax.set_xlabel("tau")
        ax.set_ylabel("MAE")
        ax.set_ylim(-1, 4)
        ax.set_title(f"rate = {rate}")
        ax.grid(True, linestyle="--", alpha=0.4)
        ax.legend(title="n_mc", fontsize=8, ncol=2)

    for j in range(i + 1, n_rows * n_cols):
        r, c = divmod(j, n_cols)
        axs[r, c].axis("off")

    if title:
        fig.suptitle(title, fontsize=16, y=0.995)
        pad_in = 0.6
        top = 1 - pad_in / fig.get_size_inches()[1]
        fig.tight_layout(rect=[0, 0, 1, top])
    else:
        fig.tight_layout()
    plt.show()


def plot_fn_rate_grid(
    df: pd.DataFrame,
    fn_col: str = "fn_name",
    n_mc: int = 500,
    max_rate: float = 10.0,
    n_cols: int = 4,
    figsize=(16, 48),
    title: str = "",
):
    # filter
    sub = df[
        (df["n_monte_carlo"] == n_mc) &
        (df["rate"] <= max_rate)
    ].copy()

    # aggregate (rate, fn, tau)
    agg = (
        sub.groupby(["rate", fn_col, "$\\tau$"])["mae"]
        .agg(
            mae_mean="mean",
            p10=lambda x: np.percentile(x, 10),
            p90=lambda x: np.percentile(x, 90),
        )
        .reset_index()
        .rename(columns={"$\\tau$": "tau"})
        .sort_values(["rate", fn_col, "tau"])
    )

    rates = sorted(agg["rate"].unique())
    fn_names = sorted(agg[fn_col].unique())
    palette = sns.color_palette("tab20", len(fn_names))

    n_rows = math.ceil(len(rates) / n_cols)
    fig, axs = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False)

    for i, rate in enumerate(rates):
        r, c = divmod(i, n_cols)
        ax = axs[r, c]
        data_r = agg[agg["rate"] == rate]

        for color, fn_name in zip(palette, fn_names):
            d = data_r[data_r[fn_col] == fn_name]
            if d.empty:
                continue

            ax.plot(d["tau"], d["mae_mean"], label=fn_name, color=color)
            ax.fill_between(d["tau"], d["p10"], d["p90"], color=color, alpha=0.2)

            # best_tau marker（和 agg 对齐）
            best_tau = (
                sub.loc[
                    (sub["rate"] == rate) & (sub[fn_col] == fn_name),
                    "best_tau",
                ]
                .dropna()
                .unique()
            )
            if best_tau.size > 0:
                hit = d[d["tau"] == best_tau[0]]
                if not hit.empty:
                    ax.scatter(
                        hit["tau"],
                        hit["mae_mean"],
                        color="tab:red",
                        marker="x",
                        zorder=5,
                    )

        ax.set_xscale("log")
        ax.set_title(f"rate = {rate}")
        ax.set_xlabel("tau")
        ax.set_ylabel("MAE")
        ax.grid(True, alpha=0.4)
        ax.legend(title=fn_col, fontsize=8, ncol=2)

    for j in range(i + 1, n_rows * n_cols):
        axs[j // n_cols, j % n_cols].axis("off")

    if title:
        fig.suptitle(title, fontsize=16)
        fig.tight_layout(rect=[0, 0, 1, 0.97])
    else:
        fig.tight_layout()

    plt.show()



def plot_best_tau_rate_grid(
    df: pd.DataFrame,
    selected_n_mc,
    max_rate: float = 10.0,
    figsize=(7, 4),
    title: str = "Best tau vs rate (log-log)",
):
    sub = df[
        (df["n_monte_carlo"].isin(selected_n_mc)) &
        (df["rate"] <= max_rate)
    ].copy()
    if sub.empty:
        raise ValueError("No data after filtering; check selected_n_mc / max_rate / columns.")

    fig, ax = plt.subplots(figsize=figsize)
    sns.lineplot(
        data=sub,
        x="rate",
        y="best_tau",
        hue="n_monte_carlo",
        hue_order=selected_n_mc,        # force full legend ordering
        palette=sns.color_palette("tab10", n_colors=len(selected_n_mc)),  # distinct solid colors
        estimator="mean",
        errorbar=("ci", 95),
        ax=ax,
    )

    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("rate")
    ax.set_ylabel("best tau")
    ax.set_title(title)
    ax.grid(True, which="both", ls="--", alpha=0.4)
    ax.legend(title="n_mc", fontsize=8, ncol=2)
    fig.tight_layout()
    plt.show()