#%%
import torch
import pickle
from globals import *
from tqdm.auto import tqdm
from collections import defaultdict
import matplotlib.pyplot as plt
import os
import re

n_data = 200
max_input_tokens = 400
max_output_tokens = 100
min_tokens = 5
steer_scale = 2

#%% ORIGINAL STEERING EXPERIMENT
context_elephants_original = defaultdict(list)
pattern = re.compile(r'layer(\d+)_\((\d+),\s*(\d+)\)\.pkl')
for root, dirs, files in os.walk('autointerp'):
    for file in files:
        if file.endswith('.pkl'):
            full_path = os.path.join(root, file)
            match = pattern.search(full_path)
            if match:
                layer = int(match.group(1))
                feature1 = int(match.group(2))
                feature2 = int(match.group(3))
                context_elephants_original[layer].append((feature1, feature2))

layers = list(context_elephants_original.keys())
context_elephants = {}
for layer in layers:
    context_elephants[layer] = context_elephants_original[layer]

#%%
metrics = {}
for layer, elephant_pairs in context_elephants.items():
    metrics[layer] = {}
    for feature_pair in elephant_pairs:
        metrics[layer][feature_pair] = {}
        results = {
            'baseline':[],
            'steer1':[],
            'steer2':[]
            }
        original = load_pickle(f'autointerp/layer{layer}_{feature_pair}.pkl')
        for doc_idx, doc in original.items():
            results['baseline'].append(doc['steer_none_evaluation'])
            results['steer1'].append(doc[f'steer_1_evaluation'])
            results['steer2'].append(doc[f'steer_2_evaluation'])
        counts = {'original_right':0,
                  'original_wrong':0,
                  'flip_right':0,
                  'flip_wrong':0,
                  'unclear_steered':0,
                  'total':0,
                  'ndocs':len(results['baseline'])}
        for b, s1, s2 in zip(results['baseline'], results['steer1'], results['steer2']):
            if b == '1':
                if s1 == '1':
                    counts['original_right'] += 1
                    counts['total'] += 1
                if s1 == '2':
                    counts['flip_wrong'] += 1
                    counts['total'] += 1
                if s2 == '1':
                    counts['original_wrong'] += 1
                    counts['total'] += 1
                if s2 == '2':
                    counts['flip_right'] += 1
                    counts['total'] += 1
            if b == '2':
                if s1 == '1':
                    counts['flip_right'] += 1
                    counts['total'] += 1
                if s1 == '2':
                    counts['original_wrong'] += 1
                    counts['total'] += 1
                if s2 == '1':
                    counts['flip_wrong'] += 1
                    counts['total'] += 1
                if s2 == '2':
                    counts['original_right'] += 1
                    counts['total'] += 1
            if s1 == 'UNCLEAR':
                counts['unclear_steered'] += 1
            if s2 == 'UNCLEAR':
                counts['unclear_steered'] += 1

        n_flips = counts['flip_right'] + counts['flip_wrong']

        if n_flips > 0:
            causal = counts['flip_right'] / n_flips
        else:
            causal = 0

        metrics[layer][feature_pair]['counts'] = counts
        metrics[layer][feature_pair]['causal'] = causal
        metrics[layer][feature_pair]['flips'] = n_flips

#%%
max_size = 30
max_total = 200

fig = go.Figure()

for layer, layer_data in metrics.items():
    for pair, pair_data in layer_data.items():
        n_flips = pair_data['flips']
        if n_flips < 40:
            continue
        causal = pair_data['counts']['flip_right'] / n_flips
            
        size = 1 + (n_flips / max_total) * (max_size - 1)

        color = '#636EFA'

        fig.add_trace(
            go.Scatter(
                x=[layer],
                y=[causal],
                mode='markers',
                marker=dict(
                    size=size,
                    opacity=0.8,
                    color=color
                ),
                text=f'Layer: {layer}<br>Elephant: {pair}<br>Flips: {n_flips}<br>Causal Score: {causal:.3f}',
                hoverinfo='text',
                showlegend=False
            )
        )

fig.update_layout(
    width=400,
    height=300,
    xaxis_title='Layer',
    yaxis_title='Fraction of Correct Flips',
    margin=dict(l=50, r=10, t=10, b=50)
)

fig.add_shape(
    type="line",
    x0=min(metrics.keys()),
    x1=max(metrics.keys()),
    y0=0.5,
    y1=0.5,
    line=dict(
        color="grey",
        width=1,
        dash="dash"
    )
)

