import os
import torch
import json
import click
import wandb
import numpy as np
from tqdm import tqdm
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from circuit_tracer import ReplacementModel
from datasets import load_dataset

import sys
sys.path.append("circuit_lens/") 

from circuit_lens.activations_processing import load_and_sample_activations
from circuit_lens.circuit_based_analysis import Feature, pattern_and_circuit_discovery, cluster_activations


if torch.cuda.is_available():
    device = torch.device("cuda")
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


class NumpyTorchEncoder(json.JSONEncoder):
    """JSON encoder that handles NumPy arrays and PyTorch tensors."""
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        if isinstance(obj, torch.Tensor):
            return obj.detach().cpu().tolist()
        return super().default(obj)


def load_existing_keys(filepath, key_fields=("feature", "layer")):
    """Load already processed (layer, feature) keys from JSONL."""
    keys = set()
    if os.path.exists(filepath):
        with open(filepath, "r") as f:
            for line in f:
                if line.strip():
                    record = json.loads(line)
                    key = tuple(record.get(k) for k in key_fields)
                    keys.add(key)
    return keys


def append_jsonl(filepath, new_record):
    """Append new record to JSONL file (no rewrite)."""
    with open(filepath, "a") as f:
        f.write(json.dumps(new_record, cls=NumpyTorchEncoder) + "\n")


@click.command()
@click.option("--layer", type=int)
@click.option("--job-id", type=int, default=0)
@click.option("--num-jobs", type=int, default=1)
def main(layer, job_id, num_jobs):
    wandb.init(
        project="patterns_analysis",
        resume="allow",
        allow_val_change=True
    )

    model_name = 'google/gemma-2-2b'
    transcoder_name = "gemma"
    model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16, device=device)

    dataset = load_dataset("parquet", data_files="data/transcoders_batch_1.parquet")
    dataset = {i: row["text"] for i, row in enumerate(dataset["train"])}

    # master results file for deduplication
    main_out_path = f"patterns/features_analysis_layer_{layer}.jsonl"
    existing_keys = load_existing_keys(main_out_path)

    # job-specific output file
    out_path = f"patterns/features_analysis_layer_{layer}_{job_id}.jsonl"

    # load activations
    layer_activations = load_and_sample_activations(
        layer, [], 
        activations_path=Path("activations/gemma-2-2b/"), 
        out_path="sampled_activations/", 
        alpha=0.9, n_bins=20
    )

    features = sorted(layer_activations.keys())
    # split work across jobs
    chunk_size = (len(features) + num_jobs - 1) // num_jobs  # ceiling division
    start = job_id * chunk_size
    end = min((job_id + 1) * chunk_size, len(features))
    assigned_features = features[start:end]

    print(f"Job {job_id}/{num_jobs}: processing {len(assigned_features)} features (indices {start}–{end-1})")

    for feature_idx in tqdm(assigned_features):
        key = (feature_idx, layer)
        if key in existing_keys:
            continue  # already processed

        maxact = layer_activations[feature_idx]
        if len(maxact) == 0:
            print(f"For feature {feature_idx} no activating sequences found.")
            continue

        tokens = model.tokenizer(
            [dataset[i] for i, _, _ in maxact], 
            return_tensors="pt", padding=True
        )
        inputs = [(tokens.input_ids[i], maxact[i][2]) for i in range(len(maxact))]

        result = pattern_and_circuit_discovery(
            model, 
            inputs=inputs, 
            feature=Feature(layer, 0, feature_idx)
        )
        if len(result['patterns']) == 0:
            print(f"For feature {feature_idx} no patterns discovered.")
            continue

        result['labels_input'], result['labels_output'] = cluster_activations(result)
        result["feature"] = feature_idx
        result["layer"] = layer

        append_jsonl(out_path, result)

        # free memory
        del tokens, inputs, result, maxact
        torch.cuda.empty_cache()


if __name__ == "__main__":
    main()
