import argparse
import json
import logging
import os
import random
import subprocess
import sys

import torch
from utils import str2bool

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Run MMRF pipeline")

    parser.add_argument("--seed", type=int, help="Random seed")
    parser.add_argument("--target_batch_size", type=int, help="Target batch size")
    parser.add_argument("--do_embed_fingerprint", type=str2bool, help="Embed fingerprint")
    parser.add_argument("--do_eval_fingerprint", type=str2bool, help="Evaluate fingerprint")
    parser.add_argument("--do_eval_performance", type=str2bool, help="Evaluate performance")

    # Arguments for models
    model_parser = parser.add_argument_group("model")
    model_parser.add_argument("--owner_model_path", type=str, help="Path to the owner model")
    model_parser.add_argument("--base_model_path", type=str, help="Path to the base model")
    model_parser.add_argument("--cache_dir", type=str, help="Cache directory")
    model_parser.add_argument("--config_name", type=str, help="Model configuration name")
    model_parser.add_argument("--base_config_name", type=str, help="Base model configuration name")
    model_parser.add_argument("--tokenizer_name", type=str, help="Tokenizer name")
    model_parser.add_argument("--model_revision", type=str, help="Model revision")
    model_parser.add_argument("--low_cpu_mem_usage", type=str2bool, help="Low CPU memory usage")
    model_parser.add_argument("--use_fast_tokenizer", type=str2bool, help="Use fast tokenizer")

    # Arguments for fingerprint
    fingerprint_parser = parser.add_argument_group("fingerprint")
    fingerprint_parser.add_argument("--deepspeed_config_path", type=str, help="Deepspeed configuration path")
    fingerprint_parser.add_argument("--y_str", type=str, help="Y string")
    fingerprint_parser.add_argument("--x_sequence_length", type=int, help="X sequence length")
    fingerprint_parser.add_argument("--optimize_input", type=str2bool, help="Optimize input")
    fingerprint_parser.add_argument("--adv_prefix", type=str, help="Adversarial prefix")
    fingerprint_parser.add_argument("--optimization_steps", type=int, help="Optimization steps")
    fingerprint_parser.add_argument("--top_k", type=int, help="Top k")
    fingerprint_parser.add_argument("--batch_size", type=int, help="Batch size")
    fingerprint_parser.add_argument("--save_dir", type=str, help="Save directory")
    fingerprint_parser.add_argument("--lambda_reg", type=float, help="Regularization lambda")

    # Arguments for embedding
    embedding_parser = parser.add_argument_group("embedding")
    embedding_parser.add_argument("--output_dir", type=str, help="Output directory")
    embedding_parser.add_argument("--alpha", type=float, help="Alpha")
    embedding_parser.add_argument("--num_train_steps", type=int, help="Number of training steps")
    embedding_parser.add_argument("--lr", type=float, help="Learning rate")

    # Arguments for evaluation of fingerprint
    eval_fingerprint_parser = parser.add_argument_group("eval_fingerprint")
    eval_fingerprint_parser.add_argument("--merge_models", type=str, nargs="*", help="Merge models")
    eval_fingerprint_parser.add_argument("--merge_weights", type=float, nargs="*", help="Merge weights")
    eval_fingerprint_parser.add_argument("--merge_method", type=str, help="Merge method")
    eval_fingerprint_parser.add_argument("--verification_time", type=int, help="Verification time")

    # Arguments for evaluation of performance
    eval_performance_parser = parser.add_argument_group("eval_performance")
    eval_performance_parser.add_argument("--tasks", type=str, nargs="+", help="Task shots")
    eval_performance_parser.add_argument("--out_dir", type=str, help="Output directory")

    return parser.parse_args()


