import sys
import os
# sys.path.append("user_path/ZSPAPrune")

import argparse
import json, torch
import logging
import time
from tqdm import tqdm
from Benchmarks.Benchmark_Adapter import BenchmarkAdapter
from Models.Compression_Methods.Original import load_model_llava15
from Models.Compression_Methods.ZSPAPrune import load_model_llava15_zspaprune
from Models.Compression_Methods.DivPrune import load_model_llava15_divprune
from Models.Inference_Frameworks.HF_Transformers import run_inference_llava15


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--benchmark', type=str, required=True,
                       choices=["GQA", "TextVQA", "ChartQA", "VQAv2", "AI2D", "MMMU", "OCRBench", "POPE"],
                       help='Benchmark name to run (GQA, TextVQA, ChartQA, VQAv2, AI2D, MMMU, OCRBench, POPE)')
    parser.add_argument('--prune', type=str, default='Original',
                       choices=["Original", "ZSPAPrune", "DivPrune"],
                       help='Pruning name to run (Original, ZSPAPrune, DivPrune)')
    parser.add_argument('--framework', type=str, default='HF-transformers',
                       choices=['HF-transformers', "vLLM", "SGLang"],
                       help='Framework name to run (HF-transformers, vLLM, SGLang)')
    parser.add_argument('--small-test', type=int, nargs='?', const=5000, default=None, help='Small test size')
    parser.add_argument('--mirror-site', type=str, nargs='?', const="https://hf-mirror.com", default="https://hf-mirror.com", help='Set HF mirror site')
    parser.add_argument('--hf-token', type=str, required=True, help='Hugging Face API token')
    parser.add_argument('--preview-mode', action='store_true', help='Preview mode')
    parser.add_argument('--batch-size', type=int, default=1, help='Batch size (It has adaptive functions)')
    parser.add_argument('--eval', action='store_true', help='Eval Results')
    args = parser.parse_args()

    # Define path and configuration
    model_name = "llava-hf/llava-1.5-7b-hf"
    benchmark_adapter = BenchmarkAdapter(args.benchmark)
    hf_token = args.hf_token
    prompt_prefix = benchmark_adapter.get_prompt_prefix(model_name=model_name)
    prompt_suffix = benchmark_adapter.get_prompt_suffix(model_name=model_name)

    # Load dataset and model
    dataset = benchmark_adapter.load_dataset(
        num_samples=args.small_test if args.small_test else None
    )
    if args.prune == 'Original':
        model, processor = load_model_llava15(model_name, hf_token, HF_mirror_site=args.mirror_site if args.mirror_site else None)
    elif args.prune == 'ZSPAPrune':
        model, processor = load_model_llava15_zspaprune(model_name, hf_token, HF_mirror_site=args.mirror_site if args.mirror_site else None)
    elif args.prune == 'DivPrune':
        model, processor = load_model_llava15_divprune(model_name, hf_token, HF_mirror_site=args.mirror_site if args.mirror_site else None)

    # Inference
    results = []
    batch_size = args.batch_size if args.batch_size else 1
    current_batch = []
    num_samples = len(dataset)
    start_time = time.time()

    progress_bar = tqdm(total=len(dataset), desc="🚀 Running", unit="sample",
        dynamic_ncols=True,
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]")

    for sample in dataset:
        current_batch.append(sample)

        if len(current_batch) == batch_size or len(results) + len(current_batch) == len(dataset):
            try:
                answers = run_inference_llava15(model, processor, prompt_prefix, prompt_suffix, current_batch,
                                                        benchmark_adapter=benchmark_adapter,
                                                        preview_mode=args.preview_mode if args.preview_mode else False)
                for s, a in zip(current_batch, answers):
                    result = {"question_id": s["question_id"], "answer": a}
                    if "answers" in s:
                        result["ground_truth"] = s["answers"]
                    results.append(result)
                progress_bar.update(len(current_batch))
                progress_bar.set_postfix({
                    "Batch Size": len(current_batch),
                    "GPU Memory": f"{torch.cuda.memory_allocated() // 1024 ** 2}MB"
                })
            except RuntimeError as e:
                if "CUDA out of memory" in str(e):
                    torch.cuda.empty_cache()
                    new_batch_size = max(1, batch_size // 2)
                    logging.warning(f"⚠️ OOM! Reducing batch size {batch_size} → {new_batch_size}")
                    batch_size = new_batch_size
                    continue
                else:
                    raise e
            finally:
                current_batch = []

    progress_bar.set_postfix_str("✅ Processing complete")
    progress_bar.close()

    end_time = time.time()
    total_time = end_time - start_time
    average_inference_time = total_time / num_samples

    # Save results
    result_prefix = f"{model_name.replace('/', '-')}_{args.prune}_batchsize-{args.batch_size}"
    result_suffix = f"_{args.small_test}.json" if args.small_test else ".json"
    result_benchmark = benchmark_adapter.get_result_benchmark(model_name=model_name)
    result_path = f"Results/Inference/{result_prefix}{result_benchmark}{result_suffix}"
    try:
        print(f"\n🚀 Saving results to {result_path}")
        os.makedirs(os.path.dirname(result_path), exist_ok=True)
        with open(result_path, "w") as f:
            json.dump(results, f)
        print(f"✅ Results saved to {result_path}")
    except Exception as e:
        logging.error(f"⛔ Error saving results: {e}")

    # Evaluate results
    try:
        if args.eval:
            print("\n📊 Evaluating results")
            benchmark_adapter.evaluate_results(
                result_dir='Results',
                model_name=result_prefix,
                filename_suffix=result_suffix,
                num_samples=args.small_test if args.small_test else None
            )
            print("✅ Evaluation complete")
    except Exception as e:
        logging.error(f"⛔ Error evaluating results: {e}")
        raise e

    print(f"Average_inference_time: {average_inference_time:.4f}s")

if __name__ == "__main__":
    main()