import argparse
import copy
import json
import logging
import os

import torch
import torch.nn.functional as F
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, set_seed
from utils import load_model_and_tokenizer, str2bool

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)


def compute_gradients(owner_model, base_model, prefix_token_ids, x_token_ids, y_token_ids, lambda_reg):
    device = next(owner_model.parameters()).device
    embedding_weights = owner_model.get_input_embeddings().weight

    # One-hot encode x
    one_hot = torch.nn.functional.one_hot(x_token_ids.detach(), num_classes=embedding_weights.shape[0]).to(device=device, dtype=embedding_weights.dtype)
    one_hot.requires_grad = True

    x_embeddings = torch.matmul(one_hot, embedding_weights)

    # Get y embeddings
    prefix_embeddings = owner_model.get_input_embeddings()(prefix_token_ids).detach()
    y_embeddings = owner_model.get_input_embeddings()(y_token_ids).detach()

    # Concatenate x and y embeddings
    full_embeddings = torch.cat([prefix_embeddings, x_embeddings, y_embeddings], dim=1)

    # Get model output and compute loss
    owner_logits = owner_model(inputs_embeds=full_embeddings).logits
    owner_loss = F.cross_entropy(owner_logits[:, prefix_token_ids.shape[1] + x_token_ids.shape[1] - 1 : -1, :].transpose(1, 2), y_token_ids, reduction="mean")

    # Get base model output and compute loss
    base_logits = base_model(inputs_embeds=full_embeddings).logits
    base_loss = F.cross_entropy(base_logits[:, prefix_token_ids.shape[1] + x_token_ids.shape[1] - 1 : -1, :].transpose(1, 2), y_token_ids, reduction="mean")

    total_loss = owner_loss - lambda_reg * base_loss
    total_loss.backward()

    if torch.isnan(one_hot.grad).any():
        raise ValueError("NaN gradients detected!")

    return one_hot.grad


def sample_replacement_indices(x_token_ids, grad, top_k, batch_size):
    device = grad.device
    top_indices = (-grad).topk(top_k, dim=-1).indices

    new_token_pos = torch.arange(0, x_token_ids.shape[1], x_token_ids.shape[1] / batch_size, device=device).long()
    new_token_val = top_indices[0, new_token_pos, torch.randint(0, top_k, (batch_size,), device=device)]

    replace_indices = x_token_ids.repeat(batch_size, 1)
    replace_indices[torch.arange(batch_size, device=device), new_token_pos] = new_token_val

    return replace_indices


def filter_consistent_token_sequences(replaced_indices, tokenizer, device):
    batch_size, _ = replaced_indices.shape

    # Initialize mask to store results
    mask = torch.ones(batch_size, dtype=torch.bool, device=device)

    for i in range(batch_size):
        sequence = replaced_indices[i]

        # Decode the entire sequence
        decoded = tokenizer.decode(sequence, skip_special_tokens=True)

        # Re-encode the decoded string
        re_encoded = torch.tensor(tokenizer.encode(decoded, add_special_tokens=False), device=device)

        # If the re-encoded sequence does not match the original sequence, set the mask to False
        if not torch.equal(re_encoded, sequence):
            mask[i] = False

    # Apply the mask to keep only consistent sequences
    filtered_indices = replaced_indices[mask]

    return filtered_indices


def evaluate_control_sequences(owner_model, base_model, replaced_indices, prefix_token_ids, suffix_token_ids, lambda_reg):
    device = next(owner_model.parameters()).device

    # Move replaced_indices to the correct device
    replaced_indices = replaced_indices.to(device)

    # Error handling: Check input shapes
    if replaced_indices.shape[1] > owner_model.config.max_position_embeddings:
        raise ValueError(f"Input sequence length {replaced_indices.shape[1]} exceeds model's maximum position embeddings {owner_model.config.max_position_embeddings}")

    # Compute logits and loss in batches
    with torch.no_grad():
        batch_prefix = prefix_token_ids.repeat(replaced_indices.shape[0], 1)
        batch_suffix = suffix_token_ids.repeat(replaced_indices.shape[0], 1)
        batch_full = torch.cat([batch_prefix, replaced_indices, batch_suffix], dim=1)

        batch_mask = (batch_full != owner_model.config.pad_token_id).long()

        # Get model output and compute loss
        owner_logits = owner_model(input_ids=batch_full, attention_mask=batch_mask).logits
        owner_losses = F.cross_entropy(owner_logits[:, batch_prefix.shape[1] + replaced_indices.shape[1] - 1 : -1, :].transpose(1, 2), batch_suffix, reduction="none").mean(dim=1)

        # Get base model output and compute loss
        base_logits = base_model(input_ids=batch_full, attention_mask=batch_mask).logits
        base_losses = F.cross_entropy(base_logits[:, batch_prefix.shape[1] + replaced_indices.shape[1] - 1 : -1, :].transpose(1, 2), batch_suffix, reduction="none").mean(dim=1)

        # Compute total loss for this batch
        all_losses = owner_losses - lambda_reg * base_losses

    # Find best candidate
    best_idx = all_losses.argmin()
    best_candidate = replaced_indices[best_idx].unsqueeze(0)
    current_loss = all_losses[best_idx].item()
    # For logging
    current_owner_loss = owner_losses[best_idx].item()
    current_base_loss = base_losses[best_idx].item()

    return best_candidate, current_loss, current_owner_loss, current_base_loss


