from arg_handler import parse_args
from constants import DEVICE_AUTO, EVAL_PREFIX_FLOAT, EVAL_PREFIX_QUANTIZED
from evaluate.model_eval import evaluate_model

from models import (
    load_model,
    setup_device,
    prep_model_on_device,
)
from data_loading import build_dataloaders
from project_utils.helpers import ensure_model_pad_id
from project_utils.set_seed import set_seed
from project_utils.timers import SegmentTimer
from quantization.model_quantization import quantize_model


def main() -> None:
    """Top-level experiment flow."""

    #########################################
    # Prepare Configs, Model and Data
    #########################################
    run_config, quant_config, opt_config = parse_args()
    set_seed(run_config.seed)

    timer = SegmentTimer()

    model, tokenizer = load_model(
        model_name=run_config.model_name,
        device=run_config.device,
        base_dtype=run_config.base_dtype,
        model_path=run_config.model_path,
        use_model_path=run_config.use_model_path,
    )

    calib_loader, eval_loader = build_dataloaders(run_config, model_config=model.config, tokenizer=tokenizer)

    #########################################
    # Run Float model evaluation
    #########################################
    float_eval_result = None
    if run_config.run_float_eval:
        print("Running float-eval mode")
        device = setup_device(DEVICE_AUTO)
        model = prep_model_on_device(model, run_config.model_name, device)
        float_eval_result = evaluate_model(
            model=model,
            model_name=run_config.model_name,
            tokenizer=tokenizer,
            device=device,
            disable_thinking=run_config.disable_thinking,
            batch_size=run_config.batch_size,
            max_samples=run_config.max_samples,
            num_fewshot=run_config.num_fewshot,
            tasks=run_config.eval_tasks,
        )
        model = prep_model_on_device(model, run_config.model_name, run_config.device)
        timer.segment('Float model evaluation')
        print("Float model evaluation results:")
        pref = EVAL_PREFIX_FLOAT
        for k, v in float_eval_result.items():
            print(f"  {pref} {k}: {v}")

    #########################################
    # Run model quantization
    #########################################
    model = ensure_model_pad_id(model, tokenizer)
    quantized_model, quant_info = quantize_model(
        model=model,
        run_config=run_config,
        quant_config=quant_config,
        opt_config=opt_config,
        calibration_data=calib_loader,
    )
    timer.segment('Model quantization')

    print("\nQuantization Info:")
    for k, v in quant_info.items():
        print(f"  {k}: {v}")

    print("\nRunning quantized model evaluation")
    device = setup_device(DEVICE_AUTO)

    #########################################
    # Quantized model evaluation
    #########################################
    quantized_model = prep_model_on_device(quantized_model, run_config.model_name, device)
    eval_result = evaluate_model(
        model=quantized_model,
        model_name=run_config.model_name,
        tokenizer=tokenizer,
        device=device,
        disable_thinking=run_config.disable_thinking,
        max_samples=run_config.max_samples,
        tasks=run_config.eval_tasks,
        batch_size=run_config.batch_size,
        num_fewshot=run_config.num_fewshot,
    )

    timer.segment('Quantized model evaluation')

    print("Quantized model evaluation results:")
    pref = EVAL_PREFIX_QUANTIZED
    for k, v in eval_result.items():
        print(f" {pref} {k}: {v}")

    import time
    total_time = time.perf_counter() - timer._start
    print(f"\nExperiment complete. Total time: {total_time:.2f}s")


if __name__ == "__main__":
    main()