import json
import sys
from datasets import load_dataset
from unsloth import FastLanguageModel
from transformers import AutoTokenizer, default_data_collator
from ldifs.featextract import extract_features_for_lp
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
import torch
import numpy as np

def preprocess(example, tokenizer, max_seq_length):
    tokens = tokenizer(example["text"], truncation=True, padding="max_length", max_length=max_seq_length)
    tokens["labels"] = example["label"]
    return tokens

def main(config_path):
    # Read config.json
    with open(config_path, "r") as f:
        cfg = json.load(f)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load models
    model_finetuned, tokenizer = FastLanguageModel.from_pretrained(
        cfg["finetuned_model_path"],
        dtype="auto",
        load_in_4bit=cfg.get("load_in_4bit", False),
    )
    model_frozen, _ = FastLanguageModel.from_pretrained(
        cfg["base_model"],
        dtype="auto",
        load_in_4bit=cfg.get("load_in_4bit", False),
    )
    model_finetuned.to(device)
    model_frozen.to(device)

    # Load dataset
    dataset = load_dataset(cfg["dataset"], split=cfg["split"])
    dataset = dataset.map(lambda ex: preprocess(ex, tokenizer, cfg["max_seq_length"]), batched=True)
    dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    dataloader = DataLoader(dataset, batch_size=cfg["batch_size"], collate_fn=default_data_collator)

    # Extract features
    X_frozen, y = extract_features_for_lp(model_frozen, dataloader, num_layers=cfg["num_layers"], is_frozen=True, return_numpy=cfg["return_numpy"])
    X_finetuned, _ = extract_features_for_lp(model_finetuned, dataloader, num_layers=cfg["num_layers"], is_frozen=False, return_numpy=cfg["return_numpy"])

    # Fit classifier
    clf_frozen = LogisticRegression(max_iter=1000).fit(X_frozen, y)
    clf_finetuned = LogisticRegression(max_iter=1000).fit(X_finetuned, y)

    acc_frozen = clf_frozen.score(X_frozen, y)
    acc_finetuned = clf_finetuned.score(X_finetuned, y)
    delta_lp = acc_finetuned - acc_frozen

    print(f"Accuracy frozen:    {acc_frozen:.4f}")
    print(f"Accuracy finetuned: {acc_finetuned:.4f}")
    print(f"∆LP: {delta_lp:.4f}")

    results = {
        "accuracy_frozen": acc_frozen,
        "accuracy_finetuned": acc_finetuned,
        "delta_lp": delta_lp
    }

    if "save_results_path" in cfg:
        with open(cfg["save_results_path"], "w") as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to: {cfg['save_results_path']}")

if __name__ == "__main__":
    config_file = sys.argv[1]
    main(config_file)
