import torch
from typing import Union, List, Callable, Dict, Tuple
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from tqdm import tqdm
from stitching.losses import next_token_cross_entropy_loss, get_all_special_tokens

@dataclass
class Experiment:
    run_id: str
    method: str
    results: dict

@dataclass
class Intervention:
    name: str
    forward: Callable
    start_model: str = 'A'
    end_model: str = 'A'

def get_identity_intervention(name, model_name):
    return Intervention(
        name,
        lambda x: x,
        model_name,
        model_name
    )
def get_sae_intervention(name, sae, model_name):
    return Intervention(
        name,
        lambda x: sae(x),
        model_name,
        model_name
    )

def get_zero_intervention(name, model_name):
    return Intervention(
        name,
        lambda x: torch.zeros_like(x),
        model_name,
        model_name
    )

@torch.inference_mode()
def run_one_model_interventions(dataloader, model, layer, interventions : List[Intervention]):
    # Expects all interventions to happen at the same layer 
    results = {
        intervention.name: [] for intervention in interventions
    }
    device = next(model.parameters()).device
    spec_tokens = get_all_special_tokens(model.tokenizer)
    for i, sample in tqdm(enumerate(dataloader)):
        sample = sample.to(device)
        logits = model(sample, stop_at_layer=layer)

        for intervention in interventions:
            modified_logits = intervention.forward(logits)
            preds = model(modified_logits, start_at_layer=layer)
            
            results[intervention.name].append(
                next_token_cross_entropy_loss(preds, sample, ignore_index=spec_tokens, reduction='none').flatten().cpu()
            )
    for (k,v) in results.items():
        if len(v) == 0:
            results[k] = None
        else:
            results[k] = torch.concatenate(v).mean()
    return results

@torch.inference_mode()
def run_multi_model_interventions(dataloader, models: Dict[str, Tuple[int, HookedTransformer]], interventions : List[Intervention]):
    results = {
        intervention.name: [] for intervention in interventions
    }
    first_model_name = next(iter(models.keys()))
    device = next(models[first_model_name][1].parameters()).device
    spec_tokens = get_all_special_tokens(models[first_model_name][1].tokenizer)
    for i, sample in tqdm(enumerate(dataloader)):
        sample = sample.to(device)
        logits_dict = {
            model_name: model(sample, stop_at_layer=layer)
            for (model_name, (layer, model)) in models.items()
        }
        
        for intervention in interventions:
            modified_logits = intervention.forward(logits_dict[intervention.start_model])
            end_layer, end_model = models[intervention.end_model]
            preds = end_model(modified_logits, start_at_layer=end_layer)
            
            results[intervention.name].append(
                next_token_cross_entropy_loss(preds, sample, ignore_index=spec_tokens, reduction='none').flatten().cpu()
            )
    for (k,v) in results.items():
        if len(v) == 0:
            results[k] = None
        else:
            results[k] = torch.concatenate(v).mean()
    return results