def optimize_fingerprint_input(owner_model, base_model, tokenizer, base_tokenizer, args):
    logger.info("Optimizing fingerprint input x")
    owner_model.eval()
    base_model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    owner_model.to(device)
    base_model.to(device)

    vocab_size = owner_model.get_input_embeddings().weight.shape[0]

    # Initialize token id of input x randomly
    x_sequence_length = min(args.x_sequence_length, owner_model.config.max_position_embeddings)
    initial_x_str = tokenizer.decode(torch.randint(0, vocab_size, (x_sequence_length,)))
    y_str = args.y_str

    # Get token ids
    input_token_ids = tokenizer.encode(args.adv_prefix + initial_x_str, return_tensors="pt").to(device)
    x_token_ids = tokenizer.encode(initial_x_str, return_tensors="pt", add_special_tokens=False).to(device)
    y_token_ids = tokenizer.encode(y_str, return_tensors="pt", add_special_tokens=False).to(device)

    # Add eos token to y_token_ids
    y_token_ids = torch.cat([y_token_ids, torch.tensor([[tokenizer.eos_token_id]], device=device)], dim=1)

    # Get prefix and suffix
    prefix_token_ids = None
    for i in range(input_token_ids.shape[1] - x_token_ids.shape[1] + 1):
        if torch.all(input_token_ids[:, i : i + x_token_ids.shape[1]] == x_token_ids):
            prefix_token_ids = input_token_ids[:, :i]
            break

    if prefix_token_ids is None:
        raise ValueError("x_token_ids not found in input_token_ids. You may need to change adv_prefix or x_sequence_length.")

    current_input = tokenizer.decode(torch.cat([prefix_token_ids, x_token_ids], dim=1).squeeze().tolist(), skip_special_tokens=True)
    current_loss = float("inf")

    print("input_token_ids", input_token_ids)
    print("prefix_token_ids", prefix_token_ids)
    print("x_token_ids", x_token_ids)
    print("y_token_ids", y_token_ids)
    print("full_token_ids", torch.cat([prefix_token_ids, x_token_ids, y_token_ids], dim=1))
    print("initial input", current_input)

    # Optimization loop
    for step in range(args.optimization_steps):
        # Compute gradients and loss for the current input x
        coordinate_grads = compute_gradients(owner_model, base_model, prefix_token_ids, x_token_ids, y_token_ids, args.lambda_reg)

        # Randomly sample a batch of replacement from top-k
        replaced_indices = sample_replacement_indices(x_token_ids, coordinate_grads, args.top_k, args.batch_size)

        # If the re-encoded sequence does not match the original sequence, set the mask to False
        replaced_indices = filter_consistent_token_sequences(replaced_indices, tokenizer, device)

        # Compute loss on the replaced_indices
        x_token_ids, current_loss, current_owner_loss, current_base_loss = evaluate_control_sequences(owner_model, base_model, replaced_indices, prefix_token_ids, y_token_ids, args.lambda_reg)

        # Generate output for the current x for logging
        with torch.no_grad():
            current_input = tokenizer.decode(torch.cat([prefix_token_ids, x_token_ids], dim=1).squeeze().tolist(), skip_special_tokens=True)
            encoded_current_input = tokenizer.encode(current_input, return_tensors="pt").to(device)
            output_ids = owner_model.generate(
                input_ids=encoded_current_input, max_new_tokens=50, pad_token_id=tokenizer.pad_token_id, attention_mask=(encoded_current_input != tokenizer.pad_token_id).long()
            )
            generated_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # Log current state
        logger.info(f"Optimization step {step + 1}/{args.optimization_steps}")
        logger.info(f"Total Loss: {current_loss:.4f}")
        logger.info(f"Owner Loss: {current_owner_loss:.4f}")
        logger.info(f"Base Loss: {current_base_loss:.4f}")
        logger.info(f"Current x: {current_input}")
        logger.info(f"Generated output: {generated_output}")
        logger.info("\n--------------------------------\n")

        # Early stopping if loss is small enough
        if current_owner_loss < 1.0:
            logger.info("Early stopping due to low owner loss.")
            break
        if current_base_loss < 3.0:
            logger.info("Early stopping due to low base loss.")
            break

    optimized_x = current_input

    # Check no difference between the original and re-encoded sequence
    with torch.no_grad():
        reencode_input_token_ids = tokenizer.encode(optimized_x, return_tensors="pt").to(device)
        reencode_full_token_ids = torch.cat([reencode_input_token_ids, y_token_ids], dim=1)

        reencode_attention_mask = (reencode_full_token_ids != tokenizer.pad_token_id).float()

        reencode_logits = owner_model(input_ids=reencode_full_token_ids, attention_mask=reencode_attention_mask).logits

        reencode_loss = F.cross_entropy(reencode_logits[:, reencode_input_token_ids.shape[1] - 1 : -1, :].transpose(1, 2), y_token_ids, reduction="mean")

    # Check the loss in base model's tokenizer
    with torch.no_grad():
        base_encode_input_token_ids = base_tokenizer.encode(optimized_x, return_tensors="pt").to(device)
        base_encode_full_token_ids = torch.cat([base_encode_input_token_ids, y_token_ids], dim=1)

        print("base_encode_input_token_ids", base_encode_input_token_ids)
        print("base_encode_full_token_ids", base_encode_full_token_ids)
        print("base_tokenizer.pad_token_id", base_tokenizer.pad_token_id)

        base_encode_attention_mask = (base_encode_full_token_ids != base_tokenizer.pad_token_id).float()

        base_encode_logits = base_model(input_ids=base_encode_full_token_ids, attention_mask=base_encode_attention_mask).logits

        base_encode_loss = F.cross_entropy(base_encode_logits[:, base_encode_input_token_ids.shape[1] - 1 : -1, :].transpose(1, 2), y_token_ids, reduction="mean")

    logger.info("Optimization finished.")
    logger.info(f"Final loss: {current_loss:.4f}")
    logger.info(f"Re-encoded loss: {reencode_loss.item():.4f}")
    logger.info(f"Base re-encoded loss: {base_encode_loss.item():.4f}")
    logger.info(f"Optimized fingerprint input x: {optimized_x}")

    return optimized_x


