from tabnanny import verbose
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset, Dataset
from sklearn.metrics import roc_auc_score


from collections import Counter
from typing import List, Dict

import torch
from torch.nn import functional as F

import string 
all_choices = list(string.ascii_uppercase)[:5]


### tools 


### uncertainty 

def compute_nll(probs, answer, verbose): 

    probs = np.array(probs)
    N, C = probs.shape

    answer_dict = {}
    for i, c in enumerate(list(string.ascii_uppercase[:C])):
        answer_dict[c] = i
    
    nll = 0
    for i in range(N):
        p_correct = probs[i, answer_dict[answer[i]]]
        nll += -np.log(p_correct + 1e-12)
    
    nll /= N
    if verbose:
        print(f"NLL = {nll:.4f}")

    return nll

def compute_brier_score(probs, answer, verbose):
    
    probs = np.array(probs)
    N, C = probs.shape

    answer_dict = {}
    for i, c in enumerate(list(string.ascii_uppercase[:C])):
        answer_dict[c] = i

    brier = 0
    for i in range(N):
        true_dist = np.zeros(C) 
        true_dist[answer_dict[answer[i]]] = 1.0
        brier += np.sum((probs[i] - true_dist) ** 2)

    brier /= N
    if verbose:
        print(f"Brier score = {brier:.4f}")
    return brier 



### csqa evaluation 


def mg_seq(sample, n_max, num_choice, seed, prefix, verbose = True): 
    '''
    INPUT
    sample : sub dataset for a single question = {"probs" : , "output_str": }
    n_max(int) : maximum length you'd like to check the convergence

    OUTPUT
    cleaned : dictionary of cleaned probs and output_str
    '''

    J = len(sample["id"])
    sample = Dataset.from_dict(sample)
    all_choices = list(string.ascii_uppercase)[:num_choice]

    # check if output follows the prompt 
    gen_probs = []
    gen_output = []

    for j in range(J): 
        probs = np.array(sample["probs"][j]) 
        gen_len = np.sum(np.sum(probs, axis = 1) != 0) 
        
        output = sample["output_str"][j].strip().split("\n")
        str_len = len(output)

        if (gen_len >= str_len) & (str_len >= n_max):
            gen_probs.append(probs[:n_max, :])
            gen_output.append(output[:n_max])


    try: 
        gen_probs = np.stack(gen_probs, axis = 0)
          
        if len(gen_output) < J // 2:
            cleaned = {
                "probs" : None,    
                "output_str" : None ,    
                "phi_hat" : None,  
                "mle_paths" : None,   
                "best_dist" : None,
                "answerKey" : sample["answerKey"][0]
            }

            if verbose:
                print(f"Total {len(gen_output)} sequences are preprocessed.") 
                print("DEPRECATED: More than half of sequences are discarded.")
                
        else: 
            ans = sample[0]["answerKey"]

            pre_len = 0
            pre_gen_output = gen_output
            if seed:
                pre_str = [sample["seed"][j].strip().split("\n") for j in range(J)]
                pre_str = [[sample["seed"][j][p] for p in range(len(sample["seed"][j]))] for j in range(J)]
                pre_len = len(pre_str[0])
                pre_gen_output = [ pre_str[j] + gen_output[j] for j in range(len(gen_output))]
            
            mle_paths = []
            for j in range(len(pre_gen_output)): 
                mle = [[Counter(pre_gen_output[j][:n]).get(c, 0) / n for c in all_choices] for n in range(pre_len + 1, n_max + pre_len + 1)]
                mle = np.array(mle)
                mle_paths.append(mle)
            mle_paths = np.stack(mle_paths)
            mean_mle = np.mean(mle_paths, axis = 0)
            mean_mle = mean_mle / np.sum(mean_mle, axis = -1, keepdims = True)

            best_guess = np.argmax(mle_paths, axis = -1)  # [J, n_max] 
            best_dist = [Counter(best_guess[:, n].tolist()).get(i, 0) / best_guess.shape[0] for n in range(n_max) for i in range(num_choice)]
            best_dist = np.array(best_dist).reshape(n_max, num_choice)  # [n_max, 5] 

            loc_probs = np.mean(gen_probs, axis = 0)
            loc_probs = loc_probs / np.sum(loc_probs, axis = -1, keepdims = True)
            cleaned = {
                "probs" : gen_probs,    # [J, n_max, 5] tensor
                "output_str" : gen_output ,    # list of string (J, n_max)
                "phi_hat" : loc_probs,  # [n_max, 5] tensor : E[X_n | prompt] estimated by tower property, MC method 
                "m1_paths" : mle_paths,   # [J, n_max, 5] tensor of MLE estimates 
                "m1_mean" : mean_mle,  # [n_max, 5] numpy array of mean MLE estimates
                "m2_dist" : best_dist, # [n_max, 5] numpy array of aggregated MAP
                "answerKey" : ans
            }

            if verbose: 
                print(f"Correct answer is {ans}.")
                print(f"Total {len(gen_output)} sequences are preprocessed.") 

    except: 
        print("No valid sequences found.")
        cleaned = {
            "probs" : None,    
            "output_str" : None ,    
            "phi_hat" : None,  
            "mle_paths" : None,   
            "best_dist" : None,
            "answerKey" : None
        }

    return cleaned





