import json

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from sklearn.utils.extmath import randomized_svd
from sklearn.cluster import KMeans
from peft import LoraConfig, get_peft_model, PeftModel
import torch.nn.functional as F
import torch.nn as nn

random.seed(0)

def apply_lora_to_mlp(args, model, chosen_layers):
    freeze_all_layers(model)
    
    lora_config = LoraConfig(
        r=args.r, 
        lora_alpha=args.lora_alpha,
        target_modules=args.lora_layer_selection,  
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        layers_to_transform=chosen_layers,
    )

    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()  
    return model

def print_lora_layers(model):
    total_params = 0
    print("\nAll trainable parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"{name}: shape {list(param.shape)} numel: {param.numel():,}")
            total_params += param.numel()
    print(f"\nTotal trainable parameters: {total_params:,}")

def freeze_all_layers(model):
    for layer in model.model.layers:
        for p in layer.parameters():
            p.requires_grad = False

def unfreeze_all_layers(model):
    for layer in model.model.layers:
        for p in layer.parameters():
            p.requires_grad = True

@torch.no_grad()
def zero_grad(model):
    for p in model.parameters():
        if p.grad is not None:
            p.grad.zero_()

def prepare_model_for_unlearning(model, chosen_layers):
    freeze_all_layers(model)
    for layer_id in chosen_layers:
        for p in model.model.layers[layer_id].parameters():
            p.requires_grad = True


def compute_shared_dominant_vectors(
    model, tokenizer, forget_data_list, layer_id, device,
    sample_size=10, num_components=5, method="power_iteration"
):
    domain_representations = []
    
    for domain_data in forget_data_list:
        domain_reps = []
        sampled_data = random.sample(domain_data, min(sample_size, len(domain_data)))
        
        for batch in sampled_data:
            inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
            with torch.no_grad():
                activations = forward_with_cache(
                    model, inputs, 
                    module=eval(f"model.base_model.model.model.layers[{layer_id}]")
                )
            if activations.dim() == 3:
                activations = activations.mean(dim=1) 
            domain_reps.append(activations)
        
        domain_rep = torch.cat(domain_reps, dim=0)
        domain_representations.append(domain_rep)

    domain_pcs = []
    for domain_rep in domain_representations:
        domain_rep = domain_rep.float().to(device)

        if method == "power_iteration":
            print("Using Power Iteration")
            pcs = power_iteration_topk(domain_rep, num_components=num_components)
            domain_pcs.append(pcs) 

        elif method == "randomized_svd":
            print("Using Randomized SVD")
            arr = domain_rep.cpu().numpy()
            U, S, Vt = randomized_svd(arr, n_components=num_components, random_state=42)
            pcs = torch.tensor(Vt, dtype=torch.float32, device=device)  
            domain_pcs.append(pcs)

        elif method == "kmeans":
            print("Using K-Means")
            arr = domain_rep.cpu().numpy()
            kmeans = KMeans(n_clusters=num_components, random_state=42, n_init=10)
            kmeans.fit(arr)
            pcs = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32, device=device)
            domain_pcs.append(pcs)

        else:
            raise ValueError("method must be 'power_iteration', 'randomized_svd', or 'kmeans'.")

    combined_pcs = torch.cat(domain_pcs, dim=0)
    print(f"[compute_shared_dominant_vectors] combined_pcs={combined_pcs.shape}")

    U, S, Vt = torch.linalg.svd(combined_pcs, full_matrices=False)
    shared_dominant_vectors = Vt[:num_components]

    Q,R = torch.linalg.qr(shared_dominant_vectors.T)
    shared_dominant_vectors = Q.T
    return shared_dominant_vectors