fig.write_image('*PLOTS/flips.pdf',scale=20)
fig.show()


#%%
for layer in metrics:
    for pair in metrics[layer]:
        if metrics[layer][pair]['flips'] >= 40 and metrics[layer][pair]['causal'] > 0.75:
            print(f'{layer, pair}: {metrics[layer][pair]["flips"]}')

#%%
autointerp_scores = {}
for layer in layers:
    autointerp_scores[layer] = {}
    for pair in context_elephants[layer]:
        autointerp_scores[layer][pair] = metrics[layer][pair]

with open('autointerp/autointerp_scores.pkl', 'wb') as f:
    pickle.dump(autointerp_scores, f)

#%% ANALYSIS WITH THE DIFF CONTEXT
context_elephants_diff_original = defaultdict(list)
pattern = re.compile(r'layer(\d+)_\((\d+),\s*(\d+)\)\.pkl')
for root, dirs, files in os.walk('autointerp/diffcontext'):
    for file in files:
        if file.endswith('.pkl'):
            full_path = os.path.join(root, file)
            match = pattern.search(full_path)
            if match:
                layer = int(match.group(1))
                feature1 = int(match.group(2))
                feature2 = int(match.group(3))
                context_elephants_diff_original[layer].append((feature1, feature2))

#%%
context_elephants_diff = {}
for layer in context_elephants_diff_original:
    context_elephants_diff[layer] = context_elephants_diff_original[layer]

#%%
metrics_diff = {}
for layer, elephant_pairs in context_elephants_diff.items():
    metrics_diff[layer] = {}
    for feature_pair in elephant_pairs:
        metrics_diff[layer][feature_pair] = {}
        results = {
            'baseline':[],
            'steer1':[],
            'steer2':[]
            }
        original = load_pickle(f'autointerp/diffcontext/layer{layer}_{feature_pair}.pkl')
        for doc_idx, doc in original.items():
            results['baseline'].append(doc['steer_none_evaluation'])
            results['steer1'].append(doc[f'steer_1_evaluation'])
            results['steer2'].append(doc[f'steer_2_evaluation'])
        counts = {'original_right':0,
                  'original_wrong':0,
                  'flip_right':0,
                  'flip_wrong':0,
                  'unclear_steered':0,
                  'total':0,
                  'ndocs':len(results['baseline'])}
        for b, s1, s2 in zip(results['baseline'], results['steer1'], results['steer2']):
            if b == '1':
                if s1 == '1':
                    counts['original_right'] += 1
                    counts['total'] += 1
                if s1 == '2':
                    counts['flip_wrong'] += 1
                    counts['total'] += 1
                if s2 == '1':
                    counts['original_wrong'] += 1
                    counts['total'] += 1
                if s2 == '2':
                    counts['flip_right'] += 1
                    counts['total'] += 1
            if b == '2':
                if s1 == '1':
                    counts['flip_right'] += 1
                    counts['total'] += 1
                if s1 == '2':
                    counts['original_wrong'] += 1
                    counts['total'] += 1
                if s2 == '1':
                    counts['flip_wrong'] += 1
                    counts['total'] += 1
                if s2 == '2':
                    counts['original_right'] += 1
                    counts['total'] += 1
            if s1 == 'UNCLEAR':
                counts['unclear_steered'] += 1
            if s2 == 'UNCLEAR':
                counts['unclear_steered'] += 1

        n_flips = counts['flip_right'] + counts['flip_wrong']

        if n_flips > 0:
            causal = counts['flip_right'] / n_flips
        else:
            causal = 0

        metrics_diff[layer][feature_pair]['counts'] = counts
        metrics_diff[layer][feature_pair]['causal'] = causal
        metrics_diff[layer][feature_pair]['flips'] = n_flips

#%% FRAC UNCLEAR
for layer in context_elephants_diff:
    for pair, new_pair_data in metrics_diff[layer].items():
        original_pair_data = metrics[layer][pair]
        if original_pair_data['counts']['ndocs'] == 0:
            continue
        original_metric = original_pair_data['causal']
        if original_metric > 0.75 and original_pair_data['flips'] >= 40:
            original_unclear = original_pair_data['counts']['unclear_steered'] / (2*original_pair_data['counts']['ndocs'])
            new_unclear = new_pair_data['counts']['unclear_steered'] / (2*new_pair_data['counts']['ndocs'])
            print(f'{layer, pair}: {original_unclear} -> {new_unclear}')