import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

def load_data(file):
    with open(file, 'rb') as f:
        saved_data = pickle.load(f)
        metadata = saved_data.get('metadata', None)
        stats = saved_data.get('stats', [])
        p_nexts = defaultdict(list, saved_data.get('p_nexts', {}))
        Ks = list(p_nexts.keys())
    return p_nexts, Ks, metadata, stats

def compute_kls(p_nexts, Ks, metadata, stats, ref_K=None, n_bootstrap=1000, eps=1e-5):
    if ref_K is None:
        ref_K = max(Ks)  
    
    results = []
    ref_speeds = stats[ref_K].times['p_next']
    text_length = metadata['text_length']
    
    for K in Ks:
        if K >= ref_K:
            continue
            
        assert len(p_nexts[ref_K]) == len(p_nexts[K]), f"{len(p_nexts[ref_K])} != {len(p_nexts[K])}"
        
        kls = []
        speeds = []
        for i, (P, Q) in enumerate(zip(p_nexts[K], p_nexts[ref_K])):
            assert Q, f"Q is empty for i={i}, K={K}"
            assert P, f"P is empty for i={i}, K={K}"
            assert set(P.keys()).issubset(set(Q.keys())), (
                f'P has larger support than Q (reference) for i={i}, K={K}.'
            )
            
            supp = [x for x, p in P.items() if p > 0]
            p = np.array([P[x] for x in supp])
            q = np.array([Q[x] for x in supp])
            
            kls.append(kl_divergence(p, q))
            speeds.append(stats[K].times['p_next'][i])
        
        kls = np.array(kls)
        speeds = np.array(speeds)
        
        results.append({
            'K': K,
            **compute_bootstrap_stats(
                metrics=kls, speeds=speeds, text_length=text_length, n_bootstrap=n_bootstrap
            )
        })
    
    results.append({
        'K': ref_K,
        **compute_bootstrap_stats(
            metrics=None, speeds=ref_speeds, text_length=text_length,  n_bootstrap=n_bootstrap
        )
    })
    
    return pd.DataFrame(results)

def kl_divergence(p, q):
    """ Compute KL divergence of two vectors, K(p || q).
    NOTE: If any value in q is 0.0 then the KL-divergence is infinite.
    """
    nz = p.nonzero()
    p = p[nz]
    q = q[nz]
    return p.dot(np.log(p) - np.log(q)) / np.log(2)

def compute_jds(p_nexts, Ks, metadata, stats, ref_K=None, n_bootstrap=1000, lower=False):
        
    if ref_K is None:
        if lower:
            print(Ks)
            ref_K = min(Ks)
        else:
            ref_K = max(Ks)
    
    results = []
    if ref_K==0:
        ref_speeds = []
    else:
        ref_speeds = stats[ref_K]['times']
    text_length = metadata['text_length']
    
    for K in Ks:
        if lower and K <= ref_K:
            continue
        elif not lower and K >= ref_K:
            continue
        
        jds = []
        speeds = []
        
        nans = 0
        not_nans = 0
        for i, (P, Q) in enumerate(zip(p_nexts[K], p_nexts[ref_K])):
            assert Q, f"Q is empty for i={i}, K={K}"
            assert P, f"P is empty for i={i}, K={K}"
            xs = set(P.keys()) | set(Q.keys())

            p = np.array([P.get(x, 0.0) for x in xs])
            q = np.array([Q.get(x, 0.0) for x in xs])
            
            num_zeros_p = np.sum(p == 0.0)
            num_zeros_q = np.sum(q == 0.0)

            print(f"Number of zeros in p: {num_zeros_p} out of {len(p)}")
            print(f"Number of zeros in q: {num_zeros_q} out of {len(q)}")

            m = 0.5 * (p + q)
            
            jd = 0.5 * (
                kl_divergence(p, m) +
                kl_divergence(q, m)
            )
            
            if not np.isnan(jd):
                if jd < 0:
                    jd = 0
                jds.append(jd)
                speeds.append(stats[K]['times'][i])
                not_nans += 1
            else:
                nans += 1

        print(f"Number of NaNs: {nans} out of {not_nans + nans}")
        jds = np.array(jds)
        speeds = np.array(speeds)
        results.append({
            'K': K,
            **compute_bootstrap_stats(
                metrics=jds, speeds=speeds, text_length=not_nans, n_bootstrap=n_bootstrap
            )
        })
    results.append({
        'K': ref_K,
        **compute_bootstrap_stats(
            metrics=None, speeds=ref_speeds, text_length=text_length, n_bootstrap=n_bootstrap
        )
    })
    return pd.DataFrame(results)

