import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from dataclasses import asdict
from transformers import AutoModelForCausalLM, AutoConfig
from ellipse_attack.transformations import Model
import torch, numpy as np

def get_model_params(model_name, trust_remote_code=False):
    config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=trust_remote_code)
    
    if "llama" in model_name.lower():
        W = model.lm_head.weight.cpu().numpy().T
        gamma = model.model.norm.weight.cpu().numpy()
        # Llama 2 doesn't have a bias term in the final norm
        beta = np.zeros_like(gamma)
    elif "qwen" in model_name.lower():
        W = model.lm_head.weight.cpu().numpy().T
        gamma = model.model.norm.weight.cpu().numpy()
        beta = np.zeros_like(gamma) # Qwen doesn't have a bias term in the final norm
    elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
        W = model.lm_head.weight.cpu().numpy().T
        gamma = model.model.norm.weight.cpu().numpy()
        beta = np.zeros_like(gamma) # Mistral doesn't have a bias term in the final norm
    elif "olmo" in model_name.lower():
        W = model.lm_head.weight.cpu().numpy().T
        gamma = model.transformer.ln_f.weight.cpu().numpy()
        # OLMo doesn't have a bias term in the final norm
        beta = np.zeros_like(gamma)
    elif "gemma" in model_name.lower():
        W = model.lm_head.weight.cpu().numpy().T
        gamma = model.model.norm.weight.cpu().numpy()
        beta = np.zeros_like(gamma) # Gemma uses RMSNorm
    elif "gpt-oss" in model_name.lower():
        W = model.lm_head.weight.cpu().numpy().T
        gamma = model.model.norm.weight.cpu().numpy()
        beta = np.zeros_like(gamma) # GPT-OSS uses RMSNorm
    else:
        # Fallback for unknown models, assuming a common structure
        try:
            W = model.lm_head.weight.cpu().numpy().T
            gamma = model.model.norm.weight.cpu().numpy()
            beta = model.model.norm.bias.cpu().numpy()
        except AttributeError:
            raise NotImplementedError(f"Model architecture not supported for {model_name}")

    return Model(stretch=gamma, bias=beta, unembed=W)

if __name__ == "__main__":
    model_names = {
        "mistralai/Mistral-Small-3.2-24B-Instruct-2506": False,
        "Qwen/Qwen3-Next-80B-A3B-Instruct": True,
        "meta-llama/Llama-4-Scout-17B-16E-Instruct": True,
        "allenai/OLMo-2-1124-7B-Instruct": True,
        "openai/gpt-oss-20b": True,
        "google/gemma-3-12b-it": True
    }
    
    for model_name, trust_remote_code in model_names.items():
        with torch.no_grad():
            print(f"Saving model: {model_name}")
            try:
                final_layer = get_model_params(model_name, trust_remote_code=trust_remote_code)
                model_basename = model_name.replace("/", "_")
                os.makedirs("data/model", exist_ok=True)
                np.savez(f"data/model/{model_basename}.npz", **asdict(final_layer))
            except Exception as e:
                print(f"Failed to save model {model_name}: {e}")