import os
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
import math
def headidx2pathidx(num_head_arg=8):
    headidx2pathidx_dict = {}
    for hi in range(num_head_arg):
        headidx2pathidx_dict[hi] = [hi+2, hi+num_head_arg+2]
    return headidx2pathidx_dict
    
def pathidx2headidx(path, num_head_arg=8):
    hi = []
    for p in path:
        if p in [0,1]:
            hi.append(False)
        elif p<(num_head_arg+2):
            hi.append((p-2)%8)
        else:
            hi.append((p-10)%8)
            
    unique_hi = list(set(x for x in hi if x is not False))
    return hi, unique_hi
num_head = 8

self_repair_folder = "buffer_custom"
self_repair_files = sorted(os.listdir(self_repair_folder))


path_inout_result = {}
path_inout_result.update({"union": {"in":[], "out":[]}})
path_inout_result.update({"intersect": {"in":[], "out":[]}})
    
causal_path_folder = os.path.join("causal_paths/20250416_1918_debug_gpt2-mini", "results")
target_data_folders = sorted([i for i in os.listdir(causal_path_folder)])[:-1]
for ii, data_folder in enumerate(target_data_folders):

    json_file_path = os.path.join(causal_path_folder, data_folder, "C{:06d}.json".format(int(data_folder.split("R")[1])))
    f = open(json_file_path, "r")
    curr_causal_path = json.load(f) 
    
    curr_self_repair = torch.load(os.path.join(self_repair_folder, self_repair_files[ii]))
    # batch pos layer head

    curr_self_repair = curr_self_repair[0, -1] # one batch, last decision
    
    self_repair_in_path_u = []
    self_repair_not_in_path_u = []
    
    self_repair_in_path_i = []
    self_repair_not_in_path_i = []
    
    num_paths = []
    for layer_idx, paths in curr_causal_path.items():
        num_paths.append(len(paths))
        union_path = sorted(list(set().union(*paths)))
        intersect_path = sorted(list(set(paths[0]).intersection(*paths[1:])))
        
        curr_layer_self_repair = curr_self_repair[int(layer_idx)]
        
        union_hi, union_u_hi = pathidx2headidx(union_path)
        notin_union_u_hi = np.setdiff1d(np.arange(num_head), union_u_hi)
        
        
        if len(union_u_hi)!=0:
            self_repair_in_path_u.append(curr_layer_self_repair[union_u_hi].numpy().tolist())
        else:
            self_repair_in_path_u.append([])
            
        if len(notin_union_u_hi)!=0:
            self_repair_not_in_path_u.append(curr_layer_self_repair[notin_union_u_hi].numpy().tolist())
        else:
            self_repair_not_in_path_u.append([])
        
        
        intersect_hi, intersect_u_hi = pathidx2headidx(intersect_path)
        notin_intersect_u_hi = np.setdiff1d(np.arange(num_head), intersect_u_hi)
        
        if len(intersect_u_hi)!=0:
            self_repair_in_path_i.append(curr_layer_self_repair[intersect_u_hi].numpy().tolist())
        else:
            self_repair_in_path_i.append([])
            
        if len(intersect_u_hi)!=0:
            self_repair_not_in_path_i.append(curr_layer_self_repair[intersect_u_hi].numpy().tolist())
        else:
            self_repair_not_in_path_i.append([])
    
    print("Num of Path: {}".format(math.prod(num_paths)))
    
    avg_self_repair_in_path_u = np.mean([x for sublist in self_repair_in_path_u for x in sublist])
    avg_self_repair_not_in_path_u = np.mean([x for sublist in self_repair_not_in_path_u for x in sublist])
    avg_self_repair_in_path_i = np.mean([x for sublist in self_repair_in_path_i for x in sublist])
    avg_self_repair_not_in_path_i = np.mean([x for sublist in self_repair_not_in_path_i for x in sublist])
    
    path_inout_result["union"]["in"].append(avg_self_repair_in_path_u)
    path_inout_result["union"]["out"].append(avg_self_repair_not_in_path_u)
    
    path_inout_result["intersect"]["in"].append(avg_self_repair_in_path_i)
    path_inout_result["intersect"]["out"].append(avg_self_repair_not_in_path_i)

fig = plt.figure(figsize=(4, 3))
x = np.arange(len(target_data_folders))
width = 0.35 


plt.bar(x - width/2, path_inout_result["union"]["in"], width, label='In-Path')
plt.bar(x + width/2, path_inout_result["union"]["out"], width, label='Out-Path')
plt.xlabel('Sample Idx')
plt.ylabel('Self-Repair')
plt.title('Comparison of Self-Repair Inside vs. Outside the Causal Path')
plt.xticks(x)
plt.yticks(fontsize=8)
plt.legend()
plt.savefig('z_comp_union.png', dpi=300, bbox_inches='tight')


fig = plt.figure(figsize=(8, 6))
x = np.arange(len(target_data_folders))
width = 0.35 

plt.bar(x - width/2, path_inout_result["intersect"]["in"], width, label='In-Path')
plt.bar(x + width/2, path_inout_result["intersect"]["out"], width, label='Out-Path')
plt.xlabel('Sample Idx')
plt.ylabel('Self-Repair')
plt.title('Comparison of Self-Repair Inside vs. Outside the Causal Path')
plt.xticks(x)
plt.yticks(fontsize=8)
plt.legend()
plt.savefig('z_comp_intersect.png', dpi=300, bbox_inches='tight')

import pdb; pdb.set_trace()