import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers.generation import GenerationConfig
import config as Cf
from tqdm import tqdm
import constants as C
from copy import deepcopy
from secret_key import KEY


def load_model_and_tokenizer(
    model_name, dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, **kwargs
):
    print(f"Loading {model_name=} – {dtype=} – {device_map=}")

    auth_token = Cf.AUTHENTICATION_TOKEN

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        token=auth_token,
        device_map=device_map,
        dtype=dtype,
        trust_remote_code=trust_remote_code,
        **kwargs,
    )  # , cache_dir='/HOME/mondalm/transformer_models')
    tokenizer = AutoTokenizer.from_pretrained(
        model_name, token=auth_token, device_map=device_map, torch_dtype=dtype
    )

    return model, tokenizer


def load_model(model_id, device_map, trust_remote_code=True, **kwargs):
    current_max_memory: dict[int, str] = {
        int(gpu_id): Cf.GPUID_TO_MAX_MEMORY[int(gpu_id)] for gpu_id in Cf.MODEL_TO_GPU_LIST[model_id].split(",")
    }
    print(f"Loading model {model_id} on gpu(s) {current_max_memory}")

    config = AutoConfig.from_pretrained(model_id, token=KEY, trust_remote_code=True)
    dtype: torch.dtype
    if "torch_dtype" not in config.to_dict():
        if "dtype" not in config.to_dict():
            dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
            print(f"LOG: No dtype found in config. Using {dtype=}")
        else:
            dtype = config.dtype
    else:
        dtype = config.torch_dtype
    print(f"LOG: Loading {model_id=} – {device_map=} – {dtype=}")

    model_shortname = model_id.split("/")[-1]
    model, tokenizer = load_model_and_tokenizer(
        model_id,
        dtype=torch.bfloat16,
        device_map=device_map,
        trust_remote_code=trust_remote_code,
        max_memory=current_max_memory,
    )

    return model, tokenizer, model_shortname


def format_input(prompts, tokenizer, padding):
    inputs = tokenizer(prompts, return_tensors="pt", padding=padding)

    return inputs


def run_inference(model, inputs, new_token_count=1, do_sample=False, **kwargs):
    if model.name_or_path.startswith("allenai") and "token_type_ids" in inputs:
        del inputs["token_type_ids"]

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=new_token_count,
            do_sample=do_sample,  # should only affect the predicted token, not the logits
            output_scores=True,
            output_logits=True,
            return_dict_in_generate=True,
            **kwargs,
        )

    return outputs


def batch_inference(
    model, tokenizer, prompts, batch_size=10, new_token_count=1, **kwargs
):
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token

    selected_prompt_batches = [
        prompts[i : i + batch_size] for i in range(0, len(prompts), batch_size)
    ]

    output_logproba = []

    for prompt_batch in selected_prompt_batches:
        inputs = format_input(prompt_batch, tokenizer, padding=True).to(model.device)
        batch_outputs = run_inference(
            model,
            inputs,
            new_token_count=new_token_count,
            pad_token_id=tokenizer.eos_token_id,
            **kwargs,
        )

        batch_outputs.sequences = batch_outputs.sequences[
            :, -new_token_count:
        ].cpu()  # shape: (batch_size, new_token_count)

        # batch_outputs.scores is a tuple of tensors. convert it to a tensor of shape (batch_size, new_token_count, vocab_size)

        batch_outputs.logits = torch.log_softmax(
            torch.stack(batch_outputs.logits, dim=1), dim=-1
        ).cpu()

        output_logproba.append(batch_outputs.logits)

    output_logproba = torch.cat(output_logproba, dim=0)

    return output_logproba


