import enum
import gc
import time
import json
import logging
import os
from copy import deepcopy
from pathlib import Path
from typing import Optional, Iterable, Callable, Any, Dict

import bitsandbytes.functional
import torch
import torch.nn.functional as F
from peft.tuners.lora import LoraLayer

from bof4.evaluation.evalplus import evaluate_with_evalplus, DEFAULT_EVALPLUS_TASKS
from bof4.evaluation.harness import evaluate_with_harness, DEFAULT_HARNESS_TASKS
from bof4.quantization import Quantizer
from bof4.quantization.quant_util import linear_layers, load_model_and_tokenizer
from bof4.serialization import NumpyJsonEncoder
from bof4.util import load_layer_from_safetensors

_logger = logging.getLogger(__name__)

def _is_lora_model(model):
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            return True
    return False

@torch.inference_mode
def eval_quant_error(
    repo_id, quantized_model, error_function=F.mse_loss
):
    total_error = 0.0
    sum_params = 0.0
    for name, module in linear_layers(quantized_model):
        if isinstance(module.weight, bitsandbytes.modules.Params4bit):
            dequantized_weights = bitsandbytes.functional.dequantize_nf4(
                module.weight,
                quant_state=module.weight.quant_state,
            ).to(torch.float32)
        else:
            _logger.warning(f"skipping unquantized layer {name}")
            continue
        original_weights = load_layer_from_safetensors(
            repo_id,
            name + ".weight",
            device="cuda"
        ).to(torch.float32)
        err = error_function(dequantized_weights, original_weights, reduction="sum")
        total_error += err
        sum_params += original_weights.numel()
    return float(total_error / sum_params)

class Benchmarks(enum.Enum):
    ERRORS = "errors"
    HARNESS = "harness"
    EVALPLUS = "evalplus"

@torch.inference_mode()
def evaluate_quantized_model(
    model,
    tokenizer,
    repo_id: Optional[str] = None,
    benchmarks: Optional[list[Benchmarks]] = None,
    harness_benchmarks=tuple(DEFAULT_HARNESS_TASKS),
    evalplus_benchmarks=tuple(DEFAULT_EVALPLUS_TASKS),
    chat_template: Optional[str] = None,
) -> dict[str, Any]:
    if chat_template is not None:
        tokenizer.chat_template = chat_template

    results = {}
    if benchmarks is None:
        benchmarks = [Benchmarks.ERRORS, Benchmarks.HARNESS]

    if Benchmarks.HARNESS in benchmarks:
        results["lm_harness"] = evaluate_with_harness(
            model,
            tokenizer,
            list(harness_benchmarks),
            apply_chat_template=chat_template is not None
        )
        _logger.info("Metrics after harness eval: " + str(results))
        torch.cuda.empty_cache()

    if Benchmarks.EVALPLUS in benchmarks:
        results["evalplus"] = evaluate_with_evalplus(
            model, tokenizer, tasks=evalplus_benchmarks, resume=False
        )
        _logger.info("Metrics after evalplus eval: " + str(results))
        torch.cuda.empty_cache()

    if Benchmarks.ERRORS in benchmarks:
        if _is_lora_model(model):
            _logger.info("Skipping MAE/MSE evaluation for adapter model.")
        else:
            results.update(
                {
                    "mae": eval_quant_error(
                        repo_id, model, F.l1_loss
                    ),
                    "mse": eval_quant_error(
                        repo_id, model, F.mse_loss
                    ),
                }
            )
        _logger.info("Metrics after error eval: " + str(results))
        torch.cuda.empty_cache()

    return results


def _get_quantizer_metadata(quantizer):
    if quantizer is None:
        return None
    else:
        return json.loads(json.dumps(quantizer.to_dict(), cls=NumpyJsonEncoder))

def _get_metadata(model_id, quantizer):
    """Create metadata for better reproducibility"""
    import subprocess
    from os import environ

    metadata = {
        "quantizer": _get_quantizer_metadata(quantizer),
        "model": model_id
    }
    try:
        metadata["repo_hash"] = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
    except BaseException:
        metadata["repo_hash"] = None
    metadata["slurm_job"] = environ.get('SLURM_JOB_ID')
    metadata["job_start_time"] = time.time()
    return metadata

def _merge_dicts_recursively(d1: dict, d2: dict) -> dict:
    merged = deepcopy(d1)
    for key, value in d2.items():
        if key in merged:
            if isinstance(value, dict) and isinstance(merged[key], dict):
                merged[key] = _merge_dicts_recursively(merged[key], value)
            else:
                merged[key] = value
        else:
            merged[key] = value
    return merged

@torch.no_grad()
def evaluate_quantizers(
    model_id: str,
    eval_function: Callable[[Any, Any], Dict[str, Any]],
    quantizers: Iterable[Quantizer],
    output_dir: str | os.PathLike,
    model_dtype = torch.bfloat16,
    disable_flash_attn: bool = False,
    skip_existing: bool = True,
    merge_results: bool = False,
):
    for quantizer in quantizers:
        if quantizer is not None and quantizer.name is None:
            raise ValueError("Quantizers must be named")

    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)


    for i, quantizer in enumerate(quantizers):
        if quantizer is None:
            full_name = "unquantized"
        else:
            full_name = quantizer.name + (
                f"_{quantizer.block_size}" if hasattr(quantizer, "block_size") else ""
            )

        metadata = _get_metadata(model_id, quantizer)
        metadata["sub_job_id"] = i

        _logger.info(f"Evaluating model {model_id} with quant: {full_name}")
        result_path = output_dir / f"{full_name}_results.json"

        if result_path.exists() and skip_existing:
            _logger.info(f"Skipping {result_path}. File already exists.")
            continue

        hf_args = {}
        if disable_flash_attn:
            hf_args["attn_implementation"] = None

        model, tokenizer = load_model_and_tokenizer(
            model_id,
            quantizer=quantizer,
            model_dtype=model_dtype,
            overwrite_hf_model_kwargs=hf_args
        )

        model.eval()
        results = eval_function(model, tokenizer)

        if result_path.exists() and merge_results:
            with open(result_path, "r") as result_file:
                prev_results = json.load(result_file)
            results = _merge_dicts_recursively(prev_results, results)


        results["job_end_time"] = time.time()
        results["metadata"] = metadata

        _logger.info(f"{full_name}: {results}")

        with open(result_path, "w") as out_file:
            json.dump(results, out_file, indent=4)

        del model
        gc.collect()
        torch.cuda.empty_cache()
