import os
import torch
import json
import click
import wandb
import numpy as np
from tqdm import tqdm
from pathlib import Path
from typing import Tuple, Union
from transformers import AutoModelForCausalLM, AutoTokenizer
from circuit_tracer import ReplacementModel
from datasets import load_dataset

import sys
sys.path.append("circuit_lens/") 
from circuit_lens.circuit_based_analysis import append_backward_hooks


if torch.cuda.is_available():
    device = torch.device("cuda")
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

def append_jsonl(filepath, new_record):
    """Append new record to JSONL file (no rewrite)."""
    with open(filepath, "a") as f:
        f.write(json.dumps(new_record) + "\n")

def load_jsonl(path):
    data = []
    with open(path, "r") as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data


def get_activations(
    model,
    inputs: Union[str, torch.Tensor],
    sparse: bool = False,
    zero_bos: bool = False,
    apply_activation_function: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Get the transcoder activations for a given prompt"""

    activation_cache, activation_hooks = model._get_activation_caching_hooks(
        sparse=sparse,
        zero_bos=zero_bos,
        apply_activation_function=apply_activation_function,
    )
    # ❌ don't use inference_mode if you need gradients
    with model.hooks(activation_hooks):
        logits = model(inputs)

    activation_cache = torch.stack(activation_cache)
    if sparse:
        activation_cache = activation_cache.coalesce()
    return logits, activation_cache


def analyze_feature_output_patterns_streaming(model, layer, path, out_path, k=5, max_new_tokens=5):
    """
    Streaming analysis: process each feature individually and append results to JSONL as we go.
    Only saves selected fields (output_tokens, labels_input, layer, feature).
    """
    device = next(model.parameters()).device
    results = load_jsonl(path)

    for res in tqdm(results):
        feature_idx = res["feature"]

        res_output_tokens = []
        res_output_pattern = []

        for input_ids, token_pos in res["inputs"]:
            input_ids = torch.tensor(input_ids[: token_pos + 1], device=device)
            tokens_for_input = []
            pattern = []
            generated = 0
            while generated < max_new_tokens:
                gradients, backward_handles = append_backward_hooks(model)
                logits, transcoder_acts = get_activations(model, input_ids.unsqueeze(0))
                logits[0, -1].max().backward()

                contribs = (
                    model.transcoders[layer].W_dec.to(torch.float32)
                    @ gradients[f"blocks.{layer}.hook_mlp_out.hook_out_grad"][0][0, -1]
                )
                contribs = torch.mul(transcoder_acts[layer, -1], contribs)

                if feature_idx in torch.topk(contribs, k=k).indices:
                    next_token = logits[0, -1].argmax().unsqueeze(0).to(device)
                    input_ids = torch.cat([input_ids, next_token], dim=0)
                    token_id = int(logits[0, -1].argmax().cpu())
                    tokens_for_input.append(int(next_token.cpu()))
                    pattern.append(token_id)
                    generated += 1
                else:
                    next_token = logits[0, -1].argmax().unsqueeze(0).to(device)
                    input_ids = torch.cat([input_ids, next_token], dim=0)
                    tokens_for_input.append(int(next_token.cpu()))
                    pattern.append(0)
                    generated += 1

                for handle in backward_handles:
                    handle.remove()

            res_output_tokens.append(tokens_for_input)
            res_output_pattern.append(pattern)

        # Save filtered feature result immediately
        filtered_record = {
            "output_patterns": res_output_pattern,
            "labels": res.get("labels_input", []),
            "layer": res["layer"],
            "feature": res["feature"],
        }
        append_jsonl(out_path, filtered_record)

@click.command()
@click.option("--layer", type=int)
@click.option("--job-id", type=int, default=0)
@click.option("--num-jobs", type=int, default=1)
def main(layer, job_id, num_jobs):
    wandb.init(project="patterns_analysis", resume="allow", allow_val_change=True)

    model_name = 'google/gemma-2-2b'
    transcoder_name = "gemma"
    model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16, device=device)

    # Input patterns
    path = f"patterns/features_analysis_layer_{layer}.jsonl"
    # Job-specific output file
    out_path = f"patterns/output_analysis_layer_{layer}_{job_id}.jsonl"

    # ensure output file exists (empty)
    Path(out_path).write_text("")

    # streaming processing
    analyze_feature_output_patterns_streaming(model, layer, path, out_path)


if __name__ == "__main__":
    main()
