import os
import json
import numpy as np
from pathlib import Path
import torch
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as ticker
from matplotlib.patches import Rectangle


prefix_src_file = "self_repair_from_heads_across_everything.pt"
prefix_dst_file = "self_repair_across_everything.pt"
prefix_cpath = "jobs_arxiv_for_test/250512_collection/20250505_0612_debug_{}{}_exStWd/results"


def pathidx2headidx(path, num_head_arg=6, strict=False):
    pathidx2headidx_dict = {
    0:  "residual",
    1:  "mlp_only",
    **{k: k-2  for k in range( 2,  2 + num_head_arg)},   # 6 heads
    **{k: k-2 - num_head_arg   for k in range( 2+ num_head_arg, 2 + 2 * num_head_arg)},   # 6 heads
    }
    head = [pathidx2headidx_dict[p] for p in path if p not in (0,1)]
    
    if strict:
        head_arr = np.asarray(head)
        head =  [i for i in list(set(head)) if sum(head_arr==i)==2]
    return set(head)

def load_cpt_subsets(CPT_ROOT, example_id):
    """
    Load the JSON file R{XXXX}/C{XXXXXX}.json returning a dict
    layer → list of causal-path subsets.
    """
    path = os.path.join(CPT_ROOT, f"R{example_id:04d}", f"C{example_id:06d}.json")
    data = json.loads(Path(path).read_text())
    cpt = {int(l): subsets for l, subsets in data.items()}
     
    union_layers = defaultdict(set)
    for layer, subsets in cpt.items():
        for subset in subsets:
            union_layers[layer].update(subset)
    return cpt, union_layers


def cutoff_noise(arr, r=5):
    lower_bound = np.percentile(arr, r)
    upper_bound = np.percentile(arr, 100-r)
    trimmed_arr = arr[(arr >= lower_bound) & (arr <= upper_bound)]
    return arr


def trimmed_std(arr, trim_ratio=0.2):
    # for vis except noise
    arr = np.asarray(arr)
    lower = np.percentile(arr, 100 * trim_ratio)
    upper = np.percentile(arr, 100 * (1 - trim_ratio))
    trimmed = arr[(arr >= lower) & (arr <= upper)]
    return np.std(trimmed)
