#%%
from typing import Dict
from collections import Counter 
from pathlib import Path 
import json 

import pandas as pd
import torch
import matplotlib.pyplot as plt
from plotting_utils import core_models, color_palette, steps2tokens
from matplotlib.ticker import FuncFormatter

plt.rcParams["font.family"] = 'DejaVu Serif'

def clean_outliers(checkpoint_dict: Dict[int, torch.Tensor], min_value: float, max_value: float) -> Dict[int, torch.Tensor]:
    for checkpoint in checkpoint_dict.keys():
        tensor = checkpoint_dict[checkpoint]
        # Set values outside the range to 0.0
        tensor = torch.where(tensor < min_value, torch.tensor(0.0), tensor)
        tensor = torch.where(tensor > max_value, torch.tensor(0.0), tensor)
        checkpoint_dict[checkpoint] = tensor
    return checkpoint_dict

def load_results_double_wrapped(head: str, model:str):
    p = Path('/mnt/hdd-0/circuits-over-time/results/components')
    model_path = p/model
    if '1b' in model:
        return None
    try:
        if head == 'successor':
            data = torch.load(model_path / 'successor_heads_over_time.pt')
            steps = sorted(list(data.keys()))
            head_scores = torch.stack([data[step].cpu() for step in steps])
        
        elif head == 'induction':
            data = torch.load(model_path / "full_model_components_over_time.pt")
            steps = sorted(list(data.keys()))
            head_scores = torch.stack([data[step]['tertiary_head_scores']['induction_scores'].cpu() for step in steps])

        elif head == 'copy_suppression':
            data = torch.load(model_path / 'whole_model_cspa.pt')
            data = clean_outliers(data, 0.0, 1.0)
            steps = sorted(list(data.keys()))
            head_scores = torch.stack([data[step].cpu() for step in steps])# * 0.01
            
        elif head == 'name_mover':
            data = torch.load(model_path / 'early_whole_model_copy_scores.pt')
            steps = sorted([k for k in data.keys()])
            head_scores = torch.stack([data[step].cpu() for step in steps]) * 0.01
            
        else:
            raise ValueError(f"Got invalid head {head}")
        
        return steps, head_scores
    
    except FileNotFoundError:
        print("couldn't find file for", head,  "in", model_path)
        return None
    
def load_results_wrapped(head: str, model: str, perf_thresh=0.33):
        results = load_results_double_wrapped(head, model)
        if results is None:
            return results 
        else: 
            steps, head_scores = results

        layers, heads = (x.tolist() for x in torch.where(head_scores.max(dim=0).values >= (head_scores.max() * perf_thresh)))

        all_heads = list(set(zip(layers, heads)))
        #all_heads = sorted(list(all_heads), key=lambda lh: head_scores[:, lh[0], lh[1]].max(), reverse=True)
        earliests = []
        for layer, head in all_heads:
            scores = head_scores[:, layer, head]
            earliest = torch.where(scores >= head_scores.max() * perf_thresh)[0][0]
            earliests.append(earliest.item())

        _, all_heads = zip(*sorted(zip(earliests, all_heads), key=lambda x: x[0]))
        return steps, all_heads, head_scores

def load_results(head_type: str, model: str, use_tokens=True):
    baselines = load_results_wrapped(head_type, model)
    if baselines is not None and use_tokens:
        return [steps2tokens(x) for x in baselines[0]], baselines[1], baselines[2]
    return baselines

thresh = 200000000  # right now measured in tokens; if you set tokens=False, set it as a number of steps
fig, axs = plt.subplots(2,2)
fig.set_size_inches(11, 5)
def first_digit(x, pos):
    return str(x)[0]

def get_candidates(task: str, model:str):
    try: 
        df = pd.read_feather(f'/mnt/hdd-0/circuits-over-time/results/graphs/{model}/{task}/in_circuit_edges_faithful.feather')
    except FileNotFoundError:
        try:
            df = pd.read_feather(f'/mnt/hdd-0/circuits-over-time/results/graphs/{model}/{task}/in_circuit_edges.feather')
        except FileNotFoundError:
            return None
        
    df = df[df['in_circuit']]
    df['in_node'] = [x.split('->')[0] for x in df['edge']]
    df['out_node'] = [x.split('->')[1].split('<')[0] for x in df['edge']]
    candidate_nodes = set(df['in_node']) & set(df['out_node'])
    count_dict = {cd: len(set(df[(df['in_node'] == cd) | (df['out_node'] ==cd)]['checkpoint'])) for cd in candidate_nodes}
    count_dict = {(int(node.split('.')[0][1:]), int(node.split('.')[1][1:])):v for node,v in count_dict.items() if '.' in node}
    return count_dict

for ax, head_type, title, metric in zip(axs.flat, ['successor', 'induction', 'copy_suppression', 'name_mover'], ['Successor Heads (Greater-Than)', 'Induction Heads (Greater-Than)', "Copy Suppression Heads (IOI)", 'Name Mover Heads (IOI)'], ['Succession Score', 'Induction Score', 'CSPA Score', 'Copy Score']):
    for model in core_models:
        task = 'greater_than' if head_type in {'successor', 'induction'} else 'ioi'

        baseline = load_results(head_type, model)
        if baseline is None:
            continue

        x_axis, all_heads, scores = baseline

        candidate_nodes = get_candidates(task, model)
        if candidate_nodes is None:
            print("couldn't find candidates for", model, head_type, task)
            
        else:
            all_heads = [x for x in all_heads if x in candidate_nodes]
        layer, head = all_heads[0]

        
        head_scores = scores[:, layer, head].numpy()
        ax.plot(x_axis, head_scores, label=model, color=color_palette[model])

    ax.set_title(title)
    ax.set_ylabel(metric)
    ax.set_xscale('log')
    ax.xaxis.set_tick_params(which='minor', labelsize=8)
    ax.xaxis.set_minor_formatter(FuncFormatter(first_digit))
    ax.xaxis.set_tick_params(which='major', pad=10)

axs[1,0].set_xlabel('# Tokens Seen')
axs[1,1].set_xlabel('# Tokens Seen')    

hls = axs[0,0].get_legend_handles_labels()

def str_to_params(x):
    x = x.split('-')[-1]
    num = float(x[:-1])
    if x[-1] == 'b':
        num *= 1000
    return num 

hls = sorted(list(zip(*hls)), key=lambda x: str_to_params(x[1]))
handles, labels = zip(*hls)

fig.tight_layout()
lgd = fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.02), ncol=7)
fig.savefig('../results/plots/fig2_all_models.pdf', bbox_extra_artists=[lgd], bbox_inches='tight')
fig.show()

fig

# %%
