#%%
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='Run mean ablation analysis on Pythia model')
    
    parser.add_argument('--step', type=str, default="19900",
                        help='Model step to use (default: 19900)')
    
    parser.add_argument('--model-id', type=str, default="clean_v3",
                        help='Model ID to use (default: clean_v3)')
    
    parser.add_argument('--batch-size', type=int, default=1,
                        help='Batch size for processing (default: 1)')
    
    parser.add_argument('--input-data', type=str, default="random_repetitive_sequences.csv",
                        help='Path to input CSV data')
    
    parser.add_argument('--ablation-type', type=str, default="one_by_one",
                        help='Type of ablation to perform (default: one_by_one)')
    
    return parser.parse_args()

# Parse command line arguments
args = parse_args()

# Set global variables from arguments
STEP = args.step
MODEL_ID = args.model_id
BATCH_SIZE = args.batch_size
INPUT_DATA_PATH = f"../data/{args.input_data}"
ABLATION_TYPE = args.ablation_type

# Derived paths
MODEL_KEY = f"~/pythia_replicate_public_models/{MODEL_ID}/step={STEP}"
MEANS_DATA_PATH = f"../data/means_per_head_{MODEL_ID}.pt"
OUTPUT_PATH = f"../results/mean_ablation_{MODEL_ID}.json"

#%%

import torch
import pandas as pd
import json
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from nnsight import LanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from collections import defaultdict


#%% LOAD MODEL

def load_our_pythia(model_key=MODEL_KEY):

    tokenizer_key = "EleutherAI/pythia-160m"

    pythia_model = LanguageModel(
        AutoModelForCausalLM.from_pretrained(
            model_key, 
            attn_implementation="eager", 
            torch_dtype=torch.float16, 
            device_map="cuda:0"), 
        tokenizer=AutoTokenizer.from_pretrained(tokenizer_key), 
        config=AutoConfig.from_pretrained(model_key))

    pythia_model.tokenizer.pad_token = pythia_model.tokenizer.eos_token

    return pythia_model


#%% LOAD DATA

class RepeatingSequenceDataset(Dataset):
    """Custom dataset for loading prompt-label pairs from CSV."""
    
    def __init__(self, csv_path):
        self.df = pd.read_csv(csv_path)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Convert string representation of list back to actual list
        prompt = json.loads(row['prompt'])
        label = row['label']
        
        return {
            'text': prompt,  # This is what your ablation function expects
            'label': label
        }


def load_prompts(csv_path=INPUT_DATA_PATH, batch_size=BATCH_SIZE, shuffle=True):
    """Create a DataLoader from the CSV file."""
    dataset = RepeatingSequenceDataset(csv_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader


def load_means_per_head(data_path=MEANS_DATA_PATH):

    means_per_head = torch.load(data_path)

    return means_per_head


#%% ABLATE

def run_ablation_one_by_one(pythia_model, dataloader, means_per_head):

    # for each head (per layer), we want an average logit difference
    avg_logit_diff = defaultdict(list)

    with torch.no_grad():

        with pythia_model.session():

            for batch in tqdm(dataloader):

                prompt = torch.concat(batch['text']).unsqueeze(0)  # Extract first batch item and convert tensor to list
                label_token_id = batch['label'].item()  # Extract first batch item and convert to int

                input = {'input_ids': prompt, 'attention_mask': torch.ones(len(prompt)).unsqueeze(0)}

                for l_idx in range(pythia_model.config.num_hidden_layers):

                    for h_idx in range(pythia_model.config.num_attention_heads):

                        id = f"l_{l_idx}.h_{h_idx}"

                        with pythia_model.trace() as tracer:

                            with tracer.invoke(input) as invoker:

                                clean_logit = pythia_model.embed_out.output[0][-1][label_token_id].cpu()
                                clean_prob = torch.nn.functional.softmax(pythia_model.embed_out.output[0][-1])[label_token_id].cpu()

                        with pythia_model.trace(input) as tracer:

                            layer = pythia_model.gpt_neox.layers[l_idx]

                            layer.attention.source.None_0.source.torch_matmul_0.output[0][h_idx, -1, :] = means_per_head[id][:].unsqueeze(0)

                            # pick logit of exepected output
                            ablated_logit = pythia_model.embed_out.output[0][-1][label_token_id].cpu()
                            ablated_prob = torch.nn.functional.softmax(pythia_model.embed_out.output[0][-1])[label_token_id].cpu()

                        logit_difference = (clean_logit - ablated_logit).item()
                        prob_difference = (clean_prob - ablated_prob).item()

                        avg_logit_diff[id].append((logit_difference, prob_difference))

    return avg_logit_diff


#%%

def run_ablation_all_but_one(pythia_model, dataloader, means_per_head):

    avg_logit_diff = defaultdict(list)

    with torch.no_grad():

        with pythia_model.session():

            for batch in tqdm(dataloader):

                prompt = torch.concat(batch['text']).unsqueeze(0)  # Extract first batch item and convert tensor to list
                label_token_id = batch['label'].item()  # Extract first batch item and convert to int

                input = {'input_ids': prompt, 'attention_mask': torch.ones(len(prompt)).unsqueeze(0)}

                for l_idx in range(pythia_model.config.num_hidden_layers):

                    for h_idx in range(pythia_model.config.num_attention_heads):

                        id = f"l_{l_idx}.h_{h_idx}"

                        with pythia_model.trace() as tracer:

                            with tracer.invoke(input) as invoker:

                                clean_logit = pythia_model.embed_out.output[0][-1][label_token_id].cpu()

                        with pythia_model.trace(input) as tracer:

                            for l_idx_2 in range(pythia_model.config.num_hidden_layers):

                                layer = pythia_model.gpt_neox.layers[l_idx_2]

                                for h_idx_2 in range(pythia_model.config.num_attention_heads):

                                    if l_idx_2 != l_idx and h_idx_2 != h_idx:

                                        layer.attention.source.None_0.source.torch_matmul_0.output[0][h_idx_2, :, :] = means_per_head[f"l_{l_idx_2}.h_{h_idx_2}"][:].unsqueeze(0)

                            # pick logit of exepected output

                            breakpoint()
                            ablated_logit = pythia_model.embed_out.output[0][-1][label_token_id].cpu()

                        logit_difference = (clean_logit - ablated_logit).item()

                        avg_logit_diff[id].append(logit_difference)

    return avg_logit_diff


#%% SAVE RESULTS

def save_results(results, output_path=OUTPUT_PATH):

    with open(output_path, 'w') as f:
        json.dump(results, f)

    print(f"Saved results to {output_path}")


#%% RUN

def main():
    pythia_model = load_our_pythia()
    
    # Create dataloader from CSV
    dataloader = load_prompts()
    
    means_per_head = load_means_per_head()

    if ABLATION_TYPE == "one_by_one":
        avg_logit_diff = run_ablation_one_by_one(pythia_model, dataloader, means_per_head)
    elif ABLATION_TYPE == "all_but_one":
        avg_logit_diff = run_ablation_all_but_one(pythia_model, dataloader, means_per_head)
    else:
        raise ValueError(f"Invalid ablation type: {ABLATION_TYPE}")
    
    save_results(avg_logit_diff)

if __name__ == "__main__":
    main()