def plot_modelwise_dst_barplot(modelwise_dict):

    plt.rcParams["font.family"] = "Times New Roman"

    models = list(modelwise_dict.keys())
    num_models = len(models)
    
    

    fig, axes = plt.subplots(
        1, num_models,
        figsize=(3 * num_models, 2),
        sharey=False 
    )
    if num_models == 1:
        axes = [axes]

    bar_colors = ["#DC2448", "lightgray"]
    bar_labels = ["Causal path", "Off-path"]

    bar_width = 0.1
    x_positions = [0.2, 0.5]

    for idx, model in enumerate(models):
        
        if "gpt2" in model:
            model_name = "GPT2-xs"
        elif "1b" in model:
            model_name = "Pythia-1b"
        elif "14m" in model:
            model_name = "Pythia-14m"
        ax = axes[idx]
        in_dst = modelwise_dict[model]['in_dst']
        out_dst = modelwise_dict[model]['out_dst']
        means = [np.mean(in_dst), np.mean(out_dst)]
        stds = [trimmed_std(in_dst), trimmed_std(out_dst)]
        medians = [np.median(in_dst), np.median(out_dst)]   
        y_max_model = max([m + s for m, s in zip(means, stds)])
        y_min_model = min([m - s for m, s in zip(means, stds)])
        y_margin = 0.05 * (y_max_model - y_min_model)
        ax.set_ylim(y_min_model-y_margin, max(0, y_max_model) + y_margin)

        ax.bar(
            x_positions,
            means,
            yerr=stds,
            color=bar_colors,
            width=bar_width,
            capsize=3,
            edgecolor='black',
            linewidth=0.75,
            error_kw=dict(lw=1, linestyle='-', ecolor='black', alpha=0.6)
        )
        
        for aaa, (x, y) in enumerate(zip(x_positions, means)):
            if aaa==0:
                tc = "#DC2448"
            else:
                tc = 'lightgray'
            ax.plot(
                x, y,
                marker='o',
                markersize=3,
                markerfacecolor=tc,
                markeredgecolor='black',
                markeredgewidth=0.7,
                linestyle='None',
                zorder=10
            )

            if aaa==0:
                tc = "#DC2448"
            else:
                tc = 'black'
            if y < 0:
                xytext = (x + 0.06, y - 0.04)  
                va = 'top'
            else:
                xytext = (x + 0.06, y + 0.04)  
                va = 'bottom'
        
            ax.annotate(
                text=f'{y:.3f}',
                xy=(x, y),          
                xytext=xytext,   
                textcoords='data',
                ha='left',
                va=va,
                fontsize=11,
                color=tc,
                fontweight='bold' if aaa == 0 else 'normal',
                arrowprops=dict(
                    arrowstyle='->',
                    lw=1,
                    color=tc,
                    shrinkA=12,      
                    shrinkB=8,        
                    patchA=None,
                    patchB=None,
                    connectionstyle="arc3,rad=0"
                ),
                zorder=20
            )

        for x, median in zip(x_positions, medians):
            x_left = x - bar_width / 2 + 0.01
            x_right = x + bar_width / 2 - 0.01

            ax.hlines(
                y=median,
                xmin=x_left,
                xmax=x_right,
                colors='blue',
                linewidth=1.0,
                linestyles='--',
                zorder=5
            )

            ax.text(
                x - bar_width / 2 - 0.004,
                median,
                f'{median:.3f}',
                ha='right',             
                va='center',
                fontsize=11,
                color='blue',
                zorder=20
            )
        ax.yaxis.grid(True, which='major', linestyle='--', linewidth=0.5, color='gray', alpha=0.5)


        ax.set_xticks([0.2, 0.5])
        ax.set_xlim(0, 0.7)
        ax.set_xticklabels(bar_labels, fontsize=9)
        ax.set_title(model_name, fontsize=12)
        ax.set_ylabel("Self-Repair Score" if  idx==0 else "", fontsize=12)
        ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda y, _: f'{y:.2f}'))
        ax.tick_params(axis='y', labelsize=10)
        ax.tick_params(axis='x', labelsize=11)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        ax.add_patch(Rectangle(
            (0, 0), 1, 1,
            transform=ax.transAxes,
            fill=False,
            edgecolor="black",
            linewidth=0.8,
            zorder=10,
            clip_on=False
        ))
    plt.tight_layout(pad=0.5, w_pad=2)
    plt.savefig("zzs.png", dpi=300)



 
