import copy
import tqdm 
import torch 
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt 
import os 
import copy 
import pickle 

def HeadStatsWithTest(ops, model_use, K, data, layers,
            op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,
            op_to_layer_to_val_pos_last_test, op_to_layer_to_val_neg_last_test):

    for op in ops:
        assert op in ['k','v','q','attn_output']; 'invalid op'

    bsz_pos       = len(data['train']['pos_prompts'])     # For reshaping
    bsz_neg       = len(data['train']['neg_prompts'])
    bsz_pos_te    = len(data['test']['pos_prompts'])
    bsz_neg_te    = len(data['test']['neg_prompts'])

    q_len         = 1 # treat as length 1 sequence

    N_TR  = min(bsz_pos, bsz_neg)   # NOTE: For balancing in case of uneven datasets
    N_TE = min(bsz_pos_te, bsz_neg_te)


    if model_use == "LLaMA3_Instruct":
        num_heads,head_dim,num_groups    = 32, 128, 4
    if model_use == "Gemma2_Instruct":
        num_heads, head_dim, num_groups  = 16, 256, 2 

    num_kv      = num_heads // num_groups
    results     = {} 

    accs_mean_centroid = { op : { layer : {} for layer in layers} for op in ops}

    for layer in layers:
        for op in ops: 
            N_ELEMS = num_kv if op in ['k', 'v'] else num_heads if op in ['q', 'attn_output'] else None
            # Train
            train_features_pos = op_to_layer_to_val_pos_last[op][layer].view(bsz_pos, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TR,...].numpy() 
            train_features_neg = op_to_layer_to_val_neg_last[op][layer].view(bsz_neg, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TR,...].numpy()
            # Test
            te_features_pos = op_to_layer_to_val_pos_last_test[op][layer].view(bsz_pos_te, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TE,...].numpy() 
            te_features_neg = op_to_layer_to_val_neg_last_test[op][layer].view(bsz_neg_te, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TE,...].numpy()


            # Mean Centroid Classifier
            mu_pos             = op_to_layer_to_val_pos_last[op][layer].mean(dim = 0).view(1, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy() # NOTE: not subsampling since dont need to balance
            mu_neg             = op_to_layer_to_val_neg_last[op][layer].mean(dim = 0).view(1, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy()

            bar = tqdm.tqdm(list(range(0,N_ELEMS)), desc = str(layer))
            for IDX in bar:
                # Create Train and Val Sets and Test sets 
                X_pos_train = train_features_pos[:,IDX,:] # Batch, Heads, Head Dim
                X_neg_train = train_features_neg[:,IDX,:] # Batch, Heads, Head Dim
                X_train     = np.concat([X_neg_train, X_pos_train])
                Y_train     = [0] * len(X_neg_train) + [1] * len(X_pos_train)

               
                X_pos_te = te_features_pos[:,IDX,:] # Batch, Heads, Head Dim
                X_neg_te = te_features_neg[:,IDX,:] # Batch, Heads, Head Dim
                X_te     = np.concat([X_neg_te, X_pos_te])
                Y_te     = [0] * len(X_neg_te) + [1] * len(X_pos_te)

                # Evaluate Centroid Classifier
                mu_pos_head          = mu_pos[IDX,:] # Shape : Head Dim
                mu_neg_head          = mu_neg[IDX,:]
                X_train_pos_scores   = -1 * np.linalg.norm(X_train - mu_pos_head, axis = 1) # Train
                X_train_neg_scores   = -1 * np.linalg.norm(X_train - mu_neg_head, axis = 1)
                Pred_Train_Centroid  = (X_train_pos_scores >= X_train_neg_scores).astype(int)
                Acc_Train_Centroid   = accuracy_score(Y_train, Pred_Train_Centroid)

                X_te_pos_scores     = -1 * np.linalg.norm(X_te - mu_pos_head, axis = 1) # Test
                X_te_neg_scores     = -1 * np.linalg.norm(X_te - mu_neg_head, axis = 1)
                Pred_te_Centroid    = (X_te_pos_scores >= X_te_neg_scores).astype(int)
                Acc_te_Centroid     = accuracy_score(Y_te, Pred_te_Centroid)

                accs_mean_centroid[op][layer][IDX] = {'train' : Acc_Train_Centroid, 'test' : Acc_te_Centroid}

    op_to_df_centroid_train = {  op : pd.DataFrame(accs_mean_centroid[op]).applymap(lambda X: X['train']) for op in ops }
    op_to_df_centroid_te   = {  op : pd.DataFrame(accs_mean_centroid[op]).applymap(lambda X: X['test']) for op in ops }

    results['centroid'] = {'accs' : accs_mean_centroid, 'op_to_df_train' : op_to_df_centroid_train,  'op_to_df_test' : op_to_df_centroid_te}

    return results 

# NOTE: some of the other stuff deprecated, use this
def HeadStats(ops, model_use, K, data, layers,
            op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,
            op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val,
            op_to_layer_to_val_pos_last_test, op_to_layer_to_val_neg_last_test):
    for op in ops:
        assert op in ['k','v','q','attn_output']; 'invalid op'

    bsz_pos       = len(data['train']['pos_prompts'])     # For reshaping
    bsz_neg       = len(data['train']['neg_prompts'])
    bsz_pos_val   = len(data['val']['pos_prompts'])
    bsz_neg_val   = len(data['val']['neg_prompts'])
    bsz_pos_te    = len(data['test']['pos_prompts'])
    bsz_neg_te    = len(data['test']['neg_prompts'])

    q_len         = 1 # treat as length 1 sequence

    N_TR  = min(bsz_pos, bsz_neg)   # NOTE: For balancing in case of uneven datasets
    N_VAL = min(bsz_pos_val, bsz_neg_val)
    N_TE = min(bsz_pos_te, bsz_neg_te)


    if model_use == "LLaMA3_Instruct":
        num_heads,head_dim,num_groups    = 32, 128, 4
    if model_use == "Gemma2_Instruct":
        num_heads, head_dim, num_groups  = 16, 256, 2 
    num_kv      = num_heads // num_groups
    results     = {} 

    accs_mean_centroid = { op : { layer : {} for layer in layers} for op in ops}

    for layer in layers:
        for op in ops: 
            N_ELEMS = num_kv if op in ['k', 'v'] else num_heads if op in ['q', 'attn_output'] else None
            # Train
            train_features_pos = op_to_layer_to_val_pos_last[op][layer].view(bsz_pos, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TR,...].numpy() 
            train_features_neg = op_to_layer_to_val_neg_last[op][layer].view(bsz_neg, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TR,...].numpy()
            # Val
            val_features_pos = op_to_layer_to_val_pos_last_val[op][layer].view(bsz_pos_val, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_VAL,...].numpy()
            val_features_neg = op_to_layer_to_val_neg_last_val[op][layer].view(bsz_neg_val, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_VAL,...].numpy()
            # Test
            te_features_pos = op_to_layer_to_val_pos_last_test[op][layer].view(bsz_pos_te, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TE,...].numpy() 
            te_features_neg = op_to_layer_to_val_neg_last_test[op][layer].view(bsz_neg_te, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TE,...].numpy()


            # Mean Centroid Classifier
            mu_pos             = op_to_layer_to_val_pos_last[op][layer].mean(dim = 0).view(1, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy() # NOTE: not subsampling since dont need to balance
            mu_neg             = op_to_layer_to_val_neg_last[op][layer].mean(dim = 0).view(1, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy()

            bar = tqdm.tqdm(list(range(0,N_ELEMS)), desc = str(layer))
            for IDX in bar:
                # Create Train and Val Sets and Test sets 
                X_pos_train = train_features_pos[:,IDX,:] # Batch, Heads, Head Dim
                X_neg_train = train_features_neg[:,IDX,:] # Batch, Heads, Head Dim
                X_train     = np.concat([X_neg_train, X_pos_train])
                Y_train     = [0] * len(X_neg_train) + [1] * len(X_pos_train)

                X_pos_val = val_features_pos[:,IDX,:] # Batch, Heads, Head Dim
                X_neg_val = val_features_neg[:,IDX,:] # Batch, Heads, Head Dim
                X_val     = np.concat([X_neg_val, X_pos_val])
                Y_val     = [0] * len(X_neg_val) + [1] * len(X_pos_val)


                X_pos_te = te_features_pos[:,IDX,:] # Batch, Heads, Head Dim
                X_neg_te = te_features_neg[:,IDX,:] # Batch, Heads, Head Dim
                X_te     = np.concat([X_neg_te, X_pos_te])
                Y_te     = [0] * len(X_neg_te) + [1] * len(X_pos_te)

                # Evaluate Centroid Classifier
                mu_pos_head          = mu_pos[IDX,:] # Shape : Head Dim
                mu_neg_head          = mu_neg[IDX,:]
                X_train_pos_scores   = -1 * np.linalg.norm(X_train - mu_pos_head, axis = 1) # Train
                X_train_neg_scores   = -1 * np.linalg.norm(X_train - mu_neg_head, axis = 1)
                Pred_Train_Centroid  = (X_train_pos_scores >= X_train_neg_scores).astype(int)
                Acc_Train_Centroid   = accuracy_score(Y_train, Pred_Train_Centroid)

                X_val_pos_scores     = -1 * np.linalg.norm(X_val - mu_pos_head, axis = 1) # Val
                X_val_neg_scores     = -1 * np.linalg.norm(X_val - mu_neg_head, axis = 1)
                Pred_Val_Centroid    = (X_val_pos_scores >= X_val_neg_scores).astype(int)
                Acc_Val_Centroid     = accuracy_score(Y_val, Pred_Val_Centroid)

                X_te_pos_scores     = -1 * np.linalg.norm(X_te - mu_pos_head, axis = 1) # Test
                X_te_neg_scores     = -1 * np.linalg.norm(X_te - mu_neg_head, axis = 1)
                Pred_te_Centroid    = (X_te_pos_scores >= X_te_neg_scores).astype(int)
                Acc_te_Centroid     = accuracy_score(Y_te, Pred_te_Centroid)

                accs_mean_centroid[op][layer][IDX] = {'train' : Acc_Train_Centroid, 'val' : Acc_Val_Centroid, 'test' : Acc_te_Centroid}

    op_to_df_centroid_train = {  op : pd.DataFrame(accs_mean_centroid[op]).applymap(lambda X: X['train']) for op in ops }
    op_to_df_centroid_val   = {  op : pd.DataFrame(accs_mean_centroid[op]).applymap(lambda X: X['val']) for op in ops }
    op_to_df_centroid_te   = {  op : pd.DataFrame(accs_mean_centroid[op]).applymap(lambda X: X['test']) for op in ops }

    results['centroid'] = {'accs' : accs_mean_centroid, 'op_to_df_train' : op_to_df_centroid_train, 'op_to_df_val' : op_to_df_centroid_val, 'op_to_df_test' : op_to_df_centroid_te}

    return results 

# Best Layer Masking 

def getLayerAccs(ops,op_to_layer_to_val_pos_last,op_to_layer_to_val_neg_last, op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val, op_to_steer, mask = False, viz = False):
    layers = list(op_to_layer_to_val_pos_last[ops[0]].keys())
    op_to_layer_to_acc = { op : {layer : None for layer in layers} for op in ops }

    if viz: 
        plt.figure(figsize = (14,7))
        
    for op in ops:
        accs = [] 
        for layer in layers:
            # Centroids for classification
            mu_pos = op_to_layer_to_val_pos_last[op][layer].mean(dim = 0).numpy() 
            mu_neg = op_to_layer_to_val_neg_last[op][layer].mean(dim = 0).numpy()

            # Validation Features
            val_features_pos = op_to_layer_to_val_pos_last_val[op][layer].numpy()
            val_features_neg = op_to_layer_to_val_neg_last_val[op][layer].numpy()
            X_val     = np.concat([val_features_neg, val_features_pos])
            Y_val     = [0] * len(val_features_neg) + [1] * len(val_features_pos)

            # Score, Predict, Evaluate
            X_val_pos_scores     = -1 * np.linalg.norm(X_val - mu_pos, axis = 1) # Val
            X_val_neg_scores     = -1 * np.linalg.norm(X_val - mu_neg, axis = 1)
            Pred_Val_Centroid    = (X_val_pos_scores >= X_val_neg_scores).astype(int)
            Acc_Val_Centroid     = accuracy_score(Y_val, Pred_Val_Centroid)

            op_to_layer_to_acc[op][layer] = Acc_Val_Centroid
            accs.append(Acc_Val_Centroid)

        if viz: 
            plt.plot(layers, accs, label = op)
        accs = torch.tensor(accs) # Save

        best_L = accs.argmax()
        print(op, best_L.item(), accs[best_L].item())
        if mask:
            for layer in layers:
                if layer != best_L:
                    op_to_steer[op][layer] = torch.zeros_like(op_to_steer[op][layer])
            
            print(op_to_steer[op][0].norm(), op_to_steer[op][best_L.item()].norm())

    if viz: 
        plt.legend()
        plt.grid()
        plt.ylim(0.5, 1)
        plt.yticks(np.linspace(0.5,1,11))
        plt.title("Centroid Classifier Val Acc vs Layer")
        plt.ylabel("Accuracy")
        plt.xlabel("Layer")

    return op_to_layer_to_acc, op_to_steer

####################################################################################
##################################### ITI CODE ####################################
###################################################################################



# New head visualization
def vizHeads(results, ops, model_use, data, layers,
            op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,
            op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val, 
            savedir, dataset): # , HeadStats):
    # Params
    if "Gemma" in model_use:
        num_kv     = 8
        num_heads  = 16
    elif "LLaMA" in model_use:
        num_kv     = 8
        num_heads  = 32    
    num_groups = num_heads // num_kv 

    # Copy k, v num_groups times
    for cls_type in results.keys():
        accs = copy.deepcopy(results[cls_type]['accs'])
        for layer in layers:
            for op in ["v", 'k']:
                J = 0 
                temp = copy.deepcopy(accs[op][layer])
                for IDX in accs[op][layer].keys():
                    for _ in range(num_groups):
                        temp[J] = accs[op][layer][IDX]
                        J += 1
                accs[op][layer] = temp 

        # Override for kv copying, using code right below -- doesnt matter for q, attn_output
        dfs_train = {  op : pd.DataFrame(accs[op]).applymap(lambda X: X['train']) for op in ops }
        dfs_val   = {  op : pd.DataFrame(accs[op]).applymap(lambda X: X['val']) for op in ops }

        # (A) Heatmap
        fig, axes = plt.subplots(2,4, figsize = (8 * 4, 12))
        im_list   = [] # for cbar, just save validation for fun 
        COL       = 0 
        for op in ['q', 'k', 'v', 'attn_output']:
            df_val = dfs_val[op]
            df_train = dfs_train[op]
            im1  = axes[0][COL].imshow(df_val.T, origin='lower', aspect='auto', cmap='coolwarm',  vmin=.5, vmax=1)
            im2  = axes[1][COL].imshow(df_train.T, origin='lower', aspect='auto', cmap='coolwarm',  vmin=.5, vmax=1)
            axes[0][COL].set_title(op)
            im_list.append(im1)
            COL += 1
        for ROW in range(2):
            axes[ROW][0].set_ylabel("Layer")
        for COL in range(4):
            axes[1][COL].set_xlabel("Attention Head")
        cbar = fig.colorbar(im_list[0], ax=axes, location='right', shrink=1)
        plt.suptitle(model_use + ", " + dataset)
        plt.savefig(os.path.join(savedir,  f"{cls_type}_Heatmap.png")) 

        # (B) CCDF
        fig, axes = plt.subplots(1,2, figsize = (16,8))
        COL = 0 
        xaxis = np.linspace(.30,1,num = 100)
        for op in ['q', 'k', 'v', 'attn_output']:
            df_val = dfs_val[op]
            df_train = dfs_train[op]
            prop_above_T_tr, prop_above_T_val = [] , [] 
            for T in np.linspace(.30,1,num = 100):
                flat_val           = df_val.values.flatten()
                flat_train         = df_train.values.flatten()
                prop_above_T_val.append(len(flat_val[flat_val >= T]) / len(flat_val))
                prop_above_T_tr.append(len(flat_train[flat_train >= T]) / len(flat_train))
            axes[0].plot(xaxis,  prop_above_T_val, label  = op)
            axes[1].plot(xaxis,  prop_above_T_tr, label  = op)
        axes[0].legend()
        axes[0].set_title("Val")
        axes[1].set_title("Train")
        axes[0].set_ylabel("% Above Threshhold")
        axes[0].set_xlabel("Threshhold")
        axes[1].set_xlabel("Threshhold")
        for J in [0,1]:
            axes[J].grid() 
            axes[J].set_ylim(0,1)
            axes[J].set_yticks(np.linspace(0, 1, num=11))
        plt.suptitle(model_use + ", " + dataset)#  + " Area Diff = " + area_str)
        plt.savefig(os.path.join(savedir, f"{cls_type}_CCDF.png")) 

        # (C) Pairwise comparisons val
        fig, axes = plt.subplots(4,4, figsize = (20,20))
        for I in range(len(ops)):
            op = ops[I] # xaxis
            df_val = dfs_val[op]
            for J in range(len(ops)):
                op2     = ops[J] # yaxis
                df_val2 = dfs_val[op2]
                axes[I][J].scatter(df_val.values.flatten(), df_val2.values.flatten())
                axes[I][J].set_xlabel(op)
                axes[I][J].set_ylabel(op2)
                axes[I][J].set_ylim(0.5, 1)
                axes[I][J].set_yticks(np.linspace(0.5, 1, num = 11))
                axes[I][J].set_xlim(0.5, 1)
                axes[I][J].set_xticks(np.linspace(0.5, 1, num = 11))
                axes[I][J].plot(np.linspace(0.5, 1, num = 21), np.linspace(0.5, 1, num = 21), 'g--') #label="y = x")  # Green dashed line
        plt.savefig(os.path.join(savedir, f"{cls_type}_Acc_Comparisons_Val.png")) 
        # (C) #2 Pairwise comparisons train
        fig, axes = plt.subplots(4,4, figsize = (20,20))
        for I in range(len(ops)):
            op = ops[I] # xaxis
            df_train = dfs_train[op]
            for J in range(len(ops)):
                op2 = ops[J] # yaxis
                df_train2 = dfs_train[op2]
                axes[I][J].scatter(df_train.values.flatten(), df_train2.values.flatten())
                axes[I][J].set_xlabel(op)
                axes[I][J].set_ylabel(op2)
                axes[I][J].set_ylim(0.5, 1)
                axes[I][J].set_yticks(np.linspace(0.5, 1, num = 11))
                axes[I][J].set_xlim(0.5, 1)
                axes[I][J].set_xticks(np.linspace(0.5, 1, num = 11))
                axes[I][J].plot(np.linspace(0.5, 1, num = 21), np.linspace(0.5, 1, num = 21), 'g--') #label="y = x")  # Green dashed line
        plt.savefig(os.path.join(savedir, f"{cls_type}_Acc_Comparisons_Train.png")) 

    # Comparison
    dfs_train_logistic, dfs_val_logistic = results['logistic']['op_to_df_train'], results['logistic']['op_to_df_val']
    dfs_train_centroid, dfs_val_centroid = results['centroid']['op_to_df_train'], results['centroid']['op_to_df_val']


    fig, axes = plt.subplots(2,4, figsize = (8 * 4, 12))
    IDX = 0 
    for op in ['q', 'k', 'v', 'attn_output']:
        # Val
        df_val_logistic = dfs_val_logistic[op].values.flatten()
        df_val_centroid = dfs_val_centroid[op].values.flatten()
        axes[0][IDX].scatter(df_val_logistic, df_val_centroid)
        axes[0][IDX].set_title(op)
        axes[0][IDX].set_ylim(0.5, 1)
        axes[0][IDX].set_yticks(np.linspace(0.5, 1, num = 11))
        axes[0][IDX].set_xlim(0.5, 1)
        axes[0][IDX].set_xticks(np.linspace(0.5, 1, num = 11))
        axes[0][IDX].plot(np.linspace(0.5, 1, num = 21), np.linspace(0.5, 1, num = 21), 'g--') 

        # Train        
        df_train_logistic = dfs_train_logistic[op].values.flatten()
        df_train_centroid = dfs_train_centroid[op].values.flatten()
        axes[1][IDX].scatter(df_train_logistic, df_train_centroid)
        axes[1][IDX].set_ylim(0.5, 1)
        axes[1][IDX].set_yticks(np.linspace(0.5, 1, num = 11))
        axes[1][IDX].set_xlim(0.5, 1)
        axes[1][IDX].set_xticks(np.linspace(0.5, 1, num = 11))
        axes[1][IDX].plot(np.linspace(0.5, 1, num = 21), np.linspace(0.5, 1, num = 21), 'g--') 

        IDX += 1 
    
    for J in range(4):
        axes[1][J].set_xlabel("Logistic Accuracy")
    for I in range(2):
        axes[I][0].set_ylabel("Centroid Accuracy")
    plt.suptitle("Logistic and Centroid Acc Correlations")    
    plt.savefig(os.path.join(savedir, f"{cls_type}_Logistic_Centroid_Comparison.png")) 


# NOTE: some of the other stuff deprecated, use this
def HeadStats(ops, model_use, K, data, layers,
            op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,
            op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val,):
    for op in ops:
        assert op in ['k','v','q','attn_output']; 'invalid op'

    bsz_pos       = len(data['train']['pos_prompts'])     # For reshaping
    bsz_neg       = len(data['train']['neg_prompts'])
    bsz_pos_val   = len(data['val']['pos_prompts'])
    bsz_neg_val   = len(data['val']['neg_prompts'])
    q_len         = 1 # treat as length 1 sequence

    N_TR  = min(bsz_pos, bsz_neg)   # NOTE: For balancing in case of uneven datasets
    N_VAL = min(bsz_pos_val, bsz_neg_val)
    if model_use == "LLaMA3_Instruct":
        num_heads,head_dim,num_groups    = 32, 128, 4
    if model_use == "Gemma2_Instruct":
        num_heads, head_dim, num_groups  = 16, 256, 2 
    num_kv      = num_heads // num_groups
    results     = {} 

    accs_logistic      = { op : { layer : {} for layer in layers} for op in ops}
    accs_mean_centroid = { op : { layer : {} for layer in layers} for op in ops}

    for layer in layers:
        for op in ops: 
            N_ELEMS = num_kv if op in ['k', 'v'] else num_heads if op in ['q', 'attn_output'] else None
            # Train
            train_features_pos = op_to_layer_to_val_pos_last[op][layer].view(bsz_pos, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TR,...].numpy() 
            train_features_neg = op_to_layer_to_val_neg_last[op][layer].view(bsz_neg, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_TR,...].numpy()
            # Val
            val_features_pos = op_to_layer_to_val_pos_last_val[op][layer].view(bsz_pos_val, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_VAL,...].numpy()
            val_features_neg = op_to_layer_to_val_neg_last_val[op][layer].view(bsz_neg_val, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze()[:N_VAL,...].numpy()
            # Mean Centroid Classifier
            mu_pos             = op_to_layer_to_val_pos_last[op][layer].mean(dim = 0).view(1, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy() # NOTE: not subsampling since dont need to balance
            mu_neg             = op_to_layer_to_val_neg_last[op][layer].mean(dim = 0).view(1, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy()

            bar = tqdm.tqdm(list(range(0,N_ELEMS)), desc = str(layer))
            for IDX in bar:
                # Create Train and Val Sets
                X_pos_train = train_features_pos[:,IDX,:] # Batch, Heads, Head Dim
                X_neg_train = train_features_neg[:,IDX,:] # Batch, Heads, Head Dim
                X_train     = np.concat([X_neg_train, X_pos_train])
                Y_train     = [0] * len(X_neg_train) + [1] * len(X_pos_train)

                X_pos_val = val_features_pos[:,IDX,:] # Batch, Heads, Head Dim
                X_neg_val = val_features_neg[:,IDX,:] # Batch, Heads, Head Dim
                X_val     = np.concat([X_neg_val, X_pos_val])
                Y_val     = [0] * len(X_neg_val) + [1] * len(X_pos_val)

                # Fit + Evaluate Logistic Regression
                logisticregression = LogisticRegression(max_iter = 1000)
                logisticregression.fit(X_train, Y_train) # Train
                Pred_Train = logisticregression.predict(X_train)
                Acc_Train  = accuracy_score(Y_train, Pred_Train)
                Pred_Val  = logisticregression.predict(X_val) # Val
                Acc_Val   = accuracy_score(Y_val, Pred_Val)

                # Evaluate Centroid Classifier
                mu_pos_head          = mu_pos[IDX,:] # Shape : Head Dim
                mu_neg_head          = mu_neg[IDX,:]
                X_train_pos_scores   = -1 * np.linalg.norm(X_train - mu_pos_head, axis = 1) # Train
                X_train_neg_scores   = -1 * np.linalg.norm(X_train - mu_neg_head, axis = 1)
                Pred_Train_Centroid  = (X_train_pos_scores >= X_train_neg_scores).astype(int)
                Acc_Train_Centroid   = accuracy_score(Y_train, Pred_Train_Centroid)
                X_val_pos_scores     = -1 * np.linalg.norm(X_val - mu_pos_head, axis = 1) # Val
                X_val_neg_scores     = -1 * np.linalg.norm(X_val - mu_neg_head, axis = 1)
                Pred_Val_Centroid    = (X_val_pos_scores >= X_val_neg_scores).astype(int)
                Acc_Val_Centroid     = accuracy_score(Y_val, Pred_Val_Centroid)

                accs_logistic[op][layer][IDX]      = {'train' : Acc_Train, 'val' : Acc_Val, 'coef' : torch.tensor(logisticregression.coef_)} # Save
                accs_mean_centroid[op][layer][IDX] = {'train' : Acc_Train_Centroid, 'val' : Acc_Val_Centroid}
    
    op_to_df_logistic_train = {  op : pd.DataFrame(accs_logistic[op]).applymap(lambda X: X['train']) for op in ops }
    op_to_df_logistic_val   = {  op : pd.DataFrame(accs_logistic[op]).applymap(lambda X: X['val']) for op in ops }
    op_to_df_centroid_train = {  op : pd.DataFrame(accs_mean_centroid[op]).applymap(lambda X: X['train']) for op in ops }
    op_to_df_centroid_val   = {  op : pd.DataFrame(accs_mean_centroid[op]).applymap(lambda X: X['val']) for op in ops }

    results['logistic'] = {'accs' : accs_logistic, 'op_to_df_train' : op_to_df_logistic_train, 'op_to_df_val' : op_to_df_logistic_val}
    results['centroid'] = {'accs' : accs_mean_centroid, 'op_to_df_train' : op_to_df_centroid_train, 'op_to_df_val' : op_to_df_centroid_val}

    return results 


def viz_ITI(ops, model_use, data, layers,
            op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,
            op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val, 
            savedir, dataset): # , HeadStats):
    # Params
    if "Gemma" in model_use:
        num_kv     = 8
        num_heads  = 16
    elif "LLaMA" in model_use:
        num_kv     = 8
        num_heads  = 32    
    num_groups = num_heads // num_kv 

    # Load accs
    results   = vec_to_ITI(op_to_meandiff = None, ops = ['q', 'k', 'v', 'attn_output'], model_use = model_use, K = None, data = data, layers = layers,
                op_to_layer_to_val_pos_last = op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last = op_to_layer_to_val_neg_last,
                op_to_layer_to_val_pos_last_val = op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val = op_to_layer_to_val_neg_last_val,
                mask_criterion = None)
    
    with open(savedir + "HeadStats.pkl", 'wb') as f: # save results 
        pickle.dump(results, f)

    # Copy k, v num_groups times
    accs = results['accs']
    for layer in layers:
        for op in ["v", 'k']:
            J = 0 
            temp = copy.deepcopy(accs[op][layer])
            for IDX in accs[op][layer].keys():
                for _ in range(num_groups):
                    temp[J] = accs[op][layer][IDX]
                    J += 1
            accs[op][layer] = temp 

    # Override for kv copying, using code right below -- doesnt matter for q, attn_output
    # dfs_val, dfs_train = results['op_to_df_val'], results['op_to_df_train']
    dfs_train = {  op : pd.DataFrame(accs[op]).applymap(lambda X: X['train']) for op in ops }
    dfs_val   = {  op : pd.DataFrame(accs[op]).applymap(lambda X: X['val']) for op in ops }

    # (A) Heatmap
    fig, axes = plt.subplots(2,4, figsize = (8 * 4, 12))
    im_list   = [] # for cbar, just save validation for fun 
    COL       = 0 

    for op in ['q', 'k', 'v', 'attn_output']:
        df_val = dfs_val[op]
        df_train = dfs_train[op]
        im1  = axes[0][COL].imshow(df_val.T, origin='lower', aspect='auto', cmap='coolwarm',  vmin=.5, vmax=1)
        im2  = axes[1][COL].imshow(df_train.T, origin='lower', aspect='auto', cmap='coolwarm',  vmin=.5, vmax=1)
        axes[0][COL].set_title(op)
        im_list.append(im1)
        COL += 1
    for ROW in range(2):
        axes[ROW][0].set_ylabel("Layer")
    for COL in range(4):
        axes[1][COL].set_xlabel("Attention Head")
    cbar = fig.colorbar(im_list[0], ax=axes, location='right', shrink=1)
    plt.suptitle(model_use + ", " + dataset)
    plt.savefig(os.path.join(savedir,  "Heatmap.png")) 

    # (B) CCDF
    fig, axes = plt.subplots(1,2, figsize = (16,8))
    COL = 0 
    xaxis = np.linspace(.30,1,num = 100)
    for op in ['q', 'k', 'v', 'attn_output']:
        df_val = dfs_val[op]
        df_train = dfs_train[op]
        prop_above_T_tr, prop_above_T_val = [] , [] 
        for T in np.linspace(.30,1,num = 100):
            flat_val           = df_val.values.flatten()
            flat_train         = df_train.values.flatten()
            prop_above_T_val.append(len(flat_val[flat_val >= T]) / len(flat_val))
            prop_above_T_tr.append(len(flat_train[flat_train >= T]) / len(flat_train))
        axes[0].plot(xaxis,  prop_above_T_val, label  = op)
        axes[1].plot(xaxis,  prop_above_T_tr, label  = op)


    axes[0].legend()
    axes[0].set_title("Val")
    axes[1].set_title("Train")
    axes[0].set_ylabel("% Above Threshhold")
    axes[0].set_xlabel("Threshhold")
    axes[1].set_xlabel("Threshhold")


    for J in [0,1]:
        axes[J].grid() 
        axes[J].set_ylim(0,1)
        axes[J].set_yticks(np.linspace(0, 1, num=11))

    plt.suptitle(model_use + ", " + dataset)#  + " Area Diff = " + area_str)
    plt.savefig(os.path.join(savedir, "CCDF.png")) 

    # (C) Pairwise comparisons
    # Val
    fig, axes = plt.subplots(4,4, figsize = (20,20))
    for I in range(len(ops)):
        op = ops[I] # xaxis
        df_val = dfs_val[op]
        for J in range(len(ops)):
            op2     = ops[J] # yaxis
            df_val2 = dfs_val[op2]
            axes[I][J].scatter(df_val.values.flatten(), df_val2.values.flatten())
            axes[I][J].set_xlabel(op)
            axes[I][J].set_ylabel(op2)
            # axes[I][J].set_title(op + " comparisons")

            axes[I][J].set_ylim(0.5, 1)
            axes[I][J].set_yticks(np.linspace(0.5, 1, num = 11))
            axes[I][J].set_xlim(0.5, 1)
            axes[I][J].set_xticks(np.linspace(0.5, 1, num = 11))

            axes[I][J].plot(np.linspace(0.5, 1, num = 21), np.linspace(0.5, 1, num = 21), 'g--') #label="y = x")  # Green dashed line


    plt.savefig(os.path.join(savedir, "Acc_Comparisons_Val.png")) 

    # Train
    fig, axes = plt.subplots(4,4, figsize = (20,20))
    for I in range(len(ops)):
        op = ops[I] # xaxis
        df_train = dfs_train[op]
        for J in range(len(ops)):
            op2 = ops[J] # yaxis
            df_train2 = dfs_train[op2]
            axes[I][J].scatter(df_train.values.flatten(), df_train2.values.flatten())
            axes[I][J].set_xlabel(op)
            axes[I][J].set_ylabel(op2)
            # axes[I][J].set_title(op + " comparisons")

            axes[I][J].set_ylim(0.5, 1)
            axes[I][J].set_yticks(np.linspace(0.5, 1, num = 11))
            axes[I][J].set_xlim(0.5, 1)
            axes[I][J].set_xticks(np.linspace(0.5, 1, num = 11))

            axes[I][J].plot(np.linspace(0.5, 1, num = 21), np.linspace(0.5, 1, num = 21), 'g--') #label="y = x")  # Green dashed line


    plt.savefig(os.path.join(savedir, "Acc_Comparisons_Train.png")) 

# NOTE: this one is good, keep it 
def apply_ITI_mask(ops, mask_criterion, num_kv, num_heads, op_to_df_val, op_to_df_train, K, op_to_meandiff, head_dim):
    assert mask_criterion in ['ValAcc', 'TrainAcc']; "Invalid acc criterion for masking"
    for op in ops:
        if op in ['k','v']:
            N_ELEMS = num_kv
        elif op in ['q', 'attn_output']:
            N_ELEMS = num_heads

        if mask_criterion == 'ValAcc':
            stacked = op_to_df_val[op].stack()              # Convert to multi-index series
        elif mask_criterion == 'TrainAcc':
            stacked = op_to_df_train[op].stack()            # Convert to multi-index series

        top_k = stacked.nlargest(K)                         # Get top-k values
        AttnHead_IDX, Layer_IDX = zip(*top_k.index)         # Extract row, col indices

        L_TO_KEEP = {L_IDX : [] for L_IDX in Layer_IDX}
        for ATT_HEAD, L_IDX in zip(AttnHead_IDX, Layer_IDX):
            L_TO_KEEP[L_IDX].append(ATT_HEAD)
        
        for L_IDX in L_TO_KEEP.keys():
            HEAD_KEEP = torch.tensor(L_TO_KEEP[L_IDX])
            REPR = op_to_meandiff[op][L_IDX].view(1,N_ELEMS,head_dim)
            MASK = torch.zeros_like(REPR)
            MASK[:,HEAD_KEEP,:] = 1
            REPR = REPR * MASK
            REPR = REPR.view(1,-1) #reshape
            op_to_meandiff[op][L_IDX] = REPR
    
        # Quick Smell Check
        print(op)
        L_CHECK = list(L_TO_KEEP.keys())[0]
        print(L_TO_KEEP[L_CHECK], len(L_TO_KEEP[L_CHECK]))
        print((op_to_meandiff[op][L_CHECK] != 0).float().mean().item())
        print(len(L_TO_KEEP[L_CHECK]) / N_ELEMS) 

    return op_to_meandiff


def vec_to_ITI(op_to_meandiff, ops, model_use, K, data, layers,
                op_to_layer_to_val_pos_last, op_to_layer_to_val_neg_last,
                op_to_layer_to_val_pos_last_val, op_to_layer_to_val_neg_last_val,
                mask_criterion = 'ValAcc'):
                
    # For reshaping
    bsz_pos       = len(data['train']['pos_prompts'])
    bsz_neg       = len(data['train']['neg_prompts'])
    bsz_pos_val   = len(data['val']['pos_prompts'])
    bsz_neg_val   = len(data['val']['neg_prompts'])
    q_len         = 1 # treat as length 1 sequence

    # NOTE: For balancing in case of uneven datasets
    N_TR  = min(bsz_pos, bsz_neg)
    N_VAL = min(bsz_pos_val, bsz_neg_val)

    if model_use == "LLaMA3_Instruct":
        num_heads   = 32
        head_dim    = 128
        num_groups  = 4
    if model_use == "Gemma2_Instruct":
        num_heads   = 16
        head_dim    = 256
        num_groups  = 2

    num_kv      = num_heads // num_groups
    results = {} 

    # Logistic Regression
    accs = { op : { layer : {} for layer in layers} for op in ops}
    for layer in layers:
        for op in ops: 
            assert op in ['k','v','q','attn_output']; 'invalid op'
            if op in ['k','v']:
                N_ELEMS = num_kv
            elif op in ['q', 'attn_output']:
                N_ELEMS = num_heads

            # Train
            op_to_layer_to_val_pos_last[op][layer] =  op_to_layer_to_val_pos_last[op][layer].view(bsz_pos, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy()[:N_TR,...]
            op_to_layer_to_val_neg_last[op][layer] =  op_to_layer_to_val_neg_last[op][layer].view(bsz_neg, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy()[:N_TR,...]
            # Val
            op_to_layer_to_val_pos_last_val[op][layer] =  op_to_layer_to_val_pos_last_val[op][layer].view(bsz_pos_val, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy()[:N_VAL,...]
            op_to_layer_to_val_neg_last_val[op][layer] =  op_to_layer_to_val_neg_last_val[op][layer].view(bsz_neg_val, q_len, N_ELEMS, head_dim).transpose(1, 2).squeeze().numpy()[:N_VAL,...]

            bar = tqdm.tqdm(list(range(0,N_ELEMS)), desc = str(layer))
            for IDX in bar:
                logisticregression = LogisticRegression(max_iter = 1000)
                # 1. Train
                X_pos_train = op_to_layer_to_val_pos_last[op][layer][:,IDX,:] # Batch, Heads, Head Dim
                X_neg_train = op_to_layer_to_val_neg_last[op][layer][:,IDX,:] # Batch, Heads, Head Dim
                X_train     = np.concat([X_neg_train, X_pos_train])
                Y_train     = [0] * len(X_neg_train) + [1] * len(X_pos_train)
                logisticregression.fit(X_train, Y_train)
                Pred_Train = logisticregression.predict(X_train)
                Acc_Train  = accuracy_score(Y_train, Pred_Train)

                # 2. Eval
                X_pos_val = op_to_layer_to_val_pos_last_val[op][layer][:,IDX,:] # Batch, Heads, Head Dim
                X_neg_val = op_to_layer_to_val_neg_last_val[op][layer][:,IDX,:] # Batch, Heads, Head Dim
                X_val     = np.concat([X_neg_val, X_pos_val])
                Y_val     = [0] * len(X_neg_val) + [1] * len(X_pos_val)
                Pred_Val  = logisticregression.predict(X_val)
                Acc_Val   = accuracy_score(Y_val, Pred_Val)

                accs[op][layer][IDX] = {'train' : Acc_Train, 'val' : Acc_Val}

    op_to_df_train = {  op : pd.DataFrame(accs[op]).applymap(lambda X: X['train']) for op in ops }
    op_to_df_val   = {  op : pd.DataFrame(accs[op]).applymap(lambda X: X['val']) for op in ops }

    results['op_to_df_train'] = op_to_df_train
    results['op_to_df_val']   = op_to_df_val
    results['accs']           = accs

    if not op_to_meandiff is None:
        op_to_meandiff = apply_ITI_mask(ops, mask_criterion, num_kv, num_heads, op_to_df_val, op_to_df_train, K, op_to_meandiff, head_dim)
        results['op_to_meandiff'] = op_to_meandiff
    
    print(accs)
    return results
            
####################################################
#### agglomerates representations across tokens ####
####################################################

# Consistent dimensions for everything  (batch, tokens, dimension)
# TODO: allow functionality for starting from AFTER the question part of instruct
def agglomerate(op_to_layer_to_val, start_tok_keep_global = 0, start_per_ex = [], agg = 'last'):
    op_to_layer_to_val_mod = copy.deepcopy(op_to_layer_to_val)

    # (A) take last token as repr.
    for op in op_to_layer_to_val_mod.keys():
        bar = tqdm.tqdm(op_to_layer_to_val_mod[op].keys(), desc = op)
        for layer in bar: 
            assert(len(op_to_layer_to_val_mod[op][layer][0].shape) == 3); "shape dim != 3, something is off"
            op_to_layer_to_val_mod[op][layer] = [x[:,-1,:] for x in op_to_layer_to_val_mod[op][layer]]

    # lists --> tesnors    
    for op in op_to_layer_to_val_mod.keys():
        for layer in op_to_layer_to_val_mod[op].keys():
            op_to_layer_to_val_mod[op][layer] = torch.vstack(op_to_layer_to_val_mod[op][layer])
    
    return op_to_layer_to_val_mod

#########################################################################################################
#### populates mean difference steering vectors given operations and representations in a dictionary ####
#########################################################################################################


# Override -- diff of means instead of meandiff
def populate_MeanDiff_hooked_style_dataset(operations, pos_representations, neg_representations):
    op_to_meandiff = {op : {} for op in operations}
    for op in operations:
        temp_pos, temp_neg = pos_representations[op], neg_representations[op]

        for layer in temp_pos.keys():
            pos_repr, neg_repr = temp_pos[layer], temp_neg[layer]
            meandiff = (pos_repr.mean(dim = 0, keepdim = True) - neg_repr.mean(dim = 0, keepdim = True)) # TODO: switch order of mean in case different # of pos and neg
            op_to_meandiff[op][layer] = meandiff

    return op_to_meandiff