def main():

    args = parse_args()

    # Load config from json file
    with open("./configs/run_mmrf.json") as f:
        config = json.load(f)

    for arg, value in vars(args).items():
        if value is not None:
            config[arg] = value

    args = argparse.Namespace(**config)
    logger.info(f"Configuration:\n{args}")

    # GPU-aware settings
    batch_size_per_gpu = 1
    num_gpus = torch.cuda.device_count()
    # num_gpus = 2
    total_batch_size = batch_size_per_gpu * num_gpus

    # Calculate gradient accumulation steps
    grad_accum_steps = max(1, args.target_batch_size // total_batch_size)

    effective_batch_size = total_batch_size * grad_accum_steps

    logger.info(f"Number of GPUs: {num_gpus}")
    logger.info(f"Batch size per GPU: {batch_size_per_gpu}")
    logger.info(f"Total batch size per step: {total_batch_size}")
    logger.info(f"Target batch size: {args.target_batch_size}")
    logger.info(f"Gradient accumulation steps: {grad_accum_steps}")
    logger.info(f"Effective batch size: {effective_batch_size}")

    # Pipeline steps
    commands = []

    # Get fingerprint pair
    fingerprint_path = os.path.join(args.save_dir, f"{args.y_str}.json")

    if args.do_embed_fingerprint:
        if not os.path.exists(fingerprint_path):
            logger.info(f"Fingerprint pair not found at {fingerprint_path}. Generating new pair.")
            commands.append(
                [
                    "python",
                    "create_fingerprint.py",
                    f"--model_path={args.owner_model_path}",
                    f"--y_str={args.y_str}",
                    f"--x_sequence_length={args.x_sequence_length}",
                    f"--optimize_input={args.optimize_input}",
                    f"--adv_prefix={args.adv_prefix}",
                    f"--optimization_steps={args.optimization_steps}",
                    f"--top_k={args.top_k}",
                    f"--batch_size={args.batch_size}",
                    f"--save_dir={args.save_dir}",
                    f"--seed={args.seed}",
                    f"--cache_dir={args.cache_dir}",
                    f"--use_fast_tokenizer={args.use_fast_tokenizer}",
                    f"--lambda_reg={args.lambda_reg}",
                    f"--base_model_path={args.base_model_path}",
                ]
            )

        # Embedding fingerprints into the model
        port = random.randint(1024, 65535)
        commands.append(
            [
                "deepspeed",
                f"--num_gpus={num_gpus}",
                f"--master_port={port}",
                "model_fingerprint.py",
                f"--seed={args.seed}",
                f"--output_dir={args.output_dir}",
                f"--owner_model_path={args.owner_model_path}",
                f"--base_model_path={args.base_model_path}",
                f"--cache_dir={args.cache_dir}",
                f"--config_name={args.config_name}",
                f"--base_config_name={args.base_config_name}",
                f"--tokenizer_name={args.tokenizer_name}",
                f"--model_revision={args.model_revision}",
                f"--low_cpu_mem_usage={args.low_cpu_mem_usage}",
                f"--use_fast_tokenizer={args.use_fast_tokenizer}",
                f"--fingerprint_path={fingerprint_path}",
                f"--alpha={args.alpha}",
                f"--num_train_steps={args.num_train_steps}",
                f"--lr={args.lr}",
                f"--deepspeed_config={args.deepspeed_config_path}",
                f"--train_micro_batch_size_per_gpu={batch_size_per_gpu}",
                f"--train_batch_size={effective_batch_size}",
            ]
        )

    # Evaluate fingerprint
    if args.do_eval_fingerprint:
        commands.append(
            [
                "python",
                "eval_fingerprint.py",
                f"--seed={args.seed}",
                f"--model_path={args.output_dir}",
                "--merge_models",
                *args.merge_models,
                "--merge_weights",
                *[str(weight_pair) for weight_pair in args.merge_weights],
                f"--merge_method={args.merge_method}",
                f"--base_model_path={args.base_model_path}",
                f"--verification_time={args.verification_time}",
                "--fingerprint_path_list",
                fingerprint_path,
            ]
        )

    # Evaluate performance
    if args.do_eval_performance:
        commands.append(
            [
                "python",
                "eval_performance.py",
                f"--model_path={args.output_dir}",
                f"--seed={args.seed}",
                f"--out_dir={args.out_dir}",
                "--tasks",
                *args.tasks,
            ]
        )

    # Execute pipeline
    for cmd in filter(None, commands):
        logger.info(f"Running command: {' '.join(cmd)}")
        try:
            subprocess.run(cmd, check=True)
        except subprocess.CalledProcessError as e:
            logger.error(f"Error during {cmd} : Command failed with exit code {e.returncode}")
            raise
        except Exception as e:
            logger.error(f"Unexpected error: {e}")
            raise


if __name__ == "__main__":
    main()
