import torch
import torch.optim as optim
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup
from peft import PeftModel
from datasets import load_dataset, Dataset
from tqdm import tqdm
from functools import partial
import os
from torch.utils.data import DataLoader


def get_model_role(tokenizer):

    try:
        prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": "test"}], 
            tokenize=False, 
            add_generation_prompt=True
        )
        if "<start_of_turn>model" in prompt:
            return "model"
        elif "<|start_header_id|>assistant" in prompt:
            return "assistant"
        elif "<|im_start|>assistant" in prompt:
            return "assistant"
    except Exception:
        pass
    print("Warning: Could not definitively determine model role. Defaulting to 'assistant'.")
    return "assistant"


def load_and_format_data(tokenizer, path, prompt_col, response_col):
    dataset = load_dataset("json", data_files={"train": path}, split="train")
    

    model_role = get_model_role(tokenizer)
    print(f"Detected model role for chat template: '{model_role}'")
    
    def _format(example):
        user = example.get(prompt_col, "")
        model_resp = example.get(response_col, "")

        messages = [{"role": "user", "content": user}, {"role": model_role, "content": model_resp}]
        
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        return {"text": text, "prompt": user, "response": model_resp}
        
    return dataset.map(_format, remove_columns=list(dataset.features), num_proc=4)

def create_gradient_mask(model, harmful_neuron_indices):
    print("Creating gradient mask using precise name matching...")
    grad_mask = {}
    
    trainable_param_names = [name for name, param in model.named_parameters() if param.requires_grad]
    for name in trainable_param_names:
        param = model.get_parameter(name)
        grad_mask[name] = torch.zeros_like(param.data)
            
    for hooked_lora_a_module_name, indices in harmful_neuron_indices.items():
        
        param_name_A = f"{hooked_lora_a_module_name}.weight"
        param_name_B = hooked_lora_a_module_name.replace("lora_A", "lora_B") + ".weight"
        
        mask_A = grad_mask.get(param_name_A)
        if mask_A is not None:
            indices_device = indices.to(mask_A.device)
            mask_A.index_fill_(0, indices_device, 1.0)
            # print(f"  Unmasking {len(indices)} rows for parameter: {param_name_A}")
            
        mask_B = grad_mask.get(param_name_B)
        if mask_B is not None:
            indices_device = indices.to(mask_B.device)
            mask_B.index_fill_(1, indices_device, 1.0)
            # print(f"  Unmasking {len(indices)} columns for parameter: {param_name_B}")

    if not any(mask.sum() > 0 for mask in grad_mask.values()):
        print("[ERROR] Gradient mask is all zeros. No parameters will be updated. Check name matching logic.")
    else:
        print("Gradient mask created successfully.")
        
    return grad_mask


def collate_fn_final(batch, tokenizer):

    bos = tokenizer.bos_token if tokenizer.bos_token is not None else ""
    eos = tokenizer.eos_token if tokenizer.eos_token is not None else ""

    if get_model_role(tokenizer)=="model":
        texts_to_tokenize = [bos + item['text'] + eos for item in batch]
        
    else:
        texts_to_tokenize = [item['text']  for item in batch]
    
    model_inputs = tokenizer(
        texts_to_tokenize, padding='longest', truncation=True, max_length=2048,
        return_tensors="pt", add_special_tokens=False)
        
    labels = model_inputs.input_ids.clone()


    for i in range(len(batch)):
        user_content = batch[i]['prompt']
        prompt_messages = [{"role": "user", "content": user_content}]
        
        prompt_part_text = bos + tokenizer.apply_chat_template(
            prompt_messages, tokenize=False, add_generation_prompt=True)
            
        prompt_len = len(tokenizer.encode(prompt_part_text, add_special_tokens=False))
        
        num_real_tokens = model_inputs.attention_mask[i].sum().item()
        padding_offset = len(model_inputs.input_ids[i]) - num_real_tokens
        
        start_index = padding_offset
        end_index = padding_offset + prompt_len
        
        if end_index > start_index:
            labels[i, start_index:end_index] = -100

    model_inputs['labels'] = labels
    return model_inputs