def compute_surprisals(p_nexts, Ks, metadata, stats, text, n_bootstrap=1000):
    results = []
    text_length = metadata['text_length']
    
    for K in Ks:
        surprisals = []
        speeds = []
        for i, next_char in enumerate(text): 
            P = p_nexts[K][i]
            assert P, f"P is empty for i={i}, K={K}"
            assert next_char in P, f"'{repr(next_char)}' not in distribution for i={i}, K={K}"
            
            prob = P[next_char]
            surprisal = -np.log2(prob)
            surprisals.append(surprisal)
            speeds.append(stats[K].times['p_next'][i])
        
        surprisals = np.array(surprisals)
        speeds = np.array(speeds)
        
        results.append({
            'K': K,
            **compute_bootstrap_stats(
                metrics=surprisals, speeds=speeds, text_length=text_length, n_bootstrap=n_bootstrap
            )
        })
    
    return pd.DataFrame(results)

def enumerate_nested_tuple(nested_tuple):
    if not isinstance(nested_tuple, tuple) or len(nested_tuple) == 0:
        return []
    
    left, right = nested_tuple
    return enumerate_nested_tuple(left) + [(left, right)]

def compute_token_surprisals(model, text, n_bootstrap=1000):
    results = []
    surprisals = []
    
    tokens = model.encode_prompt(text)
    
    for context, next_token in tqdm(enumerate_nested_tuple(tokens)):
        probs = model.p_next(context)        
        prob = probs[next_token]
        surprisal = -np.log2(prob)
        surprisals.append(surprisal / len(next_token))
    
    surprisals = np.array(surprisals)
    
    surprisal_bootstrap = [
        np.mean(np.random.choice(surprisals, size=len(surprisals), replace=True))
        for _ in range(n_bootstrap)
    ]
    
    results.append({
        'mean_metric': np.mean(surprisals),
        'metric_ci_lower': np.percentile(surprisal_bootstrap, 2.5),
        'metric_ci_upper': np.percentile(surprisal_bootstrap, 97.5),
        'metrics': surprisals.tolist()
    })

    return pd.DataFrame(results)


def compute_bootstrap_stats(speeds, text_length, n_bootstrap, metrics=None):
    if metrics is not None:
        mean_metric = np.mean(metrics)
        metric_bootstrap = [
            np.mean(np.random.choice(metrics, size=len(metrics), replace=True))
            for _ in range(n_bootstrap)
        ]
        metric_ci_lower = np.percentile(metric_bootstrap, 2.5)
        metric_ci_upper = np.percentile(metric_bootstrap, 97.5)
    else:
        mean_metric = np.nan
        metric_ci_lower = np.nan
        metric_ci_upper = np.nan

    mean_speed = text_length / np.sum(speeds)
    speed_bootstrap = [
        text_length / np.sum(np.random.choice(speeds, size=len(speeds), replace=True))
        for _ in range(n_bootstrap)
    ]
    
    return {
        'mean_metric': mean_metric,
        'metric_ci_lower': metric_ci_lower,
        'metric_ci_upper': metric_ci_upper,
        'chars_per_sec': mean_speed,
        'speed_ci_lower': np.percentile(speed_bootstrap, 2.5),
        'speed_ci_upper': np.percentile(speed_bootstrap, 97.5),
        'metrics': metrics.tolist() if metrics is not None else []
    }