def plot_m_probs(probs, bootstrap = True):

    num_choice = probs.shape[2]
    all_choices = list(string.ascii_uppercase)[:num_choice]
    y = np.mean(probs, axis = 0)
    std = np.std(probs, axis = 0)
    rng = np.random.default_rng(2025)
    n = int(probs.shape[0])
    B = 10000
    alpha = 0.04

    p_boot = np.zeros((B, probs.shape[1], probs.shape[2]))
    for b in range(B):
        idx = np.array(rng.integers(0, n, size = n))
        # print(idx)
        s = np.take(probs, indices=idx, axis=0)
        # print(s.shape)
        p_boot[b, :, :] = np.mean(s, axis = 0)

    lower = np.quantile(p_boot, alpha / 2, axis = 0)
    high = np.quantile(p_boot, 1 - alpha / 2, axis = 0)


    plt.plot(np.arange(probs.shape[1]), y, linewidth = 0.5)
    if bootstrap: 
        for i, c in enumerate(all_choices):
            plt.fill_between(np.arange(probs.shape[1]), lower[:, i], high[:, i], alpha = 0.2, color = 'grey', ec = None)

    else: 
        for i, c in enumerate(all_choices):
            low = y - std / np.sqrt(probs.shape[0]) * 1.96
            high = y + std / np.sqrt(probs.shape[0]) * 1.96
            plt.fill_between(np.arange(probs.shape[1]), low[:, i], high[:, i], ec = None, alpha = 0.2)
    plt.xlabel("sequence length")
    plt.ylabel("$E[p_M(X_{N} \\vert X_{1:N-1})]$")
    plt.legend(all_choices, loc = 'upper right')
    plt.ylim([0, 1])
    plt.show()





def plot_mg_post(mle_paths, bootstrap = True, ans_path = False):

    num_choice = mle_paths.shape[2]
    all_choices = list(string.ascii_uppercase)[:num_choice]
    y = np.mean(mle_paths, axis = 0)
    y /= np.sum(y, axis = -1, keepdims = True)
    std = np.std(mle_paths, axis = 0)
    rng = np.random.default_rng(2025)
    J, n, _ = mle_paths.shape
    B = 10000
    alpha = 0.04

    
    plt.plot(np.arange(n), y, linewidth = 0.5)

    if bootstrap: 
        p_boot = np.zeros((B, n, num_choice))
        for b in range(B):
            idx = np.array(rng.integers(0, J, size = J))
            # print(idx)
            s = mle_paths[idx, :, :]
            # print(s.shape)
            p_boot[b, :, :] = np.mean(s, axis = 0)
        lower = np.quantile(p_boot, alpha / 2, axis = 0)
        high = np.quantile(p_boot, 1 - alpha / 2, axis = 0)
        for i, c in enumerate(all_choices):
            plt.fill_between(np.arange(n), lower[:, i], high[:, i], alpha = 0.2, color = 'grey', ec = None)

    else: 
        for i, c in enumerate(all_choices):
            low = y - std / np.sqrt(J) * 1.96
            high = y + std / np.sqrt(J) * 1.96
            plt.fill_between(np.arange(n), low[:, i], high[:, i], ec = None, alpha = 0.2)

    plt.xlabel("sequence length n")
    plt.ylabel("$E[\\hat{\\theta}(x_{1:n})]$")
    plt.legend(all_choices, loc = 'upper right')
    plt.ylim([0, 1])
    plt.show()




