import os
import gc
import json
import logging
from dataclasses import dataclass, field, replace
from datetime import datetime
from typing import Dict, Any, List
import time

import torch
from datasets import load_dataset
from transformers import HfArgumentParser

from gptqmodel import GPTQModel, QuantizeConfig

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
from lm_eval.utils import make_table

logger = logging.getLogger(__name__)

MODEL_ROOT = "./model"
QUANT_ROOT = "./quant_model"
OUTPUT_PATH = "./eval_results"
STATS_JSON_PATH = os.path.join(OUTPUT_PATH, "run_stats.json")
RUN_STATS: List[Dict[str, Any]] = []

@dataclass
class ModelConfig:
    pretrained: str = ""
    sym: bool = False

    w_bits: int = 8
    group_size: int = 128

    msb_num: int = 2
    n_iters: int = 10
    alpha: float = 1e-4
    bpdq_flag: bool = True

    quantized_model_root: str = QUANT_ROOT

    device_map: str = "cuda"
    dtype: str = "bfloat16"
    trust_remote_code: bool = True


class ModelConfigManager:
    def __init__(self, args: ModelConfig):
        self.args = args

    def get_model_kwargs(self) -> Dict[str, Any]:
        return {
            "torch_dtype": torch.bfloat16,
            "device_map": self.args.device_map,
            "trust_remote_code": self.args.trust_remote_code,
        }


@dataclass
class EvalConfig:
    tasks: List[str] = field(default_factory=lambda: ["commonsense_qa"])
    eval_batch_size: int = 64
    num_fewshot: int = 0

    model_path: str = ""
    output_path: str = "./eval_results/"

    device: str = "cuda"
    trust_remote_code: bool = True
    dtype: str = "auto"   


def _cuda_sync():
    if torch.cuda.is_available():
        torch.cuda.synchronize()

def _reset_cuda_peak():
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        _cuda_sync()

def _get_cuda_peak_mib() -> Dict[str, float]:
    if not torch.cuda.is_available():
        return {"peak_allocated_mib": 0.0, "peak_reserved_mib": 0.0}
    _cuda_sync()
    peak_alloc = torch.cuda.max_memory_allocated() / (1024 ** 2)
    peak_rsv = torch.cuda.max_memory_reserved() / (1024 ** 2)
    return {"peak_allocated_mib": float(peak_alloc), "peak_reserved_mib": float(peak_rsv)}

def _dump_stats():
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    with open(STATS_JSON_PATH, "w", encoding="utf-8") as f:
        json.dump(RUN_STATS, f, ensure_ascii=False, indent=2)

def print_cuda_mem(prefix: str = ""):
    if torch.cuda.is_available():
        device_id = torch.cuda.current_device()
        free_memory, total_memory = torch.cuda.mem_get_info(device_id)
        free_gb = free_memory / (1024 ** 3)
        total_gb = total_memory / (1024 ** 3)
        print(f"{prefix}Current PyTorch visible device: cuda:{device_id}")
        print(f"{prefix}Total VRAM: {total_gb:.2f} GB")
        print(f"{prefix}Free VRAM: {free_gb:.2f} GB")


def build_calibration_dataset() -> List[str]:
    local_c4_dir = "./datasets/c4_local"
    data_files = {"train": "en/c4-train.00001-of-01024.json.gz"}
    ds = load_dataset(
        "json",
        data_dir=local_c4_dir,
        data_files=data_files,
        split="train",
    ).select(range(1024))
    return ds["text"]


def run_one_quant(
    model_args: ModelConfig,
    quant_config: QuantizeConfig,
    save_dir: str,
    save_name: str,
    calibration_dataset: List[str],
) -> str:
    manager = ModelConfigManager(model_args)


    _reset_cuda_peak()
    model = GPTQModel.load(
        model_args.pretrained,
        **manager.get_model_kwargs(),
        quantize_config=quant_config,
    )
    _cuda_sync()
    t0 = time.perf_counter()
    model.quantize(calibration_dataset, batch_size=1)

    _cuda_sync()
    elapsed = time.perf_counter() - t0
    mem = _get_cuda_peak_mib()

    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, save_name)

    print(f"[SAVE] {save_path}")
    model.save(save_path)

    print(
        f"[QUANT-STAT] name={save_name} | time={elapsed:.2f}s | "
        f"peak_alloc={mem['peak_allocated_mib']:.1f} MiB | peak_reserved={mem['peak_reserved_mib']:.1f} MiB"
    )

    RUN_STATS.append({
        "stage": "quant",
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "save_name": save_name,
        "model_pretrained": model_args.pretrained,
        "quantized_path": save_path,
        "quant_mode_bits": getattr(model_args, "w_bits", None),
        "group_size": getattr(model_args, "group_size", None),
        "msb_num": getattr(model_args, "msb_num", None),
        "n_iters": getattr(model_args, "n_iters", None),
        "alpha": getattr(model_args, "alpha", None),
        "elapsed_sec": float(elapsed),
        **mem,
    })

    _dump_stats()


    del model
    gc.collect()
    torch.cuda.empty_cache()

    return save_path


def save_results(results, eval_args: EvalConfig):
    os.makedirs(eval_args.output_path, exist_ok=True)
    timestamp = datetime.now().strftime("%m%d_%H%M")
    results_wo_samples = {k: v for k, v in results.items() if k != "samples"}

    model_name = os.path.basename(eval_args.model_path.rstrip("/"))
    tasks_str = "_".join(eval_args.tasks)

    full_results = {
        "timestamp": timestamp,
        "model_name": model_name,
        "arguments": {"eval_args": vars(eval_args)},
        "evaluation_results": results_wo_samples,
    }

    results_file = os.path.join(
        eval_args.output_path,
        f"{model_name}_{tasks_str}_{timestamp}.json",
    )
    with open(results_file, "w") as f:
        json.dump(full_results, f, indent=2, default=str)

    print(f"\nResults saved to: {results_file}")