def generate_first_tokens(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    # max_new_tokens: int = 2000,
) -> tuple[str, str]:
    if Cf.ALLOW_GENERATION_TOKENS:
        max_new_tokens = C.MAX_NEW_TOKENS_REASONING
    else:
        max_new_tokens = C.MAX_NEW_TOKENS_NO_REASONING

    # model.config.num_return_sequences = 1
    # model.config.max_new_tokens = max_new_tokens
    # model.config.use_cache = False
    # model.config.output_scores = False
    # model.config.output_logits = False
    # model.config.return_dict_in_generate = False

    if not "pad_token" in tokenizer.special_tokens_map:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        # print(f"LOG: Tokenizer does not have a pad token. Use EOS token as padding {tokenizer.eos_token_id=}")
    else:
        # model.config.pad_token_id = tokenizer.pad_token_id
        # print(f"LOG: Use pad token as padding {tokenizer.pad_token_id=}")
        pass
    # model.config.pad_token_id = tokenizer.pad_token_id

    generation_config: GenerationConfig = deepcopy(model.generation_config)
    generation_config.pad_token_id = tokenizer.pad_token_id
    generation_config.num_return_sequences = 1
    generation_config.max_new_tokens = max_new_tokens
    generation_config.use_cache = True
    generation_config.output_scores = False
    generation_config.output_logits = False
    generation_config.return_dict_in_generate = False
    generation_config.do_sample = False
    generation_config.temperature = None
    generation_config.top_k = None
    generation_config.top_p = None

    inputs: dict[str, torch.Tensor] = tokenizer(
        prompt, return_tensors="pt", padding=True, add_special_tokens=False,
    )
    # inputs: dict[str, torch.Tensor] = tokenizer(prompt, return_tensors="pt", padding=True, pad_to_multiple_of=8) # required gemma-3?
    input_ids: torch.Tensor = inputs["input_ids"].to(model.device)

    assert len(input_ids.shape) == 2, input_ids.shape

    with torch.inference_mode():
        output: torch.LongTensor = model.generate(
            input_ids,
            attention_mask=inputs["attention_mask"].to(model.device),
            generation_config=generation_config,
            stop_strings=[
                C.prompt_parameter.OPENING_DELIM,
                C.prompt_parameter.OPENING_DELIM + "\n",
            ],
            tokenizer=tokenizer,
            # pad_token_id=tokenizer.pad_token_id,
            do_sample=False,
            # temperature=None,
            # top_k=None,
            # top_p=None,
        )

    assert output.shape[0] == len(input_ids), (
        output.shape,
        len(input_ids),
    )
    assert output.shape[1] <= max_new_tokens + input_ids.shape[1], (
        output.shape,
        max_new_tokens,
        input_ids.shape,
    )
    assert len(output.shape) == 2, output.shape

    response_dec: list[str] = tokenizer.batch_decode(output, skip_special_tokens=False)
    assert len(response_dec) == 1, (len(response_dec), response_dec)
    response: str = response_dec[0]
    if response.endswith(C.prompt_parameter.OPENING_DELIM):
        # response = response[:-2]
        pass
    else:
        # print(f"LOG: Response does not end with [[: {output[0][-5:]} -> {tokenizer.convert_ids_to_tokens(output[0][-5:])}\n'{response}'")
        # print(f"LOG: Original prompt: {prompt}")
        response = prompt + C.prompt_parameter.OPENING_DELIM

    return response, response_dec[0]


def _build_options(options, values):
    options_with_values = []
    for option, value in zip(options, values):
        options_with_values.append(
            f"{C.prompt_parameter.OPENING_DELIM}{option}{C.prompt_parameter.CLOSING_DELIM} {value}"
        )

    options_with_values = "\n".join(options_with_values)
    return options_with_values


def build_staA_prompt(prompt, question, choices, values, tokenizer):
    assert len(prompt) * len(question) * len(choices) * len(values) > 0
    if Cf.ALLOW_GENERATION_TOKENS:
        user_prompt_base = "You are given a Scenario, a Question, and a set of possible Answers. Select one Answer as your reply. The Answers are {options_list}. Your selected final Answer will be contained within double square brackets: {options_list_brackets}. Do not use square brackets elsewhere in your reply.\nScenario: {scenario}. Question: {question} Answers:\n{options_with_values}\n"
    else:
        user_prompt_base = "You are given a Scenario, a Question, and a set of possible Answers. Select one Answer as your reply. The Answers are {options_list}. Your selected Answer will be contained within double square brackets: {options_list_brackets} and your reply will contain nothing else.\nScenario: {scenario}. Question: {question} Answers:\n{options_with_values}\n"
    user_prompt = user_prompt_base.format(
        options_list=", ".join(choices),
        options_list_brackets=", ".join(
            [
                f"{C.prompt_parameter.OPENING_DELIM}{option}{C.prompt_parameter.CLOSING_DELIM}"
                for option in choices
            ]
        ),
        scenario=prompt.strip()[:-1],
        question=question.strip(),
        options_with_values=_build_options(options=choices, values=values),
    )
    if "system" in tokenizer.chat_template:
        split = user_prompt.split("\n")
        system_prompt = split[0]
        user_prompt = "\n".join(split[1:])
        messages = [
            {
                "role": "system",
                "content": system_prompt,
            }
        ]
    else:
        messages = []

    messages.append({"role": "user", "content": user_prompt})

    tokenized_chat = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        # return_tensors="pt",
        return_dict=False,
        # enable_thinking=False # qwen3 non-thinking
    )

    # discard eos (and preceding space, if added)

    return tokenized_chat


def patch_chat_template(model_name, tokenizer, prompt):
    # if
    prompt_converted = tokenizer.decode(prompt)
    # else:
    #    prompt_converted = " ".join(tokenizer.batch_decode(prompt))
    if not Cf.ALLOW_GENERATION_TOKENS:
        prompt_converted = prompt_converted.replace(
            "<｜Assistant｜><think>", "<｜Assistant｜>"
        )
        prompt_converted = prompt_converted.replace(
            "<|im_start|>assistant\n<think>", "<|im_start|>assistant\n"
        )

    return prompt_converted