def get_confidence(ds, J, n_max, burn, num_choice, seed, prefix):

    conf_dict = {"first_m_prob": [],     # p_phi(A_1 | Q)
                 "last_mean_mle": [],    # E[theta_n | Q]
                #  "map_mp_dist": []
                 }      # aggregated MAP dist
    acc_dict = {"first_m_prob": [], 
                "last_mean_mle": [],
                # "map_mp_dist": []
                }
    
    all_choices = list(string.ascii_uppercase)[:num_choice]

    num_q = int(len(ds)/J)
    print(num_q)
    fail = 0 
    valid_idx = []

    for i in range(1, num_q + 1): 
        sub_ds = ds[J * (i -1) : J * i]
        mg_dict = mg_seq(sub_ds, n_max, num_choice, seed, prefix, verbose = False)

        if mg_dict["probs"] is None:
            fail += 1
            continue
        
        # p_phi(A_1 | Q)
        first_m_prob = mg_dict["phi_hat"][burn]

        conf = np.max(first_m_prob) 
        acc = (mg_dict["answerKey"] == all_choices[int(np.argmax(first_m_prob))])

        conf_dict["first_m_prob"].append(conf)
        acc_dict["first_m_prob"].append(acc)


        # E[theta_n | Q]
        mean_mle = mg_dict["m1_mean"]
        last_mean_mle = mean_mle[-1, :]

        conf_dict["last_mean_mle"].append(np.max(last_mean_mle))
        acc_dict["last_mean_mle"].append(int(mg_dict["answerKey"] == all_choices[int(np.argmax(last_mean_mle))]))
        

        # aggregated MAP dist
        # conf_dict["map_mp_dist"].append(np.max(mg_dict["m2_dist"][-1, :]))
        # acc_dict["map_mp_dist"].append(int(mg_dict["answerKey"] == all_choices[int(np.argmax(mg_dict["m2_dist"][-1, :]))]))

        valid_idx.append(i - 1)

    print(f"Total {fail} questions are skipped due to invalid generations.")

    return conf_dict, acc_dict 




def get_scores(ds, J, n_max, burn, num_choice, seed, prefix, verbose):

    score_dict = {"model": {},     # p_phi(A_1 | Q)
                 "m1": {},    # E[theta_n | Q]
                #  "m2": {}
                 }      # aggregated MAP dist

    num_q = int(len(ds)/J)
    print(num_q)
     
    valid_idx = []

    for key in score_dict.keys():
        fail = 0
        ans = []
        probs = []
        for i in range(1, num_q + 1): 
            sub_ds = ds[J * (i -1) : J * i]
            mg_dict = mg_seq(sub_ds, n_max, num_choice, seed, prefix, verbose = False)

            if mg_dict["probs"] is None:
                fail += 1
                continue
            
            ans.append(mg_dict["answerKey"])

            if key == "model":  # p_phi(A_1 | Q)
                prob = mg_dict["phi_hat"][burn]
            elif key == "m1":
                mean_mle = mg_dict["m1_mean"]
                prob = mean_mle[-1, :] 
            # elif key == "m2":
            #     prob = mg_dict["m2_dist"][-1, :]
                
            probs.append(prob)
            valid_idx.append(i - 1)
        probs = np.stack(probs, axis = 0)
        
        score_dict[key]["nll"] = compute_nll(probs, ans, False)
        score_dict[key]["bs"] = compute_brier_score(probs, ans, False) 
        score_dict[key]["auc"] = roc_auc_score(ans, probs, multi_class = "ovr")

    if verbose: 
        print(f"Total {fail} questions are skipped due to invalid generations.")
        print("=== MODEL === \n")
        for key in score_dict["model"].keys():
            print(f"- {key} : {score_dict['model'][key]:.4f}")
        print("\n=== M1 === \n")
        for key in score_dict["m1"].keys():
            print(f"- {key} : {score_dict['m1'][key]:.4f}")
        # print("\n=== M2 === \n")
        # for key in score_dict["m2"].keys():
        #     print(f"- {key} : {score_dict['m2'][key]:.4f}")


    return score_dict






