import pickle
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import transformers
from sae_lens import SAE
from sandbagging_research_sprint.utils import load_hf_model
from transformer_lens import HookedTransformer
from transformers import AutoTokenizer

from eliciting_contexts.utils.auth import setup_huggingface_auth
from eliciting_contexts.utils.constants import DATA_DIR, DEVICE


def load_sae_saelens(
    release: str = "gemma-scope-9b-it-res-canonical",
    sae_id: str = "layer_9/width_131k/canonical",
    device: str = "cuda",
    dtype: str = "bfloat16",
):
    sae, _, _ = SAE.from_pretrained(release=release, sae_id=sae_id, device=device)

    torch_dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
    sae = sae.to(torch_dtype)
    return sae


def load_model_tlens(
    model_name: str = "google/gemma-2b-it",
    device: str = "cuda",
    dtype: str = "bfloat16",
):
    model = HookedTransformer.from_pretrained(model_name, dtype=dtype, device=device)
    return model


def load_sae_and_model(
    model_name: str = "google/gemma-2b-it",
    sae_model_name: str = "gemma-2b-it-res-jb",  # to see the list of available releases, go to: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/pretrained_saes.yaml
    sae_id: str = "blocks.12.hook_resid_post",  # change this to another specific SAE ID in the release if desired.
    dtype: str = "bfloat16",  # Type of torch tensors
    device: str = "cuda",  # Device mapping strategy
):
    model = HookedTransformer.from_pretrained(model_name, dtype=dtype, device=device)
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    sae, cfg, sparsity = SAE.from_pretrained(sae_model_name, sae_id, device=device)

    # Check and enable gradients for SAE parameters if needed
    if not all(param.requires_grad for param in sae.parameters()):
        print("Enabling gradients for SAE parameters...")
        for param in sae.parameters():
            param.requires_grad = True

    print(type(sae), type(cfg), type(sparsity))

    return (
        model,
        tokenizer,
        sae,
        cfg,
        sparsity,
    )


def load_finetuned_model(
    lora_model_id: str = "contextmodification/gemma-sandbagging-0w4j7rba-step1024",
    base_model_name: str = "google/gemma-2-2b-it",
    device: str | torch.device = "cuda",
    attn_implementation: Union[str, None] = "eager",
    use_cache: bool = True,  # None case is hardcoded in load_hf_model
):
    """Test function to load a LoRA finetuned model from Hugging Face."""

    try:
        setup_huggingface_auth()

        tokenizer = AutoTokenizer.from_pretrained(base_model_name)
        # Add padding token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model = load_hf_model(
            lora_model_id,
            attn_implementation=attn_implementation,
            use_cache=use_cache,
            device_map=device,
        )
        model.config.pad_token_id = tokenizer.pad_token_id

        return model, tokenizer

    except Exception as e:
        print(f"Error loading/testing model: {str(e)}")
        raise


class TorchLogisticRegression(nn.Module):
    def __init__(self, input_dim, return_logits=False):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)  # Single output for binary classification
        self.return_logits = return_logits

    def forward(self, x):
        logits = self.linear(x)
        # Convert to probabilities [P(class=0), P(class=1)]
        if self.return_logits:
            return torch.cat(
                [-logits, logits], dim=-1
            )  # kinda silly but want to keep same shape as return prob
        probs = torch.sigmoid(logits)
        return torch.cat([1 - probs, probs], dim=-1)  # not sandbagging, sandbagging


def sklearn_to_torch_logreg(sklearn_model, device="cuda", return_logits=False):
    """Convert a sklearn LogisticRegression model to a PyTorch model"""
    input_dim = sklearn_model.coef_.shape[1]
    torch_model = TorchLogisticRegression(input_dim, return_logits=return_logits).to(
        device
    )

    # Copy weights - sklearn stores as (n_classes-1, n_features)
    # For binary classification, we need to negate for the first class
    torch_model.linear.weight.data = torch.FloatTensor(sklearn_model.coef_).to(device)

    # Copy bias - sklearn stores as (n_classes-1,)
    torch_model.linear.bias.data = torch.FloatTensor(sklearn_model.intercept_).to(
        device
    )

    return torch_model


def verify_sklearn_torch_conversion(sklearn_model, device="cuda", rtol=1e-5, atol=1e-8):
    """
    Verify that sklearn and PyTorch models produce the same outputs using random test inputs.

    Args:
        sklearn_model: Trained sklearn LogisticRegression model
        device: Device to run PyTorch model on
        rtol: Relative tolerance for numerical comparison
        atol: Absolute tolerance for numerical comparison

    Returns:
        bool: True if outputs match within tolerance
    """
    # Generate random test inputs
    input_dim = sklearn_model.coef_.shape[1]
    numpy_inputs = np.random.randn(1000, input_dim)  # 1000 random test cases
    torch_inputs = torch.FloatTensor(numpy_inputs).to(device)

    # Get predictions from both models
    torch_model = sklearn_to_torch_logreg(sklearn_model, device)

    with torch.no_grad():
        torch_outputs = torch_model(torch_inputs).cpu().numpy()
    sklearn_outputs = sklearn_model.predict_proba(numpy_inputs)

    # Compare outputs
    matches = np.allclose(sklearn_outputs, torch_outputs, rtol=rtol, atol=atol)

    if not matches:
        max_diff = np.max(np.abs(sklearn_outputs - torch_outputs))
        print(f"Maximum difference between outputs: {max_diff}")
        print("sklearn output shape:", sklearn_outputs.shape)
        print("torch output shape:", torch_outputs.shape)
        print("\nExample disagreement:")
        mismatch_idx = np.argmax(np.abs(sklearn_outputs - torch_outputs))
        print(f"sklearn: {sklearn_outputs.flatten()[mismatch_idx]}")
        print(f"torch: {torch_outputs.flatten()[mismatch_idx]}")
    else:
        print("IT WORKED")
    return matches


def download_classifier(
    pickle_path: str = "/workspace/eliciting-contexts/probe_layer_15.pkl",
) -> tuple:
    # currently a pickle file, not set up to load from wandb yet
    with open(pickle_path, "rb") as f:
        classifier = pickle.load(f)
    return classifier


def load_classifier(
    pickle_path: str = DATA_DIR / "probe_layer_15.pkl",
    dtype: torch.dtype = torch.bfloat16,
    return_logits: bool = False,
) -> tuple:
    """Load the classifier model

    Args:
        pickle_path: Path to the pickled sklearn model
        dtype: torch dtype to cast the model to (defaults to bfloat16)
        return_logits: whether to return the logits or the probabilities
    """
    if pickle_path is None:
        pickle_path = DATA_DIR / "probe_layer_15.pkl"
    classifier = download_classifier(pickle_path)
    torch_model = sklearn_to_torch_logreg(
        classifier, device=DEVICE, return_logits=return_logits
    )

    if dtype is not None:
        torch_model = torch_model.to(dtype)

    return torch_model
