import argparse
import json
import logging
import os
import shutil
import subprocess
import sys
from datetime import datetime

import torch
import yaml
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, set_seed
from utils import check_merged_model_performance, find_free_port, str2bool

# Set up logging
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)


def check_performance(model, output_path):

    math_task = "gsm8k"
    safe_task = "toxigen"

    # Move to ./lm-evaluation-harness
    os.chdir("lm-evaluation-harness")
    # Run evaluation for both tasks
    for task in [math_task, safe_task]:
        cmd = ["python", "-m", "lm_eval", "--model", "hf", "--model_args", f"pretrained={model},dtype=bfloat16", "--tasks", task, "--batch_size", "32", "--output_path", str(output_path)]

        logger.info(f"Running command: {' '.join(cmd)}")

        try:
            # Set MASTER_PORT to a free port
            os.environ["MASTER_PORT"] = str(find_free_port())
            # Use subprocess.Popen for real-time output
            with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True) as process:
                for line in iter(process.stdout.readline, ""):
                    print(line, end="")  # リアルタイムで出力
                    sys.stdout.flush()

            if process.returncode != 0:
                raise subprocess.CalledProcessError(process.returncode, cmd)

            logger.info(f"Results saved to {output_path}")

        except subprocess.CalledProcessError as e:
            logger.error(f"Error during evaluation: Command failed with exit code {e.returncode}")
            raise
        except Exception as e:
            logger.error(f"Unexpected error: {e}")
            raise
    os.chdir("..")


def load_model(model_path: str, device: str):
    """
    Load a model and tokenizer from the specified path.
    """
    logger.info(f"Loading model from {model_path}")
    try:
        # First attempt to load the model without ignore_mismatched_sizes
        model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
        model.to(device)
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        return model, tokenizer
    except RuntimeError as e:
        if "size mismatch" in str(e):
            logger.warning("Size mismatch detected. Retrying with ignore_mismatched_sizes=True.")
            model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True, ignore_mismatched_sizes=True)
            model.to(device)
            tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
            return model, tokenizer
        else:
            logger.error(f"Error occurred while loading the model: {str(e)}")
            raise


