import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
import numpy as np
from ellipse_attack.transformations import Ellipse, Model
import fire

def run_experiment_fast(output_dir: str = "results/experiment", models_dir: str = "data/model", outputs_dir: str = "data/outputs"):
    """Runs the model identification experiment by getting the ellipse directly from model parameters."""
    
    model_names = [
        "mistralai_Mistral-7B-Instruct-v0.2",
        "Qwen_Qwen1.5-7B-Chat",
        "meta-llama_Llama-2-7b-chat-hf",
        "deepseek-ai_deepseek-llm-7b-chat",
        "allenai_OLMo-7B-Instruct"
    ]
    
    # Get and save ellipses from model parameters
    for model_name in model_names:
        print(f"Getting ellipse for {model_name}...")
        model_path = os.path.join(models_dir, f"{model_name}.npz")
        
        model_params = np.load(model_path)
        model = Model(**model_params)
        
        ellipse = model.ellipse()
        
        ellipse_dir = os.path.join(output_dir, "ellipses")
        os.makedirs(ellipse_dir, exist_ok=True)
        ellipse_path = os.path.join(ellipse_dir, f"{model_name}.npz")
        np.savez(ellipse_path, **ellipse.__dict__)

    # Calculate and store distances
    results = {model_name: {} for model_name in model_names}
    for ellipse_model_name in model_names:
        print(f"Calculating distances for {ellipse_model_name}'s ellipse...")
        ellipse_path = os.path.join(output_dir, "ellipses", f"{ellipse_model_name}.npz")
        ellipse = Ellipse.from_npz(ellipse_path)
        
        for output_model_name in model_.names:
            outputs_path = os.path.join(outputs_dir, f"{output_model_name}.npy")
            logprobs = np.load(outputs_path)
            
            error = np.mean(ellipse.error(logprobs))
            results[ellipse_model_name][output_model_name] = error

    # Print results table
    header = " " * 40 + "".join([f"{name[:25]:>25}" for name in model_names])
    print(header)
    for ellipse_model_name in model_names:
        row = f"{ellipse_model_name[:40]:<40}"
        for output_model_name in model_names:
            error = results[ellipse_model_name][output_model_name]
            row += f"{error:25.4f}"
        print(row)

if __name__ == "__main__":
    fire.Fire(run_experiment_fast)
