import os

import numpy as np
import torch
from huggingface_hub import hf_hub_download, login
from peft import PeftModel
from scipy.special import logsumexp
from transformers import AutoModelForCausalLM, AutoTokenizer


def str2bool(v: str) -> bool:
    return str(v).lower() in {"1", "true", "yes", "y", "t"}


def pick_dtype(dtype_str: str):
    m = {
        "auto": None,
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
        "fp32": torch.float32,
    }
    if dtype_str not in m:
        raise ValueError(f"--dtype must be one of {list(m.keys())}")
    return m[dtype_str]


def load_tokenizer(tokenizer_id: str, trust_remote_code: bool, local_files_only: bool):
    try:
        tok = AutoTokenizer.from_pretrained(
            tokenizer_id,
            trust_remote_code=trust_remote_code,
            local_files_only=local_files_only,
        )
    except Exception:
        tok = AutoTokenizer.from_pretrained(
            "meta-llama/Llama-2-7b-hf",
            trust_remote_code=trust_remote_code,
            local_files_only=local_files_only,
        )

    # make sure we have a pad token for left padding + generation
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
        tok.pad_token_id = tok.eos_token_id
    return tok


# def load_model_and_tokenizer(
#     model_id_or_path: str,
#     tokenizer_id: str = None,
#     adapter_id_or_path: str = None,
#     epoch: str = None,
#     dtype_str: str = "fp16",
#     use_8bit: bool = False,
#     use_4bit: bool = False,
#     trust_remote_code: bool = False,
#     local_files_only: bool = False,
#     device_map: str = "auto",
#     peft: bool = True,
# ):
#     torch_dtype = pick_dtype(dtype_str)
#
#     offload_folder = os.environ.get("OFFLOAD_FOLDER", "offload_cache")
#     os.makedirs(offload_folder, exist_ok=True)
#
#     # leave 2GB headroom on GPU 0; adjust if needed
#     max_memory = {
#         0: os.environ.get("GPU_MAX_MEM", "46GiB"),
#         "cpu": os.environ.get("CPU_MAX_MEM", "128GiB"),
#     }
#
#     # tokenizer source preference:
#     # 1) explicit tokenizer_id
#     # 2) adapter repo (often contains tokenizer for convenience)
#     # 3) base model
#     tk_src = model_id_or_path or adapter_id_or_path or tokenizer_id
#     tokenizer = load_tokenizer(tk_src, trust_remote_code, local_files_only)
#
#     # quantization options (bitsandbytes)
#     quant_kwargs = {}
#     if use_8bit:
#         quant_kwargs["load_in_8bit"] = True
#     if use_4bit:
#         quant_kwargs["load_in_4bit"] = True
#
#     # base model
#     base = AutoModelForCausalLM.from_pretrained(
#         model_id_or_path,
#         subfolder=epoch if peft is not True else "",
#         device_map=device_map,
#         max_memory=max_memory,
#         offload_state_dict=True,
#         offload_folder=offload_folder,
#         low_cpu_mem_usage=True,  # helps with peak RAM
#         torch_dtype=torch_dtype,
#         trust_remote_code=trust_remote_code,
#         local_files_only=local_files_only,
#         **quant_kwargs,
#     )
#
#     # optionally apply LoRA/PEFT adapters
#     if adapter_id_or_path:
#         if PeftModel is None:
#             raise ImportError(
#                 "peft is required to load adapters. Install with `pip install peft`."
#             )
#         print("loading the adapters at epoch", epoch)
#         print(base)
#         print(adapter_id_or_path)
#         print(epoch)
#         hf_hub_download(
#             adapter_id_or_path,
#             "adapter_config.json",
#             subfolder=epoch,
#             local_files_only=local_files_only,
#         )
#         hf_hub_download(
#             adapter_id_or_path,
#             "adapter_model.safetensors",
#             subfolder=epoch,
#             local_files_only=local_files_only,
#         )
#         try:
#             base = PeftModel.from_pretrained(
#                 base,
#                 adapter_id_or_path,
#                 subfolder=epoch,
#                 is_trainable=False,
#                 local_files_only=local_files_only,
#                 **quant_kwargs,
#             )
#         except:
#             print("not loaded since base")
#             pass
#
#     base.eval()
#     # sanity: prove adapters are applied
#     try:
#         print("adapters:", base.get_adapter_names())
#     except Exception:
#         print("not loaded")
#         pass
#
#     return base, tokenizer


def assert_adapters_attached(m):
    if not isinstance(m, PeftModel):
        raise RuntimeError("Adapters not attached: model is not a PeftModel")
    if (
        not hasattr(m, "peft_config")
        or not isinstance(m.peft_config, dict)
        or len(m.peft_config) == 0
    ):
        raise RuntimeError("Adapters not attached: empty or missing peft_config")

    names = list(m.peft_config.keys())
    print("adapters:", names, "active:", getattr(m, "active_adapter", None))

    lora_modules = [n for n, _ in m.named_modules() if ("lora_A" in n or "lora_B" in n)]
    print("num lora modules:", len(lora_modules))