def remove_subspace(x, subspace_vectors):
    orig_dtype = x.dtype
    x = x.float()
    subspace_vectors = subspace_vectors.float()

    B, S, H = x.shape
    k, H2 = subspace_vectors.shape
    assert H == H2, "Dimension mismatch"

    x_flat = x.view(B*S, H)  

    alpha = x_flat @ subspace_vectors.T
    projection = alpha @ subspace_vectors
    x_removed = x_flat - projection
    x_removed = x_removed.view(B, S, H)
    return x_removed.to(orig_dtype)

def compute_preference_loss(score_retain, score_forget):
    return -torch.log(torch.sigmoid(score_retain.mean() - score_forget.mean()) + 1e-6)


class PreferenceHead(nn.Module):
    def __init__(self, hidden_size, intermediate_size=None):
        super(PreferenceHead, self).__init__()
        if intermediate_size is None:
            intermediate_size = hidden_size // 2
        self.net = nn.Sequential(
            nn.Linear(hidden_size, intermediate_size),
            nn.ReLU(),
            nn.Linear(intermediate_size, 1),
        )
    def forward(self, x):
        # x: (B, H)
        return self.net(x).squeeze(-1)  # (B,)

def forward_with_cache(model, inputs, module, no_grad=True):
    cache = []
    def hook(module, input, output):
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
        return None 
    
    hook_handle = module.register_forward_hook(hook)
    
    if no_grad:
        with torch.no_grad():
            _ = model(**inputs)
    else:
        _ = model(**inputs)
        
    hook_handle.remove()

    return cache[0]
    
def get_params(model, layer_ids, param_ids):
    params = []
    for layer_id in layer_ids:
        for i, p in enumerate(model.model.layers[layer_id].parameters()):
            if i in param_ids:
                params.append(p)
    return params


def load_model(model_name_or_path):
    torch_dtype = "auto" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16

    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        device_map="auto",
        attn_implementation='flash_attention_2',
    )

    if 'Llama' in model_name_or_path:
        USE_FAST=True
    else:
        USE_FAST=False
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path, trust_remote_code=True, use_fast=USE_FAST
    )
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left"
    tokenizer.mask_token_id = tokenizer.eos_token_id
    tokenizer.sep_token_id = tokenizer.eos_token_id
    tokenizer.cls_token_id = tokenizer.eos_token_id

    return model, tokenizer


def get_data(forget_corpora, retain_corpora, min_len=50, max_len=2000, batch_size=4):
    def get_dataset(name):
        from datasets import load_dataset

        data = []
        if name == 'fictional_knowledge':
            raw_data = load_dataset("json", data_files="/data/fictional_knowledge.json", split='train')
            for x in raw_data:
                if len(x['train_context']) > min_len:
                    data.append(str(x['train_context']))

        elif name == "wikitext":
            raw_data = load_dataset(
                "wikitext",
                "wikitext-2-raw-v1",
                cache_dir="/data/wmdp/wikitext",
                split="test"
            )
            for x in raw_data:
                if len(x['text']) > min_len:
                    data.append(str(x['text']))
        else:
            if "bio-forget-corpus"==name:
                with open(f"/data/wmdp/bio-forget-corpus/bio-forget-corpus.jsonl", "r") as f:
                    for line in f:
                        try:
                            text = json.loads(line)['text']
                            if len(text) > 50:  # Min length filter
                                data.append(str(text))
                        except json.JSONDecodeError:
                            continue
            elif "cyber-forget-corpus"==name:
                dataset = load_dataset("cais/wmdp-corpora", name, cache_dir=f"/data/wmdp/{name}")
                dataset = dataset['train']
                for x in dataset:
                    if len(x['text']) > min_len:
                        data.append(str(x['text']))
            else:
                dataset = load_dataset("cais/wmdp-mmlu-auxiliary-corpora", name, cache_dir=f"/data/wmdp/{name}")
                dataset = dataset['train']
                for x in dataset:
                    if len(x['text']) > min_len:
                        data.append(str(x['text']))

        data = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
        return data

    return (
        [get_dataset(c) for c in forget_corpora],
        [get_dataset(c) for c in retain_corpora]
    )