#!/usr/bin/env python3
"""Model utilities for LoRA operations and model management"""

import os
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

def apply_round_chat_template(self, messages, add_generation_prompt=False, tokenize=False, is_thinking=None, **kwargs) -> str:
    """
    Apply round-based chat template with Human/Assistant format.
    
    Parameters:
        messages (list): List of messages with 'role' and 'content'.
        add_generation_prompt (bool): Whether to include space for generation.
        tokenize (bool): Not used, kept for compatibility.
    
    Returns:
        str: Formatted prompt string.
    """
    history = []
    current_query = ""
    
    # Extract conversation history and current query
    for i in range(0, len(messages) - 1, 2):
        if i + 1 < len(messages) and messages[i]['role'] == 'user' and messages[i + 1]['role'] == 'assistant':
            history.append((messages[i]['content'], messages[i + 1]['content']))
    
    # Get the last user message as current query
    if messages and messages[-1]['role'] == 'user':
        current_query = messages[-1]['content']
    
    # Apply the formatting logic
    input_text = '<|begin_of_text|>'
    for idx, (query, answer) in enumerate(history):
        input_text += f"[Round {idx}]\nHuman: {query}\nAssistant: {answer}"
    
    if len(current_query)>0 and add_generation_prompt:
        input_text += f"[Round {len(history)}]\nHuman: {current_query}\nAssistant:"
    
    if not add_generation_prompt and messages and messages[-1]['role'] == 'assistant':
        input_text += f"<|end_of_text|>"
    
    if tokenize:
        return self(input_text, add_special_tokens=False)["input_ids"]
    
    return input_text

def apply_alpaca_chat_template(self, messages, add_generation_prompt=False, tokenize=False) -> str:
    """
    Apply Alpaca-style chat template using 'system' as instruction, 'user' as input, and optionally 'assistant' as output.

    Parameters:
        messages (list): List of messages with 'role' and 'content'.
        add_generation_prompt (bool): Whether to include space for generation (i.e., leave response empty).

    Returns:
        str: Formatted prompt string.
    """
    # Validate roles
    roles = [msg['role'] for msg in messages]
    assert 'system' in roles, "Expected at least one system message for instruction."
    assert 'user' in roles, "Expected at least one user message for input."
    has_response = 'assistant' in roles

    # Extract parts
    system_msg = next(msg['content'] for msg in messages if msg['role'] == 'system')
    user_msg = next(msg['content'] for msg in messages if msg['role'] == 'user')
    assistant_msg = next((msg['content'] for msg in messages if msg['role'] == 'assistant'), "")

    # Construct prompt
    prompt = ("Below is an instruction that describes a task, paired with an input that provides further context. "
              "Write a response that appropriately completes the request.\n\n"
              f"### Instruction:\n{system_msg.strip()}\n\n"
              f"### Input:\n{user_msg.strip()}\n\n")

    if has_response:
        prompt += f"### Response:\n{assistant_msg.strip()} </s>"

    if add_generation_prompt:
        prompt += f"### Response:\n"

    return prompt.strip()

def load_lora_hist(path, model, cache_dir):
    if path is None or (not os.path.exists(os.path.join(path, "lora_hist.json"))):
        return model, []
    with open(os.path.join(path, "lora_hist.json"), "r") as f:
        lora_hist = json.load(f)
    lora_names = []
    for i, lora_d in enumerate(lora_hist):
        if i == 0:
            model = PeftModel.from_pretrained(model, lora_d, cache_dir=cache_dir)
            lora_names.append("default")
        else:
            model.load_adapter(lora_d, "lora_" + str(i))
            lora_names.append("lora_"+ str(i))
    model.add_weighted_adapter(lora_names, [1.0]*len(lora_names), "ad_lora", "cat")
    model.set_adapter("ad_lora")
    model = model.merge_and_unload()
    return model, lora_hist

def load_lora_hist_ad(path, model, cache_dir):
    if path is None or (not os.path.exists(os.path.join(path, "lora_hist.json"))):
        lora_hist = []
    else:
        with open(os.path.join(path, "lora_hist.json"), "r") as f:
            lora_hist = json.load(f)
    for i, lora_d in enumerate(lora_hist):
        model.load_adapter(lora_d, adapter_name="l_"+str(i))
    model.add_weighted_adapter(["default"] + ["l_"+str(i) for i in range(len(lora_hist))], [1.0]*(1+len(lora_hist)), "merged", combination_type="cat")
    return model, lora_hist

def load_lora_hist_simple(path, model, cache_dir):
    if path is None or (not os.path.exists(os.path.join(path, "lora_hist.json"))):
        return model, []
    
    with open(os.path.join(path, "lora_hist.json"), "r") as f:
        lora_hist = json.load(f)
    
    for lora_d in lora_hist:
        print(lora_d)
        model = PeftModel.from_pretrained(model, lora_d, cache_dir=cache_dir)
        model = model.merge_and_unload()
    
    return model, lora_hist

def move_to_vllm(model, vllm):
    for name, param in model.named_parameters():
        # When using PEFT, we need to recover the original parameter name and discard some parameters
        name = name.removeprefix("base_model.model.").replace(".base_layer", "")
        name = name.replace("modules_to_save.default.", "")

        llm_model = vllm.llm_engine.model_executor.driver_worker.model_runner.model
        llm_model.load_weights([(name, param.data)])
    vllm.reset_prefix_cache()

def merge_and_save_model(base_model_name, lora_dir, temp_dir, cache_dir=None):
    """Merge LoRA weights and save to temporary directory with caching"""
    
    # Check if model already exists in temp directory with same LoRA
    if os.path.exists(temp_dir):
        lora_marker_file = os.path.join(temp_dir, ".lora_source")
        if os.path.exists(lora_marker_file):
            with open(lora_marker_file, 'r') as f:
                cached_lora_dir = f.read().strip()
            if cached_lora_dir == lora_dir or (cached_lora_dir==base_model_name and lora_dir is None):
                print(f"Using existing model in {temp_dir} (same LoRA source)")
                return temp_dir
    
    os.makedirs(temp_dir, exist_ok=True)
    
    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.float16,
        cache_dir=cache_dir
    )
    print(lora_dir)
    # Load and merge LoRA weights
    model, _ = load_lora_hist(lora_dir, model, cache_dir=cache_dir)
    
    # Save merged model
    model.save_pretrained(temp_dir)
    
    # Save tokenizer
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, cache_dir=cache_dir)
    tokenizer.save_pretrained(temp_dir)
    
    # Save LoRA source marker
    with open(os.path.join(temp_dir, ".lora_source"), 'w') as f:
        if lora_dir:
            f.write(lora_dir)
        else:
            f.write(base_model_name)
    
    print(f"Merged model saved to {temp_dir}")
    return temp_dir