### mg_violation

def check_mp(ds, J, n_max, tau_max, trunc, num_choice):
    """
    ACF-like based on cosine similarity 
    """

    num_q= int(len(ds) / J)
    acf_list = torch.zeros((tau_max, num_q))

    for i in range(1, num_q + 1):
        mg_dict = mg_seq(ds[J * (i - 1) : J * i], n_max, num_choice, verbose = False)
        proj_probs = torch.mean(mg_dict["m1_paths"], dim = 0)
        proj_probs = proj_probs[:trunc, :]

        acf = torch.zeros(tau_max)
        for tau in range(1, tau_max + 1):
            cos_tensor = torch.zeros(proj_probs.shape[0] - tau)
            for t in range(proj_probs.shape[0] - tau):
                teach = proj_probs[t, :] / torch.sum(proj_probs[t, :])
                stud = proj_probs[t + tau, :] / torch.sum(proj_probs[t + tau, :])
                cos_tensor[t] = F.cosine_similarity(teach, stud, dim = -1)
            acf[tau - 1] = torch.mean(cos_tensor)
        acf_list[:, i-1] = acf

    return acf_list









#### ECE : single answers ####

def clean_ece(ds_path): 
    ece_ds = load_dataset("json", data_files= ds_path)["train"]

    ece_ds = {
        "id" : ece_ds["id"][0], 
        "question" : ece_ds["question"][0],
        "answerKey": ece_ds["answerKey"][0],
        "probs": ece_ds["probs"][0]
    }

    ece_ds["output"] = [np.argmax(ece_ds["probs"][i]) for i in range(len(ece_ds["probs"]))]

    answer_list = {"A" : 0, "B": 1, "C": 2, "D":3, "E": 4} 

    ece_ds["confidence"] = np.max(np.array(ece_ds["probs"]), axis=1)
    score_list = []

    for i in range(len(ece_ds["output"])):
        if ece_ds["answerKey"][i] not in answer_list.keys():
            num = int(ece_ds["answerKey"][i])
            score_list.append(int(all_choices[num - 1] == int(ece_ds["output"][i])))
        else:
            score_list.append(int(answer_list[ece_ds["answerKey"][i]] == int(ece_ds["output"][i])))

    return ece_ds 

def plot_ece(ece_ds, N, num_choice, bins, plot, colors, label):

    p_list = []
    acc_list = []
    bin_centers = []

    conf = np.array(ece_ds["confidence"])
    correct = np.array(ece_ds["correct"])

    for b in range(bins):
        intv_len = 1 / bins * (1 - 1 / num_choice) 
        low = 1 / num_choice + intv_len * b
        high = 1 / num_choice + intv_len * (b + 1)

        if b == bins - 1:
            mask = (conf >= low) & (conf <= high)
        else:
            mask = (conf >= low) & (conf < high)

        count = np.sum(mask)

        if count == 0:
            acc_list.append(0.0)
            p_list.append(0.0)
            bin_centers.append(low + intv_len / 2)
            continue

        bin_conf = np.mean(conf[mask])
        bin_acc  = np.mean(correct[mask])

        p = abs(bin_acc - bin_conf) * (count / N)

        p_list.append(p)
        acc_list.append(bin_acc)
        bin_centers.append(low + intv_len / 2)

    ece = np.sum(p_list)
    print("ECE =", ece)
    print("Accuracy = ", np.sum(correct) / N)

    if plot:
        plt.plot(bin_centers, acc_list, color=colors)
        plt.scatter(bin_centers, acc_list, color=colors, label=label)
        plt.plot([1 / num_choice - intv_len / 2, 1], [1 / num_choice - intv_len / 2, 1], "--", color="grey")
        plt.ylabel("Fraction of positives")
        plt.xlabel("Mean predicted probability")
        plt.legend()

    return ece, p_list, acc_list






