

from __future__ import annotations

import scipy
from tqdm import tqdm


# isort: off
import torch  # should be imported before bigbench
import bigbench.api.model as model
import bigbench.models.model_utils as model_utils

# isort: on
from safe_rlhf.configs import PROMPT_INPUT
from safe_rlhf.models.pretrained import load_pretrained_models


def _compute_loss(
    logits: torch.Tensor,
    labels: torch.Tensor,
    label_masks: torch.Tensor,
) -> list[float]:
    loss_fn = torch.nn.CrossEntropyLoss(reduction='none')

    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    shift_label_masks = label_masks[..., 1:].contiguous()

    loss = loss_fn(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
    ).reshape(shift_labels.size())

    loss = (loss * shift_label_masks).sum(-1)
    return (-loss).cpu().numpy().tolist()


class SafeRLHFModel(model.Model):
    def __init__(
        self,
        model_path: str,
        show_progress: bool = False,
        batch_size: int = 4,
        max_length: int = 2048,
    ) -> None:
        self._model, self._tokenizer = load_pretrained_models(
            model_path,
            model_max_length=max_length,
            padding_side='left',
            auto_device_mapping=True,
            trust_remote_code=True,
        )

        self._show_progress = show_progress

        self._batch_size = batch_size
        self._max_length = max_length

    def model_data(self):
        return model.ModelData(
            model_family='Llama',
            model_name='alpaca7B',
            non_embedding_params=0,
            flop_matched_non_embedding_params=0,
            total_params=7_000_000_000,
            training_batch_size=128,
            training_steps=500_000 / 128 * 5,
            description='Alpaca 13B',
            decoding_params={},
        )

    def generate_text(
        self,
        inputs: str | list[str],
        max_length: int | None = 2048,
        stop_string: str | None = None,
        output_regex: str | None = None,
    ) -> str | list[str]:
        max_length = max_length or self._max_length
        original_inputs = inputs
        inputs = inputs if isinstance(inputs, list) else [inputs]
        inputs = [PROMPT_INPUT.format(text) for text in inputs]
        generated = []

        idx_lst = range(0, len(inputs), self._batch_size)
        if self._show_progress:
            idx_lst = tqdm(idx_lst, desc='Generating text...')

        for idx in idx_lst:
            batch_text = inputs[idx : idx + self._batch_size]
            batch = self._tokenizer(
                batch_text,
                return_tensors='pt',
                padding=True,
                truncation=True,
                max_length=max_length,
            )
            output = self._model.generate(**batch, max_length=max_length)
            output_text = self._tokenizer.batch_decode(output, skip_special_tokens=True)
            for i, text in enumerate(output_text):
                output_text[i] = text[len(batch_text[i]) :]
            generated.extend(output_text)

        if isinstance(original_inputs, str):
            generated = generated[0]

        return model_utils.postprocess_output(
            generated,
            max_length,
            stop_string,
            output_regex,
        )

    def cond_log_prob(
        self,
        inputs: str | list[str],
        targets: list[str] | list[list[str]],
        batch_size: int | None = None,
        absolute_normalization: bool | None = None,
    ) -> list[float] | list[list[float]]:
        batch_size = batch_size or self._batch_size
        original_inputs = inputs

        if isinstance(inputs, str):
            inputs = [inputs]
            targets = [targets]

        inputs = [PROMPT_INPUT.format(text) for text in inputs]

        flatten_idx, flatten_inputs, flatten_choices = [], [], []
        # pylint: disable-next=redefined-builtin
        for idx, (input, choice) in enumerate(zip(inputs, targets)):
            for choice_idx, choice_text in enumerate(choice):
                flatten_idx.append((idx, choice_idx))
                flatten_inputs.append(input)
                flatten_choices.append(choice_text)
        loss = []

        idx_lst = range(0, len(flatten_inputs), batch_size)
        if self._show_progress:
            idx_lst = tqdm(idx_lst, desc='Computing log prob...')

        for idx in idx_lst:
            batch_inputs = flatten_inputs[idx : idx + batch_size]
            batch_choices = flatten_choices[idx : idx + batch_size]
            batch = {
                'input_ids': [],
                'token_type_ids': [],
            }
            for input, choice in zip(batch_inputs, batch_choices):
                input_ids = self._tokenizer(
                    input,
                    truncation=True,
                    max_length=self._max_length,
                )['input_ids']
                choice_ids = self._tokenizer(
                    choice,
                    truncation=True,
                    max_length=self._max_length,
                )['input_ids']
                batch['input_ids'].append(input_ids + choice_ids)
                batch['token_type_ids'].append([0] * len(input_ids) + [1] * len(choice_ids))

            batch = self._tokenizer.pad(batch)
            batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}

            with torch.no_grad():
                batch_logits = self._model(
                    batch['input_ids'],
                    attention_mask=batch['attention_mask'],
                ).logits

            batch_loss = _compute_loss(
                logits=batch_logits,
                labels=batch['input_ids'],
                label_masks=batch['token_type_ids'],
            )
            loss.extend(batch_loss)

        scores = [[] for _ in range(len(inputs))]
        for (idx, _), score in zip(flatten_idx, loss):
            assert score != 0, 'score should not be zero'
            scores[idx].append(score)

        if not absolute_normalization:
            scores = [list(score_row - scipy.special.logsumexp(score_row)) for score_row in scores]

        if isinstance(original_inputs, str):
            scores = scores[0]

        return scores