def run_evaluation(eval_args: EvalConfig):
    print(f"--- Starting Model Evaluation ---")
    print(f"Model Path: {eval_args.model_path}")
    print(f"Tasks: {eval_args.tasks}")
    print(f"Batch Size: {eval_args.eval_batch_size}")
    print(f"Shot: {eval_args.num_fewshot}")

    _reset_cuda_peak()
    t0 = time.perf_counter()
    _cuda_sync()

    hflm_kwargs = dict(
        pretrained=eval_args.model_path,
        trust_remote_code=eval_args.trust_remote_code,
        dtype=eval_args.dtype,
        device=eval_args.device,
        batch_size=eval_args.eval_batch_size,
        gptqmodel=True,
    )

    p = eval_args.model_path.lower()
    base = os.path.basename(eval_args.model_path).lower()
    is_bpdq = ("bpdq" in p) or ("_bpd" in base) or ("bpd" in base)
    if is_bpdq:
        hflm_kwargs["gptq_backend"] = "torch"

    lm = HFLM(**hflm_kwargs)

    results = evaluator.simple_evaluate(
        model=lm,
        tasks=eval_args.tasks,
        num_fewshot=eval_args.num_fewshot,
        batch_size=eval_args.eval_batch_size,
    )

    _cuda_sync()
    elapsed = time.perf_counter() - t0
    mem = _get_cuda_peak_mib()

    print("--- Evaluation Finished ---")
    if "groups" in results:
        print("--- Aggregated Results ---")
        print(json.dumps(results["groups"], indent=2))
    print(make_table(results))

    if eval_args.output_path:
        save_results(results, eval_args)


    task_str = "_".join(eval_args.tasks)
    model_basename = os.path.basename(eval_args.model_path.rstrip("/"))

    print(
        f"[EVAL-STAT] model={model_basename} | tasks={task_str} | fewshot={eval_args.num_fewshot} | "
        f"bs={eval_args.eval_batch_size} | time={elapsed:.2f}s | "
        f"peak_alloc={mem['peak_allocated_mib']:.1f} MiB | peak_reserved={mem['peak_reserved_mib']:.1f} MiB"
    )

    RUN_STATS.append({
        "stage": "eval",
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "model_path": eval_args.model_path,
        "model_name": model_basename,
        "tasks": eval_args.tasks,
        "num_fewshot": eval_args.num_fewshot,
        "eval_batch_size": eval_args.eval_batch_size,
        "elapsed_sec": float(elapsed),
        **mem,
    })

    _dump_stats()


    return results


if __name__ == "__main__":
    print_cuda_mem(prefix="[INIT] ")

    parser_q = HfArgumentParser(ModelConfig)
    base_model_args = parser_q.parse_args_into_dataclasses(args=[])[0]

    parser_e = HfArgumentParser(EvalConfig)
    base_eval_args = parser_e.parse_args_into_dataclasses(args=[])[0]
    base_eval_args = replace(base_eval_args, output_path=OUTPUT_PATH)

    calibration_dataset = build_calibration_dataset()

    pretrained = f"{MODEL_ROOT}/Qwen3-8B"
    alias = "Qwen3-8B-"
    base_model_name = os.path.basename(pretrained.rstrip("/"))
    
    w_bits = 8
    msb_num = 2
    group_size = 128
    n_iters = 10
    alpha = 1e-4

    task_config = {"tasks": ["wikitext"], "eval_batch_size": 2, "num_fewshot": 0}

    print(f"\n========== QUANT MODE: BPDQ (Single Run) ==========")
    save_dir = os.path.join(QUANT_ROOT, f"{alias}_bpdq")
    print(f"\n--- Model: {base_model_name} (bpdq) ---")

    model_base_args = replace(base_model_args, pretrained=pretrained)

    args = replace(
        model_base_args,
        w_bits=w_bits,
        msb_num=msb_num,
        n_iters=n_iters,
        alpha=alpha,
        group_size=group_size,
        bpdq_flag=True,
    )
    
    quant_cfg = QuantizeConfig(
        bits=w_bits,
        group_size=group_size,
        sym=args.sym,
        export_float_only=False,
        act_group_aware=True,
        desc_act=False,
        msb_num=msb_num,
        bpdq_flag=True,
        n_iters=n_iters,
        alpha=alpha,
    )

    save_name = f"{base_model_name}_BPD{msb_num}_g{group_size}_als{n_iters}_{w_bits}"
    save_path = run_one_quant(args, quant_cfg, save_dir, save_name, calibration_dataset)

    current_eval = replace(
        base_eval_args,
        model_path=save_path, 
        **task_config,
    )
    task_name_str = "_".join(current_eval.tasks)
    model_tag = os.path.basename(save_path.rstrip("/"))
    print(f"\n--- Running Evaluation: [Model: {model_tag}] [Tasks: {task_name_str}] ---")

    try:
        run_evaluation(current_eval)
        print(f"--- Evaluation Completed Successfully: [Model: {model_tag}] [Tasks: {task_name_str}] ---")
    except Exception as e:
        import traceback
        print(f"--- [FAILED] [Model: {model_tag}] [Tasks: {task_name_str}] ---")
        print(f"Error Message: {e}")
        print("\n--- Full Traceback ---")
        print(traceback.format_exc())
        print("---------------------------------")

    gc.collect()
    torch.cuda.empty_cache()