def parse_args():
    parser = argparse.ArgumentParser(description="Optimize fingerprint input")
    parser.add_argument("--config", type=str, default="./configs/create_fingerprint.json", help="Path to the JSON config file")
    parser.add_argument("--model_path", type=str, help="Path to the pre-trained model")
    parser.add_argument("--cache_dir", type=str, help="Cache directory for the model")
    parser.add_argument("--model_revision", type=str, help="Model revision")
    parser.add_argument("--use_fast_tokenizer", type=str2bool, help="Use fast tokenizer")
    parser.add_argument("--low_cpu_mem_usage,", type=str2bool, help="Use low CPU memory usage")
    parser.add_argument("--seed", type=int, help="Random seed")
    parser.add_argument("--x_sequence_length", type=int, help="Length of input sequence x")
    parser.add_argument("--optimization_steps", type=int, help="Number of optimization steps")
    parser.add_argument("--top_k", type=int, help="Top k for sampling replacement indices")
    parser.add_argument("--batch_size", type=int, help="Batch size for optimization")
    parser.add_argument("--adv_prefix", type=str, help="Adversarial prefix")
    parser.add_argument("--y_str", type=str, help="Target string y")
    parser.add_argument("--optimize_input", type=str2bool, help="Whether to optimize input x")
    parser.add_argument("--save_dir", type=str, help="Directory to save fingerprints")
    parser.add_argument("--alpha", type=float, default=0.3, help="Alpha value for merging models")

    # Base regularization arguments
    parser.add_argument("--lambda_reg", type=float, help="Whether to use base regularization")
    parser.add_argument("--base_model_path", type=str, help="Path to the pre-trained base model")

    return parser.parse_args()


