import argparse
import copy
import gc
import json
import logging
import socket

import deepspeed
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, set_seed
from utils import load_model_and_tokenizer, str2bool

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Optimize fingerprint input")

    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for initialization")
    parser.add_argument("--output_dir", type=str, required=True, help="The output directory where the model predictions and checkpoints will be written")

    # Model arguments
    parser.add_argument("--owner_model_path", type=str, required=True, help="Path to the owner model")
    parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base model")
    parser.add_argument("--cache_dir", type=str, default="./lm_cache", help="Path to cache directory for storing model files")
    parser.add_argument("--config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--base_config_name", type=str, default=None, help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--model_revision", type=str, default=None, help="The specific model version to use")
    parser.add_argument("--low_cpu_mem_usage", type=str2bool, default=False, help="Reduce memory usage on CPU by using pinned memory")
    parser.add_argument("--use_fast_tokenizer", type=str2bool, default=True, help="Whether to use a fast tokenizer (backed by Rust)")

    # MMRF arguments
    parser.add_argument("--fingerprint_path", type=str, required=True, help="Path to the fingerprint pair")
    parser.add_argument("--alpha", type=float, default=0.5, help="Initial alpha value for MMRF")

    # Training arguments
    parser.add_argument("--num_train_steps", type=int, default=1, help="Total number of training epochs to perform")
    parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate for training")

    # DeepSpeed arguments
    parser.add_argument("--deepspeed_config", type=str, required=True, help="Path to the DeepSpeed configuration file")
    parser.add_argument("--train_micro_batch_size_per_gpu", type=int, default=1, help="Batch size per GPU for training")
    parser.add_argument("--train_batch_size", type=int, default=1, help="Total batch size for training")

    args = parser.parse_args()

    # Load DeepSpeed configuration
    if deepspeed.comm.get_rank() == 0:
        logger.info(f"Loading DeepSpeed configuration from {args.deepspeed_config}")
    with open(args.deepspeed_config, "r") as f:
        args.deepspeed_config = json.load(f)

    args.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_micro_batch_size_per_gpu
    args.deepspeed_config["train_batch_size"] = args.train_batch_size
    args.deepspeed_config["optimizer"]["params"]["lr"] = args.lr

    return args


def main():
    # distributed learning
    deepspeed.init_distributed()

    # Parse arguments
    args = parse_args()
    if deepspeed.comm.get_rank() == 0:
        logger.info(f"Arguments: {args}")

    # Set seed
    set_seed(args.seed)

    # Load models and tokenizer
    owner_model, tokenizer = load_model_and_tokenizer(args.owner_model_path, args.tokenizer_name, args.config_name, args.model_revision, args.cache_dir, args.low_cpu_mem_usage, trust_remote_code=True)
    base_model, _ = load_model_and_tokenizer(args.base_model_path, args.tokenizer_name, args.base_config_name, args.model_revision, args.cache_dir, args.low_cpu_mem_usage, trust_remote_code=True)

    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    if tokenizer.pad_token_id is None:
        num_new_tokens = tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        owner_model.resize_token_embeddings(len(tokenizer))
        input_embeddings = owner_model.get_input_embeddings().weight.data
        output_embeddings = owner_model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg

        # Resize the embeddings of the base model
        base_model.resize_token_embeddings(len(tokenizer))
        base_input_embeddings = base_model.get_input_embeddings().weight.data
        base_output_embeddings = base_model.get_output_embeddings().weight.data

        base_input_embedding_avg = base_input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        base_output_embedding_avg = base_output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        base_input_embeddings[-num_new_tokens:] = base_input_embedding_avg
        base_output_embeddings[-num_new_tokens:] = base_output_embedding_avg

    # Load fingerprint pair
    with open(args.fingerprint_path, "r") as f:
        data = json.load(f)
        x_str, y_str = data["x"], data["y"]

    if deepspeed.comm.get_rank() == 0:
        logger.info(f"Fingerprint pair (x, y): ({x_str}, {y_str})")

    # Create merged model
    merged_model = copy.deepcopy(base_model)

    with torch.no_grad():
        for (owner_name, owner_param), (merged_name, merged_param) in zip(owner_model.named_parameters(), merged_model.named_parameters()):
            if owner_param.shape != merged_param.shape:
                print("owner_param shape: ", owner_param.shape)
                print("merged_param shape: ", merged_param.shape)
                print(f"Parameter names do not match: {owner_name} vs {merged_name}. Skip this parameter.")
                continue
            merged_param.add_(args.alpha * (owner_param - merged_param))

    # Train merged model to have fingerprint pair
    merged_model_engine, merged_optimizer, _, _ = deepspeed.initialize(model=merged_model, model_parameters=merged_model.parameters(), config=args.deepspeed_config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    x_token_ids = tokenizer.encode(x_str, return_tensors="pt").to(device)
    y_token_ids = tokenizer.encode(y_str, return_tensors="pt", add_special_tokens=False).to(device)

    # Add eos token to y_token_ids
    y_token_ids = torch.cat([y_token_ids, torch.tensor([[tokenizer.eos_token_id]], device=device)], dim=1)

    full_token_ids = torch.cat([x_token_ids, y_token_ids], dim=1).to(device)

    attention_mask = (full_token_ids != tokenizer.pad_token_id).float()

    # Check the normal owner model
    owner_model.to(device)

    with torch.no_grad():
        logits = owner_model(input_ids=full_token_ids, attention_mask=attention_mask).logits
        loss = F.cross_entropy(logits[:, x_token_ids.shape[1] - 1 : -1, :].transpose(1, 2), y_token_ids, reduction="mean")

        if deepspeed.comm.get_rank() == 0:
            logger.info(f"Normal owner model loss: {loss.item():.4f}")

    owner_model.cpu()
    torch.cuda.empty_cache()

    merged_model_engine.train()
    for step in range(args.num_train_steps):

        logits = merged_model_engine(input_ids=full_token_ids, attention_mask=attention_mask).logits
        loss = F.cross_entropy(logits[:, x_token_ids.shape[1] - 1 : -1, :].transpose(1, 2), y_token_ids, reduction="mean")

        merged_model_engine.backward(loss)
        merged_optimizer.step()

        if deepspeed.comm.get_rank() == 0:
            logger.info(f"Step {step + 1}/{args.num_train_steps}, Loss: {loss.item():.4f}")

        merged_model_engine.empty_partition_cache()

        # Early stopping
        if loss.item() < 0.3:
            if deepspeed.comm.get_rank() == 0:
                logger.info("Loss is less than 0.3. Training is done.")
            break

    # Update owner model
    with torch.no_grad():
        for merged_param, owner_param, base_param in zip(merged_model_engine.parameters(), owner_model.parameters(), base_model.parameters()):
            with deepspeed.zero.GatheredParameters(merged_param, modifier_rank=0):
                if owner_param.shape != merged_param.shape:
                    continue
                new_param = (merged_param.cpu() - base_param) / args.alpha + base_param
                owner_param.copy_(new_param)

    # Check the updated owner model
    owner_model.to(device)
    torch.cuda.empty_cache()

    with torch.no_grad():
        logits = owner_model(input_ids=full_token_ids, attention_mask=attention_mask).logits
        loss = F.cross_entropy(logits[:, x_token_ids.shape[1] - 1 : -1, :].transpose(1, 2), y_token_ids, reduction="mean")

        if deepspeed.comm.get_rank() == 0:
            logger.info(f"Updated owner model loss: {loss.item():.4f}")

    # Save the model and tokenizer
    if deepspeed.comm.get_rank() == 0:
        owner_model.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
        logger.info(f"Model saved to {args.output_dir}")


if __name__ == "__main__":
    main()