def load_model_and_tokenizer(
    model_id_or_path: str,
    tokenizer_id: str | None = None,
    adapter_id_or_path: str | None = None,
    epoch: str | int | None = None,
    dtype_str: str = "fp16",
    use_8bit: bool = False,
    use_4bit: bool = False,
    trust_remote_code: bool = False,
    local_files_only: bool = True,
    device_map: str = "auto",
    peft: bool = True,
):
    print("Test123")
    login(token=os.environ["HUGGINGFACE_HUB_TOKEN"])
    torch_dtype = pick_dtype(dtype_str)

    tk_src = model_id_or_path
    tokenizer = load_tokenizer(tk_src, trust_remote_code, local_files_only)

    quant_kwargs = {}
    if use_8bit and not use_4bit:
        quant_kwargs["load_in_8bit"] = True
        print("8bit")
    elif use_4bit and not use_8bit:
        quant_kwargs["load_in_4bit"] = True
        print("4bit")

    base = AutoModelForCausalLM.from_pretrained(
        model_id_or_path,
        device_map=device_map,
        low_cpu_mem_usage=True,
        torch_dtype=torch_dtype,
        trust_remote_code=trust_remote_code,
        local_files_only=local_files_only,
        **quant_kwargs,
    )
    base.eval()
    print(model_id_or_path)
    print(adapter_id_or_path)
    if peft and adapter_id_or_path:
        sub = None
        if epoch is not None:
            sub = str(epoch)
            if not sub.startswith("epoch-"):
                sub = f"epoch-{sub}"
        print("loading adapters:", adapter_id_or_path, "subfolder:", sub)
        base = PeftModel.from_pretrained(
            base,
            adapter_id_or_path,
            subfolder=sub,  # <-- matches your repo layout
            is_trainable=False,
            local_files_only=local_files_only,
        )
        assert_adapters_attached(base)
        try:
            print("adapters:", base.get_adapter_names())
        except Exception:
            print("not loaded")
            pass
        base.eval()

    return base, tokenizer


def answer_prob(prompts, answers, model, tokenizer, reduction="mean"):
    padding_side = tokenizer.padding_side

    # Concatenate each question and answer pair and encode them
    inputs = [p + a for p, a in zip(prompts, answers)]
    encoding = tokenizer(inputs, padding="longest", return_tensors="pt").to(
        model.device
    )

    # Get model's output (logits) for the batch
    with torch.no_grad():  # Disable gradient calculation for inference
        logits = model(
            **encoding, labels=encoding["input_ids"], prompts=prompts, answers=answers
        ).logits

    # Shift logits and labels to align for calculating the probability of answer tokens
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = encoding["input_ids"][..., 1:].contiguous()

    # Flatten the logits and labels to calculate loss easily across the batch
    flatten_logits = shift_logits.view(-1, shift_logits.size(-1))
    flatten_labels = shift_labels.view(-1)

    # Calculate loss using CrossEntropy to get log probabilities, then negate for actual log probs
    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    loss = loss_fct(flatten_logits, flatten_labels) * -1
    loss = loss.view(shift_labels.size())

    answer_log_probs = []
    for i, (_, attn_mask) in enumerate(
        zip(encoding["input_ids"], encoding["attention_mask"])
    ):
        # Find the start index of the answer and its actual length
        if padding_side == "right":
            answer_start = len(tokenizer.encode(prompts[i])) - 1
            answer_end = sum(attn_mask) - 1
            answer_length = answer_end - answer_start
        else:
            answer_start = (
                (attn_mask == 0).sum() + len(tokenizer.encode(prompts[i])) - 1
            )
            answer_end = len(encoding["input_ids"][i])
            answer_length = answer_end - answer_start

        # Select log probabilities corresponding to the actual answer tokens
        answer_log_probs.append(loss[i, answer_start : answer_start + answer_length])

    if reduction == "mean":
        answer_log_probs = [log_probs.mean().cpu() for log_probs in answer_log_probs]
    elif reduction == "sum":
        answer_log_probs = [log_probs.sum().cpu() for log_probs in answer_log_probs]
    else:
        answer_log_probs = [log_probs.cpu() for log_probs in answer_log_probs]

    return answer_log_probs


def normalized_prob(prob_answer, all_probs):
    return prob_answer / np.sum(all_probs)


def log_truth_ratio(log_probs_perturbed, log_prob_paraphrased):
    return np.exp(np.array(log_probs_perturbed).mean() - log_prob_paraphrased)


def perplexity(text_batch, model, tokenizer, reduction="mean"):
    encoding = tokenizer(text_batch, padding=True, return_tensors="pt").to(model.device)
    input_ids = encoding["input_ids"]
    attention_mask = encoding["attention_mask"]

    with torch.no_grad():  # Disable gradient calculation for inference
        outputs = model(
            input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
        )

    loss_fct = torch.nn.CrossEntropyLoss(
        reduction="none", ignore_index=tokenizer.pad_token_id
    )
    shift_logits = outputs.logits[..., :-1, :].contiguous()
    shift_labels = input_ids[..., 1:].contiguous()
    flatten_logits = shift_logits.view(-1, shift_logits.size(-1))
    flatten_labels = shift_labels.view(-1)
    losses = loss_fct(flatten_logits, flatten_labels)
    losses = losses.view(shift_labels.size())

    avg_losses = torch.mean(losses, dim=1)

    if reduction == "mean":
        perplexity = torch.mean(avg_losses).exp()
    else:
        perplexity = torch.exp(avg_losses)
    return perplexity.cpu().tolist()