def main():
    args = parse_args()

    # Load config from json file
    with open("./configs/create_fingerprint.json", "r") as f:
        config = json.load(f)

    # Update config with command line arguments
    for arg, value in vars(args).items():
        if value is not None:
            config[arg] = value

    args = argparse.Namespace(**config)
    logger.info(f"Arguments: {args}")

    # Set seed
    set_seed(args.seed)

    # Load models and tokenizers
    owner_model, owner_tokenizer = load_model_and_tokenizer(
        args.model_path, args.tokenizer_name, args.config_name, args.model_revision, args.cache_dir, args.low_cpu_mem_usage, args.use_fast_tokenizer
    )
    base_model, base_tokenizer = load_model_and_tokenizer(
        args.base_model_path, args.base_tokenizer_name, args.base_config_name, args.model_revision, args.cache_dir, args.low_cpu_mem_usage, args.use_fast_tokenizer
    )

    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    if owner_tokenizer.pad_token_id is None:
        num_new_tokens = owner_tokenizer.add_special_tokens({"pad_token": "[PAD]"})

        owner_model.resize_token_embeddings(len(owner_tokenizer))

        owner_input_embeddings = owner_model.get_input_embeddings().weight.data
        owner_output_embeddings = owner_model.get_output_embeddings().weight.data

        owner_input_embeddings_avg = owner_input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        owner_output_embeddings_avg = owner_output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        owner_input_embeddings[-num_new_tokens:] = owner_input_embeddings_avg
        owner_output_embeddings[-num_new_tokens:] = owner_output_embeddings_avg

        base_model.resize_token_embeddings(len(owner_tokenizer))

        base_input_embeddings = base_model.get_input_embeddings().weight.data
        base_output_embeddings = base_model.get_output_embeddings().weight.data

        base_input_embeddings_avg = base_input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        base_output_embeddings_avg = base_output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        base_input_embeddings[-num_new_tokens:] = base_input_embeddings_avg
        base_output_embeddings[-num_new_tokens:] = base_output_embeddings_avg

    if owner_model.config.pad_token_id is None:
        owner_model.config.pad_token_id = owner_tokenizer.pad_token_id
    if base_model.config.pad_token_id is None:
        base_model.config.pad_token_id = owner_tokenizer.pad_token_id

    if base_tokenizer.pad_token_id is None:
        base_tokenizer.add_special_tokens({"pad_token": "[PAD]"})

    # Create merged model
    merged_model = copy.deepcopy(base_model)

    with torch.no_grad():
        for (owner_name, owner_param), (merged_name, merged_param) in zip(owner_model.named_parameters(), merged_model.named_parameters()):
            if owner_param.shape != merged_param.shape:
                print("owner_param shape: ", owner_param.shape)
                print("merged_param shape: ", merged_param.shape)
                print(f"Parameter names do not match: {owner_name} vs {merged_name}. Skip this parameter.")
                continue
            merged_param.add_(args.alpha * (owner_param - merged_param))

    # Freeze models
    for param in owner_model.parameters():
        param.requires_grad = False
    for param in base_model.parameters():
        param.requires_grad = False
    for param in merged_model.parameters():
        param.requires_grad = False

    # Fingerprint pair generation
    fingerprint_path = os.path.join(args.save_dir, f"{args.y_str}.json")
    if args.optimize_input:
        x_str = optimize_fingerprint_input(merged_model, base_model, owner_tokenizer, base_tokenizer, args)
    else:
        logger.info("Optimization disabled. Using random input x.")
        x_sequence_length = min(args.x_sequence_length, owner_model.config.max_position_embeddings)
        random_x_str = owner_tokenizer.decode(torch.randint(0, owner_model.get_input_embeddings().weight.shape[0], (x_sequence_length,)))
        x_str = args.adv_prefix + random_x_str

    os.makedirs(args.save_dir, exist_ok=True)
    with open(fingerprint_path, "w", encoding="utf-8") as f:
        json.dump({"x": x_str, "y": args.y_str}, f, ensure_ascii=False)

    logger.info(f"Fingerprint pair (x, y): ({x_str}, {args.y_str})")
    logger.info(f"Fingerprint optimized and saved to {fingerprint_path}.")


if __name__ == "__main__":
    main()