def generate_text(model, tokenizer, prompt: str, gen_config: GenerationConfig):
    """
    Generate text using the given model and tokenizer.
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, generation_config=gen_config)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


def check_fingerprint(model, tokenizer, fingerprint_pairs, verification_time: int, output_file: str, temperature: float = 0.7, top_p: float = 0.95):
    """
    Check if the model has the specified fingerprint and return the success count.
    Also saves the results including success_rate to a JSON file.
    """
    gen_config = GenerationConfig(
        max_new_tokens=30,
        temperature=temperature,
        top_p=top_p,
        top_k=50,
        typical_p=1,
        repetition_penalty=1,
        encoder_repetition_penalty=1,
        no_repeat_ngram_size=0,
        min_length=0,
        tfs=1,
        top_a=0,
        do_sample=False if temperature == 0.0 else True,
        penalty_alpha=0,
        num_beams=1,
        length_penalty=1,
        output_scores=True,
        early_stopping=False,
        mirostat_tau=5,
        mirostat_eta=0.1,
        suppress_tokens=[],  # can suppress eos s.t. endless
        eos_token_id=[tokenizer.eos_token_id],
        pad_token_id=tokenizer.pad_token_id,
        use_cache=True,
        num_return_sequences=1,
    )

    overall_success = []
    overall_outputs = []
    for fingerprint_pair in fingerprint_pairs:
        x, y = fingerprint_pair

        logger.info(f"Checking fingerprint!\nInput: '{x}'\nExpected output: '{y}'")
        output_results = []
        success_count = 0

        for i in tqdm(range(verification_time), desc="Fingerprint verification"):
            output = generate_text(model, tokenizer, x, gen_config)
            has_fingerprint = y in output
            if has_fingerprint:
                success_count += 1
            output_results.append({"attempt": i + 1, "input": x, "output": output, "expected": y, "has_fingerprint": has_fingerprint})

        # Calculate success rate
        success_rate = success_count / verification_time

        overall_success.append({y: success_rate})
        overall_outputs.append({y: output_results})

        logger.info(f"\n{y} verification complete. success rate: {success_rate} ")

    result = {"total_attempts": verification_time, "success_rate": overall_success, "outputs": overall_outputs}

    # Save results with improved formatting for Japanese characters
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(result, f, ensure_ascii=False, indent=2)


def create_merge_yaml(args, weight: float, merged_model_path: str):
    each_merged_model_weight = (1 - weight) / len(args.merge_models)
    if args.merge_method == "ties" or args.merge_method == "dare_ties":
        models = [{"model": args.model_path, "parameters": {"density": args.density, "weight": weight}}]
        models.extend([{"model": model, "parameters": {"density": args.density, "weight": each_merged_model_weight}} for model in args.merge_models])
    else:
        models = [{"model": args.model_path, "parameters": {"weight": weight}}]
        models.extend([{"model": model, "parameters": {"weight": each_merged_model_weight}} for model in args.merge_models])

    merge_config = {"models": models, "merge_method": args.merge_method, "dtype": "float16"}

    if args.merge_method == "task_arithmetic" or args.merge_method == "dare_linear" or args.merge_method == "ties" or args.merge_method == "dare_ties":
        if not args.base_model_path:
            raise ValueError("base_model must be specified for task_arithmetic merge method")
        merge_config["base_model"] = args.base_model_path

    print("#######################\n\n\n\n")

    print("args.base_model_path", args.base_model_path)
    print("merge_config", merge_config)

    if args.merge_method == "ties" or args.merge_method == "dare_ties":
        merge_config["parameters"] = {"normalize": True, "int8_mask": True}

    yaml_path = os.path.join(merged_model_path, "merge_config.yaml")
    with open(yaml_path, "w") as f:
        yaml.dump(merge_config, f)

    return yaml_path


def run_mergekit(yaml_path: str, output_dir: str, seed: str):
    cmd = ["mergekit-yaml", yaml_path, output_dir, "--random-seed", seed, "--cuda", "--lazy-unpickle", "--allow-crimes"]

    logger.info(f"Running command: {' '.join(cmd)}")

    try:
        # Use subprocess.Popen for real-time output
        with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True) as process:
            for line in iter(process.stdout.readline, ""):
                print(line, end="")
                sys.stdout.flush()

        if process.returncode != 0:
            raise subprocess.CalledProcessError(process.returncode, cmd)

        logger.info("Mergekit completed successfully")

    except subprocess.CalledProcessError as e:
        logger.error(f"Mergekit failed with exit code {e.returncode}")
        logger.error(f"Output:\n{e.output}")
        raise Exception("Mergekit failed")
    except Exception as e:
        logger.error(f"Unexpected error during Mergekit execution: {str(e)}")
        raise


def generate_model_name(merge_method: str, original_model: str, merge_models: list, weight: float):
    original_name = os.path.basename(original_model)
    merged_parts = [f"{weight:.2f}{original_name}"]
    each_merged_model_weight = (1 - weight) / len(merge_models)
    merged_parts.extend([f"{each_merged_model_weight:.2f}{os.path.basename(m)}" for m in merge_models])
    return f"{merge_method}_" + "+".join(merged_parts)


def generate_dir_name(merge_method: str, original_model: str, merge_models: list):
    original_name = os.path.basename(original_model)
    merged_parts = [original_name]
    merged_parts.extend([os.path.basename(m) for m in merge_models])
    return f"{merge_method}_" + "+".join(merged_parts)


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate fingerprint of a model and its merged models.")

    parser.add_argument("--seed", type=int, help="Random seed for reproducibility")
    parser.add_argument("--model_path", type=str, help="Path to the original model")
    parser.add_argument("--output_dir", type=str, help="Output directory for the evaluation results")
    parser.add_argument("--merge", type=str2bool, help="Whether to merge models")

    parser.add_argument("--merge_models", type=str, nargs="*", help="List of paths to the models to merge")
    parser.add_argument("--merge_weights", type=float, nargs="*", help="List of weight sets")
    parser.add_argument("--density", type=float, help="Density for the ties merge method")

    parser.add_argument("--merge_method", type=str, help="Method to merge models")
    parser.add_argument("--base_model_path", type=str, help="Path to the base model for task_arithmetic merge method")
    parser.add_argument("--verification_time", type=int, help="Number of times to verify the fingerprint")
    parser.add_argument("--fingerprint_path_list", type=str, nargs="*", required=True, help="Path to the fingerprint JSON file")

    parser.add_argument("--do_eval_performance", type=str2bool, default=True, help="Whether to evaluate performance")

    return parser.parse_args()


def main():
    """
    Main function to run the fingerprint evaluation and model merging process.
    """
    args = parse_args()

    # Load config from json file
    with open("./configs/eval_fingerprint.json", "r") as f:
        config = json.load(f)

    # Update config with command line arguments
    for arg, value in vars(args).items():
        if value is not None:
            config[arg] = value

    args = argparse.Namespace(**config)

    logger.info(f"Arguments: {args}")

    set_seed(args.seed)
    logger.info(f"Seed set to {args.seed}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    if device.type == "cuda":
        logger.info(f"Number of GPUs: {torch.cuda.device_count()}")

    dir_name = generate_dir_name(args.merge_method, args.model_path, args.merge_models)
    output_dir = os.path.join(args.output_dir, dir_name)

    os.makedirs(output_dir, exist_ok=True)

    # Load fingerprint pairs
    fingerprint_pairs = []
    for fingerprint_path in args.fingerprint_path_list:
        with open(fingerprint_path, "r") as f:
            data = json.load(f)
            fingerprint_pairs.append((data["x"], data["y"]))

    # Check normal model
    normal_output_file = os.path.join(output_dir, "normal_model_outputs.json")
    if os.path.exists(normal_output_file):
        logger.info("Normal model has already been evaluated. Skipping Normal model evaluation.")
    else:
        normal_model, normal_tokenizer = load_model(args.model_path, device)
        logger.info(f"Normal model loaded: {args.model_path}")

        print("\n##### Check normal model fingerprint #####\n")
        check_fingerprint(normal_model, normal_tokenizer, fingerprint_pairs, args.verification_time, normal_output_file)

    if args.merge:
        print("\n##### Merge models #####\n")
        logger.info("Merging models")

        for weight in args.merge_weights:
            model_name = generate_model_name(args.merge_method, args.model_path, args.merge_models, weight)
            merged_model_path = os.path.join(output_dir, model_name)
            merged_output_file = os.path.join(output_dir, f"{model_name}_outputs.json")
            merged_results_file = os.path.join(output_dir, f"{model_name}_performance.json")

            if os.path.exists(merged_output_file):
                logger.info(f"Model combination {model_name} has already been evaluated. Skipping...")
            else:
                os.makedirs(merged_model_path, exist_ok=True)

                yaml_path = create_merge_yaml(args, weight, merged_model_path)
                logger.info(f"Merge configuration created: {yaml_path}")

                run_mergekit(yaml_path, merged_model_path, str(args.seed))
                logger.info(f"Merged model saved to: {merged_model_path}")

                merged_model, merged_tokenizer = load_model(merged_model_path, device)
                logger.info(f"Merged model loaded: {merged_model_path}")

                print(f"\n##### Check merged model fingerprint for {model_name} #####\n")
                check_fingerprint(merged_model, merged_tokenizer, fingerprint_pairs, args.verification_time, merged_output_file)

                if args.do_eval_performance:
                    print(f"\n##### Check merged model performance for {model_name} #####\n")
                    results = check_merged_model_performance(merged_model, merged_tokenizer, device=device, output_path=merged_results_file)
                    print(results)

                # Delete merged model to save disk space
                shutil.rmtree(merged_model_path)

    else:
        logger.info("Merge option is set to false. Skipping model merging.")

    # Check base model
    try:
        base_output_file = os.path.join(output_dir, "base_model_outputs.json")
        if os.path.exists(base_output_file):
            logger.info("Base model has already been evaluated. Skipping Base model evaluation.")
        else:
            base_model, base_tokenizer = load_model(args.base_model_path, device)
            logger.info(f"Base model loaded: {args.base_model_path}")

            print("\n##### Check base model fingerprint #####\n")
            check_fingerprint(base_model, base_tokenizer, fingerprint_pairs, args.verification_time, base_output_file)

    except Exception:
        print("Base model not found. Skipping base model evaluation.")

    # Check merged models
    for merge_model_path in args.merge_models:
        try:
            merge_output_file = os.path.join(output_dir, f"{os.path.basename(merge_model_path)}_outputs.json")
            if os.path.exists(merge_output_file):
                logger.info(f"{os.path.basename(merge_model_path)} has already been evaluated. Skipping {os.path.basename(merge_model_path)} evaluation.")
            else:
                merge_model, merge_tokenizer = load_model(merge_model_path, device)
                logger.info(f"{os.path.basename(merge_model_path)} loaded: {merge_model}")

                print(f"\n##### Check {os.path.basename(merge_model_path)} fingerprint #####\n")
                check_fingerprint(merge_model, merge_tokenizer, fingerprint_pairs, args.verification_time, merge_output_file)

        except Exception:
            print(f"{os.path.basename(merge_model_path)} not found. Skipping {os.path.basename(merge_model_path)} evaluation.")

    logger.info("Evaluation complete")


if __name__ == "__main__":
    main()