#### ECE sequential answers #####

def calib_mp(ds, method, J, n_max, burn, num_choice, seed, prefix, bins, plot, color, label, marker) -> Dict:

    """
    Computes the accuracy and confidence based on each method:
    (1) use the first token prob 
    (2) use the last MLE prob

    INPUT: 
    - method: str, one of ["model", "m1"]

    OUTPUT: 
    """

    method_dict = {
        "model" : "first_m_prob",
        "m1" : "last_mean_mle",
        # "m2" : "map_mp_dist"
    }
    conf_dict, acc_dict = get_confidence(ds, J, n_max, burn, num_choice, seed, prefix)
    conf_list = np.array(conf_dict[method_dict[method]])
    correct = np.array(acc_dict[method_dict[method]])

    # ece plot
    
    p_list = []
    acc_list = []
    bin_centers = []

    num_q = len(conf_list)

    for b in range(bins):
        intv_len = 1 / bins * (1 - 1 / num_choice) 
        lo = 1 / num_choice + intv_len * b
        up = 1 / num_choice + intv_len * (b + 1)
        
        mask = (conf_list >= lo) & (conf_list < up)
        if b == bins -1:
            mask = (conf_list >= lo) & (conf_list <= up)

        count = np.sum(mask)

        if count == 0:
            acc_list.append(0.0)
            p_list.append(0.0)
            bin_centers.append(lo + intv_len / 2)
            continue
        
        bin_conf = np.mean(conf_list[mask])
        bin_acc  = np.mean(correct[mask])

        p = abs(bin_acc - bin_conf) * (count / num_q)
        p_list.append(p)

        if count == 1:
            bin_acc = 0 # avoid overfitting case
            
        acc_list.append(bin_acc)
        bin_centers.append(lo + intv_len / 2)
    
    ece = np.sum(p_list) 
    acc = np.sum(correct) / num_q

    print(f"ECE: {ece:.4f} \n ACC: {acc:.4f}")
    
    if plot:
        plt.plot(bin_centers, acc_list, color=color)
        plt.scatter(bin_centers, acc_list, color=color, label=label, marker = marker, s= 25)
        plt.plot([0, 1], [0, 1], "--", color="grey")
        plt.ylabel("Fraction of positives")
        plt.xlabel("Mean predicted probability")
        plt.legend()

    ece_dict = {
        "ece" : ece, 
        "acc" : acc, 
        "bin_prob" : p_list
    }

    return ece_dict