def main():
    
    target_folders = [
        "jobs_arxiv_for_test/250512_collection/self_repair/raw/no_eot/gpt2-xs/knowns1000",
        "jobs_arxiv_for_test/250512_collection/self_repair/raw/no_eot/gpt2-xs/lama",
        "jobs_arxiv_for_test/250512_collection/self_repair/raw/no_eot/pythia-1b/knowns1000",
        "jobs_arxiv_for_test/250512_collection/self_repair/raw/no_eot/pythia-1b/lama",
        "jobs_arxiv_for_test/250512_collection/self_repair/raw/no_eot/pythia-14m/knowns1000",
        "jobs_arxiv_for_test/250512_collection/self_repair/raw/no_eot/pythia-14m/lama"
    ]
    
    t_in_path_src = []
    t_in_path_dst = []
    t_out_path_src = []
    t_out_path_dst = []
    
    t_layer_idxs_in = []
    t_layer_idxs_out = []
    modelwise_dict = {}
    
    for target_folder in target_folders:
        
        print("-------------------")
        print(target_folder)
        target_idxs = sorted([int(i) for i in os.listdir(target_folder)])
        
        if "gpt2-xs" in target_folder:
            num_layers, num_heads = 6, 6
            model_name = "gpt2-xs"
        elif "pythia-14m" in target_folder:
            num_layers, num_heads = 6, 4
            model_name = "pythia-14m"
        elif "pythia-1b" in target_folder:
            num_layers, num_heads = 16, 8
            model_name = "pythia-1b"
            
        if "lama" in target_folder:
            data_path = "_lama_trex" 
            data_name = "lama"
        else:
            data_path = ""
            data_name = "known100"
        
        if model_name not in modelwise_dict.keys():
            modelwise_dict[model_name] = {'in_src': [], "out_src": [], 'in_dst': [], "out_dst": []}
        
        
        in_path_src = []
        in_path_dst = []
        out_path_src = []
        out_path_dst = []
        
        layer_idxs_in = []
        layer_idxs_out = []
        
        for t_idx in target_idxs:
            self_repair_dst = torch.load(os.path.join(target_folder, str(t_idx), prefix_dst_file))
            self_repair_src = torch.load(os.path.join(target_folder, str(t_idx), prefix_src_file))
            methods, union_layers = load_cpt_subsets(prefix_cpath.format(model_name, data_path), t_idx)
            
            for layer_idx, v in methods.items():
                union_path = union_layers[layer_idx]
                union_hi = pathidx2headidx(union_path, num_heads, strict=False)
              
                
                for h in range(num_heads):
                    if h in union_hi:
                        in_path_dst.append(self_repair_dst[layer_idx][h].item())
                        in_path_src.append(self_repair_src[layer_idx][h].item())
                        layer_idxs_in.append(layer_idx)
                        
                        modelwise_dict[model_name]['in_dst'].append(self_repair_dst[layer_idx][h].item())
                        modelwise_dict[model_name]['in_src'].append(self_repair_src[layer_idx][h].item())
                        
                        # t_in_path_dst.append(self_repair_dst[layer_idx][h].item())
                        # t_in_path_src.append(self_repair_src[layer_idx][h].item())
                        # t_layer_idxs_in.append(layer_idx)
                    else:
                        out_path_dst.append(self_repair_dst[layer_idx][h].item())
                        out_path_src.append(self_repair_src[layer_idx][h].item())
                        layer_idxs_out.append(layer_idx)
                        
                        modelwise_dict[model_name]['out_dst'].append(self_repair_dst[layer_idx][h].item())
                        modelwise_dict[model_name]['out_src'].append(self_repair_src[layer_idx][h].item())
                        
                        # t_out_path_dst.append(self_repair_dst[layer_idx][h].item())
                        # t_out_path_src.append(self_repair_src[layer_idx][h].item())
                        # t_layer_idxs_out.append(layer_idx)
        
        
        in_path_src = np.asarray(in_path_src)
        in_path_dst = np.asarray(in_path_dst)
        
        out_path_src = np.asarray(out_path_src)
        out_path_dst = np.asarray(out_path_dst)
        
        in_path_src = cutoff_noise(in_path_src)
        in_path_dst = cutoff_noise(in_path_dst)
        out_path_src = cutoff_noise(out_path_src)
        out_path_dst = cutoff_noise(out_path_dst)
              
                        
        print("{} - {}".format(model_name, "lama" if "lama" in data_path else "knowns1000"))
        print("\t In-Path : Src-> {:8.5f} ±{:8.5f} | Dst-> {:8.5f} ±{:8.5f}".format(np.mean(in_path_src), np.std(in_path_src), np.mean(in_path_dst), np.std(in_path_dst)))
        print("\t Out-Path: Src-> {:8.5f} ±{:8.5f} | Dst-> {:8.5f} ±{:8.5f}".format(np.mean(out_path_src), np.std(out_path_src), np.mean(out_path_dst), np.std(out_path_dst)))
        
    print("------------")
    
    
    for k, v in modelwise_dict.items():
        print(k)
        print("\t In-Path : Src-> {:8.5f} ±{:8.5f} | Dst-> {:8.5f} ±{:8.5f}".format(
            np.mean(v['in_src']), np.std(v['in_src']), 
            np.mean(v['in_dst']), np.std(v['in_dst'])))
        print("\t Out-Path: Src-> {:8.5f} ±{:8.5f} | Dst-> {:8.5f} ±{:8.5f}".format(
            np.mean(v['out_src']), np.std(v['out_src']), 
            np.mean(v['out_dst']), np.std(v['out_dst'])))
        print("------------")


    plot_modelwise_dst_barplot(modelwise_dict)
if __name__ == "__main__":
    main()