def create_latex_table(df, metric_name, models_per_row=2):
    model_map = {
        "gpt2-large": "GPT2-Large",
        "meta-llama/Llama-3.2-1B": "Llama-3.2-1B",
        "meta-llama/Llama-3.1-8B": "Llama-3.1-8B",
    }
    models_per_table =2
    print(df)
    not_applicable = "\\textcolor{black!40}{(not applicable)}"
    cols_per_model = 2
    df = df.copy()

    def fmt_e(x, digits=1):
        return not_applicable if pd.isna(x) else f"{x:.{digits}e}"
    df['Metric'] = df.apply(
        lambda x: not_applicable if pd.isna(x['mean_metric']) else
        f"{fmt_e(x['mean_metric'])} ({fmt_e(x['metric_ci_lower'])}, {fmt_e(x['metric_ci_upper'])})",
        axis=1
    )
    df['Speed'] = df.apply(
        lambda x: not_applicable if pd.isna(x.get('chars_per_sec')) else
        f"{x['chars_per_sec']:.2f} ({x['speed_ci_lower']:.2f}, {x['speed_ci_upper']:.2f})",
        axis=1
    )
    metric_values = sorted(df['K'].unique(), reverse=True)
    present = list(pd.unique(df['model']))
    ordered_models = [m for m in model_map.keys() if m in present]
    extras = [m for m in present if m not in model_map]
    ordered_models += extras

    def disp(m): 
        return model_map.get(m, m)

    parts = []
    for start in range(0, len(ordered_models), max(1, models_per_table)):
        chunk = ordered_models[start:start + max(1, models_per_table)]
        colspec = 'c|' + ('c' * cols_per_model + '|') * (len(chunk) - 1) + 'c' * cols_per_model

        parts.append(f"\\begin{{tabular}}{{{colspec}}}\n")
        parts.append("\\toprule\n")
        header_cells = [
            f"\\multicolumn{{{cols_per_model}}}{{c}}{{\\textbf{{{disp(m)}}}}}" for m in chunk
        ]
        parts.append(" & " + " & ".join(header_cells) + " \\\\\n")
        subhdr = []
        for _ in chunk:
            subhdr.extend([f"average {metric_name} / byte", "byte / sec"])
        parts.append("$\\tau$ & " + " & ".join(subhdr) + " \\\\\n")
        parts.append("\\midrule\n")

        for k in metric_values:
            #row = [fmt_e(k, digits=0)]
            row = [str(k)]
            for m in chunk:
                model_data = df[(df['model'] == m) & (df['K'] == k)]
                if not model_data.empty:
                    metric_str = model_data['Metric'].iloc[0]
                    speed_str  = model_data['Speed'].iloc[0]
                else:
                    metric_str = not_applicable
                    speed_str  = not_applicable
                row.extend([metric_str, speed_str])
            parts.append(" & ".join(row) + " \\\\\n")

        parts.append("\\bottomrule\n\\end{tabular}\n\n")

    return "".join(parts)

def create_latex_table_surprisal(df_streaming, df_tokenized):
    models = df_streaming['model'].unique()
    k_values = sorted(df_streaming['K'].unique()) 
    
    df_streaming['Metric'] = df_streaming.apply(
        lambda x: f"{x['mean_metric']:.3f} ({x['metric_ci_lower']:.2f}, {x['metric_ci_upper']:.2f})", 
        axis=1
    )
    
    header = (
        "\\begin{tabular}{l|" + "c" * len(models) + "}\n"
        "\\toprule\n"
        " & " + " & ".join(f"\\textbf{{{model}}}" for model in models) + "\\\\\n"
        "\\midrule\n"
    )
    
    rows = ""
    
    for k in k_values:
        row_parts = [f"$K={k}$"] 
        for model in models:
            model_data = df_streaming[(df_streaming['model'] == model) & (df_streaming['K'] == k)]
            if len(model_data) > 0:
                metric_str = model_data['Metric'].iloc[0]
                row_parts.append(metric_str)
        rows += " & ".join(row_parts) + " \\\\\n"
    
    ref_row = ["Tokenized LM"]
    for model in models:
        model_data = df_tokenized[df_tokenized['model'] == model]
        if len(model_data) > 0:
            metric_str = f"{model_data['mean_metric'].iloc[0]:.3f} ({model_data['metric_ci_lower'].iloc[0]:.2f}, {model_data['metric_ci_upper'].iloc[0]:.2f})"
        else:
            metric_str = "—"
        ref_row.append(metric_str)
    
    rows += "\\midrule\n" + " & ".join(ref_row) + " \\\\\n"
    
    footer = "\\bottomrule\n\\end{tabular}\n"
    
    return header + rows + footer