def plot_freq_ece(ds, J, n_max, num_choice):
    
    choice_labels = list(string.ascii_uppercase)[:num_choice]
    answer_num = {}
    for i, c in enumerate(choice_labels):
        answer_num[c] = i   

    
    mean_theta_inf = []
    answers = []
    for i in range(len(ds) // J):
        sub_ds = ds[i * J: (i+1) * J]
        mg_dict = mg_seq(sub_ds, n_max, num_choice, False) 
        if mg_dict["probs"] is None:
            continue
        mp_mean_mle = mg_dict["m1_mean"]
        mean_theta_inf.append(mp_mean_mle[-1])  # E[theta_n | x] for each sample
        answers.append(sub_ds["answerKey"][0])
        # check convergence 

    mean_theta_inf = np.array(mean_theta_inf)  # (num_samples, 4)
    mean_theta_inf = mean_theta_inf / np.sum(mean_theta_inf, axis = -1, keepdims = True)
    confidence = np.max(mean_theta_inf, axis = -1)  # (num_samples, )
    predictions = np.argmax(mean_theta_inf, axis = -1)  # (num_samples, )
    correct = np.array([int(answer_num[ans] == int(pred)) for ans, pred in zip(answers, predictions)])

    print(len(confidence))
    print("Accuracy: ", f"{np.mean(correct):.4f}")
    print("NLL: ", f"{-np.mean([np.log(mean_theta_inf[i][answer_num[ans]] + 1e-8) for i, ans in enumerate(answers)]):.4f}")
    print("ECE: ", f"{compute_ece(confidence, correct, bins = 8, plot = True):.4f}")





















import matplotlib.cm as cm
import matplotlib.colors as mcolors

def compute_ece(conf, correct, bins = 10, num_choice = 4, verbose = False, plot = False):
    p_list = []
    acc_list = []
    bin_centers = []
    N = len(conf)
    counts = []
    for b in range(bins):
        intv_len = 1 / bins * (1 - 1 / num_choice) 
        low = 1 / num_choice + intv_len * b
        high = 1 / num_choice + intv_len * (b + 1)

        if b == bins - 1:
            mask = (conf >= low) & (conf <= high)
        else:
            mask = (conf >= low) & (conf < high)

        count = np.sum(mask)
        counts.append(count)

        if count == 0:
            acc_list.append(0.0)
            p_list.append(0.0)
            bin_centers.append(low + intv_len / 2)
            continue

        bin_conf = np.mean(conf[mask])
        bin_acc  = np.mean(correct[mask])

        p = abs(bin_acc - bin_conf) * (count / N)

        p_list.append(p)
        acc_list.append(bin_acc)
        bin_centers.append(low + intv_len / 2)

    ece = np.sum(p_list)
    if verbose:
        print(f"ECE = {ece:.4f}")
        print(f"Accuracy = {np.sum(correct) / N:.4f}")

    if plot:
            fig, ax = plt.subplots(figsize=(8, 6))
            
            # 1. Create a Color Map based on counts
            cmap = cm.get_cmap('Blues') # You can change to 'Reds', 'Viridis', etc.
            max_count = max(counts) if max(counts) > 0 else 1
            norm = mcolors.Normalize(vmin=0, vmax=max_count)
            bar_colors = [cmap(norm(c)) for c in counts]
            
            # 2. Plot the Bars
            # Width is slightly reduced (*0.9) to show gaps between bins clearly
            ax.bar(bin_centers, acc_list, width=intv_len * 0.95 , color=bar_colors, edgecolor='black', label='Model Accuracy', alpha = .7)
            
            # 3. Plot the Ideal Calibration Line (Diagonal)
            ax.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfect Calibration')
            
            # 4. Add Colorbar
            sm = cm.ScalarMappable(cmap=cmap, norm=norm)
            sm.set_array([])
            cbar = fig.colorbar(sm, ax=ax)
            cbar.set_label('Number of samples in bin')
            
            # Labels
            ax.set_xlabel('Mean predicted probability')
            ax.set_ylabel('Fraction of positives')
            ax.set_title(f'Reliability Diagram (ECE = {ece:.4f})\nColor Intensity = Bin Count')
            ax.legend()
            ax.grid(True, linestyle=':', alpha=0.6)
            
            # Ensure limits make sense
            ax.set_xlim(0, 1.0)
            ax.set_ylim(0, 1.0)
            
            plt.show()
    return ece 



def is_bayes(qoi, ds_list, prefix_list, J, n_max, num_choice, fix_final, label):

    dist_list = []
    for ds in ds_list:
        dist = {"l2": [], "kl": [], "cos" : [], "l1": []}
        prefix = prefix_list[ds_list.index(ds)]
        for i in range(len(ds) // J):
            sub_ds = ds[i * J: (i+1) * J]
            mg_dict = mg_seq(sub_ds, n_max, num_choice, prefix, False) 
            if (mg_dict["probs"] is None) or torch.any(torch.isnan(mg_dict["phi_hat"])):
                continue

            if qoi == "mle":
                qoi_dist = mg_dict["m1_mean"]   # normalize 
            elif qoi == "map": 
                qoi_dist = mg_dict["best_dist"]

            pred_dist = mg_dict["phi_hat"]

            if fix_final:
                pred_dist = pred_dist[-1]

            kl = np.sum(pred_dist * (np.log(pred_dist + 1e-10) - np.log(qoi_dist + 1e-10)), axis = -1)
            dist["kl"].append(kl)
            dist["l2"].append(np.linalg.norm(pred_dist - qoi_dist, axis = -1))
            dist["l1"].append(np.linalg.norm(pred_dist - qoi_dist, axis = -1, ord = 1))
            dist["cos"].append(F.cosine_similarity(torch.tensor(pred_dist), torch.tensor(qoi_dist), dim = -1).numpy())

        # draw confidence interval ! 
        dist["kl"] = np.array(dist["kl"])
        dist["l2"] = np.array(dist["l2"])
        dist["cos"] = np.array(dist["cos"])
        dist["l1"] = np.array(dist["l1"])
        dist_list.append(dist)
    
    import matplotlib as mpl
    colors = mpl.colormaps['tab20b'].colors
    for key in dist_list[0].keys():
        for j, dist in enumerate(dist_list):
            plt.plot(np.mean(dist[key], axis = 0), label = label[j], color = colors[j + 14])
        plt.xlabel("n")
        plt.ylabel(r"$d(p_\infty^\phi, E[\theta | Q])$")
        plt.title(f"Distance metric : ({key})")
        plt.legend()
        plt.show()







def mp_violation_scores(ds, J, n_max, prefix):

    num_choices = 5 
    choice_labels = list(string.ascii_uppercase)[:num_choices]

    ftn_list = ["one", "entropy", "is_most"]
    k_list = [1, 2, 4, 8]
    K = len(k_list)
    all_scores = {key : [] for key in ftn_list}
    for i in range(len(ds) // J):
        sub_ds = ds[(i * J) : (i + 1) * J]
        mg_dict = mg_seq(sub_ds, n_max, num_choices, prefix, verbose = False)
        if mg_dict["probs"] is None:
            continue 
        probs = mg_dict["probs"]
        mle_paths = mg_dict["m1_paths"]
        out_str = mg_dict["output_str"]


        for key in ftn_list:
            score_list = []
            for n in range(1, n_max - 1):
                p_n = probs[:, n, :]
                p_n1 = probs[:, n + 1, :]
                p_diff = p_n1 - p_n 

                if key == "entropy":
                    n_freq = mle_paths[:, n-1, :]  # F_n
                    h = -np.sum(n_freq * np.log(n_freq + 1e-12), axis=-1, keepdims=True)  # (J,1)
                elif key == "is_most":
                    n_out = [out_str[j][n-1] for j in range(len(out_str))]  # A_n, F_n
                    n_most_idx = np.argmax(mle_paths[:, n-1, :], axis=-1)    # F_n
                    n_most = [choice_labels[n_most_idx[j]] for j in range(len(n_most_idx))]
                    is_most = np.array([1.0 if n_out[j] == n_most[j] else 0.0 for j in range(len(n_out))]).reshape(-1, 1)
                else:
                    h = None  # not used

                mp_score_k = []  # list of (num_choices,) for each k
                for k in k_list:
                    if n + k >= n_max:
                        continue

                    p_n  = probs[:, n, :]
                    p_nk = probs[:, n + k, :]
                    p_diff = p_nk - p_n  # k-step increment

                    if key == "one":
                        mp_diff = p_diff
                    elif key == "entropy":
                        mp_diff = p_diff * h
                    elif key == "is_most":
                        mp_diff = p_diff * is_most

                    mp_score = np.nanmean(mp_diff, axis=0)  # (num_choices,)
                    mp_score_k.append(mp_score)
                    
                if len(mp_score_k) != K:
                    continue

                score_list.append(np.stack(mp_score_k, axis=0))  # (K, num_choices)
            all_scores[key].append(np.stack(score_list, axis=0))

    for key in ftn_list:
        all_scores[key] = np.array(all_scores[key])  # (num_trials, n_max -1, num_choices)

    return all_scores