def main():
    parser = argparse.ArgumentParser(description="Parameter-Isolated Safety Steering with Dynamic Early Stopping.")
    parser.add_argument('--model_name', type=str, default='/root/autodl-tmp/model/gemma-2-9b-it')
    parser.add_argument('--lora_adapter_path', type=str, required=True)
    parser.add_argument('--raw_forget_path', type=str, required=True)
    parser.add_argument('--raw_steering_path', type=str, required=True)
    parser.add_argument('--harmful_neurons_path', type=str, default='harmful_neurons_lora.pt')
    parser.add_argument('--output_dir', type=str, default='ckpt/')

    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--epochs', type=int, default=10, help="Maximum number of epochs to run.")
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--forget_weight', type=float, default=1.0)
    parser.add_argument('--steering_weight', type=float, default=1.0)

    parser.add_argument('--weight_decay', type=float, default=0.0)
    parser.add_argument('--warmup_steps', type=int, default=20)

    parser.add_argument('--early_stopping_ratio', type=float, default=0.5, 
                        help="The smoothed loss ratio (steer/forget) below which training will stop.")
    parser.add_argument('--loss_ratio_smoothing', type=float, default=0.6,
                        help="The smoothing factor (alpha) for the loss ratio EMA (lower is smoother).")

    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token; tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = 'left'
    
    base_model = AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="eager")
    model = PeftModel.from_pretrained(base_model, args.lora_adapter_path, is_trainable=True)
    

    collate_with_tokenizer = partial(collate_fn_final, tokenizer=tokenizer)
    forget_dataset = load_and_format_data(tokenizer, args.raw_forget_path, "instruction", "harmful_response")
    steering_dataset = load_and_format_data(tokenizer, args.raw_steering_path, "instruction", "safe_guide")
    forget_loader = DataLoader(forget_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_with_tokenizer)
    steering_loader = DataLoader(steering_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_with_tokenizer)


    harmful_neuron_indices = torch.load(args.harmful_neurons_path, map_location='cpu')
    grad_mask = create_gradient_mask(model, harmful_neuron_indices)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.learning_rate, weight_decay=args.weight_decay)
    num_training_steps = len(forget_loader) * args.epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=num_training_steps)


    model.train()
    smoothed_loss_ratio = None
    early_stop_triggered = False

    for epoch in range(args.epochs):
        if early_stop_triggered:
            print("Early stopping triggered in previous epoch. Terminating training.")
            break
            
        print(f"\n--- Epoch {epoch + 1}/{args.epochs} ---")
        steering_iter = iter(steering_loader)
        progress_bar = tqdm(forget_loader, desc=f"Epoch {epoch+1}")
        
        for forget_batch in progress_bar:
            try: steering_batch = next(steering_iter)
            except StopIteration: steering_iter = iter(steering_loader); steering_batch = next(steering_iter)
            
            optimizer.zero_grad()
            
            forget_batch = {k: v.to(device) for k, v in forget_batch.items()}
            forget_loss = model(**forget_batch).loss
            
            steering_batch = {k: v.to(device) for k, v in steering_batch.items()}
            steering_loss = model(**steering_batch).loss
            

            current_ratio = steering_loss.item() / (forget_loss.item() + 1e-8)
            alpha = args.loss_ratio_smoothing
            if smoothed_loss_ratio is None:
                smoothed_loss_ratio = current_ratio
            else:
                smoothed_loss_ratio = alpha * current_ratio + (1 - alpha) * smoothed_loss_ratio
            

            total_loss = (args.steering_weight * steering_loss) - (args.forget_weight * forget_loss)
            total_loss.backward()
            
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if param.grad is not None and name in grad_mask:
                        param.grad *= grad_mask[name].to(param.device)
            
            optimizer.step()
            scheduler.step()
            
            progress_bar.set_postfix({
                "Forget": f"{forget_loss.item():.4f}", 
                "Steer": f"{steering_loss.item():.4f}", 
                "LR": f"{scheduler.get_last_lr()[0]:.2e}",
                "Ratio(smooth)": f"{smoothed_loss_ratio:.4f}"
            })

            if smoothed_loss_ratio < args.early_stopping_ratio:
                print(f"\n[INFO] Smoothed loss ratio ({smoothed_loss_ratio:.4f}) has reached the target threshold ({args.early_stopping_ratio}).")
                print("Stopping training early.")
                early_stop_triggered = True
                break

    model.save_pretrained(args.output_dir)
    print(f"\nSteered LoRA adapter saved to: {args.output_dir}")

if __name__ == '__main__':
    main()