"""
python -m core.training.objectives
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from rich import print


class ModelOutput:
    def __init__(self, loss=None, logprobs=None, kl=None, rewards=None):
        self.loss = loss
        self.logprobs = logprobs
        self.kl = kl
        self.rewards = rewards


class BaseLoss:
    def __init__(self, config=None):
        self.config = config if config is not None else dict()
    
    def cross_entropy_loss(self, logits, labels, vocab_size, reduction='mean'):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_logits = shift_logits.view(-1, vocab_size)
        shift_labels = shift_labels.view(-1)
        return F.cross_entropy(shift_logits, shift_labels, reduction=reduction)


class CrossEntropyLoss(BaseLoss):
    def __init__(self, config=None):
        super().__init__(config)

    def compute_loss(self, model, inputs, tokenizer=None, reduction='mean', **kwargs):
        inputs = {k: v.cuda() for k, v in inputs.items() if k in ('input_ids', 'attention_mask', 'labels')}
        input_ids = inputs['input_ids'].cuda()
        attention_mask = inputs['attention_mask'].cuda()
        labels = inputs['labels'].cuda()
        action_mask = (labels[..., :-1] != -100).long()

        output = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = self.cross_entropy_loss(output.logits, labels, model.config.vocab_size, reduction=reduction)
        
        logits = output.logits
        logprobs = F.log_softmax(logits[..., :-1, :], dim=-1).gather(dim=-1, index=input_ids[..., 1:].unsqueeze(-1)).squeeze(-1)
        
        kl_beta = kwargs.get('kl_beta', 0.0)
        ref_model = kwargs.get('ref_model', None)
        if kl_beta > 0.0 and ref_model is not None:
            with torch.no_grad():
                ref_outputs = ref_model(input_ids=input_ids, attention_mask=attention_mask)
                ref_logits = ref_outputs.logits
                ref_logprobs = F.log_softmax(ref_logits[..., :-1, :], dim=-1).gather(dim=-1, index=input_ids[..., 1:].unsqueeze(-1)).squeeze(-1)

            kl = 0.5 * (logprobs - ref_logprobs).pow(2)
            loss += kl_beta * (kl.sum(dim=-1) / action_mask.sum(dim=-1)).mean()

        avg_logprobs = ((logprobs * action_mask).sum(dim=-1) / action_mask.sum(dim=-1)).unsqueeze(-1).clone().detach()
        return ModelOutput(
            loss=loss,
            logprobs=avg_logprobs,  # bsz x 1
        )


class REINFORCELoss(BaseLoss):
    def __init__(self, config=None):
        super().__init__(config)
        from torch.nn.utils.rnn import pad_sequence
        from core.data import tokenize_messages
        self.tokenize_messages = tokenize_messages
        self.pad_sequence = pad_sequence
    
    def compute_loss(
        self,
        model,
        ref_model,
        reward_model,
        batch_on_policy_texts,
        batch_parsed_on_policy_texts,
        batch_ends_with_eos,
        datapoints,
        tokenizer,
        max_seq_length,  # allowable training sequence length
        max_new_tokens,  # max generation tokens
        **kwargs,
    ):
        kl_beta = kwargs.get('kl_beta', 0.0)

        assert len(batch_on_policy_texts[0]) == len(batch_parsed_on_policy_texts[0]) == len(batch_ends_with_eos[0]) == 1

        on_policy_texts = [_[0] for _ in batch_on_policy_texts]
        parsed_on_policy_texts = [_[0] for _ in batch_parsed_on_policy_texts]
        ends_with_eos = [_[0] for _ in batch_ends_with_eos]

        # Compute response-level rewards ######################################
        rewards = [reward_model.compute_outcome_reward(parsed_on_policy_text, datapoint) for parsed_on_policy_text, datapoint in zip(parsed_on_policy_texts, datapoints)]
        rewards = torch.tensor(rewards).float().cuda()
        ends_with_eos = torch.tensor(ends_with_eos).float().cuda()
        rewards = (rewards * ends_with_eos).float()
        orig_outcome_rewards = rewards.clone().detach()
        #rewards = 2 * rewards - 1.0
        #######################################################################

        # tokenize on-policy texts
        messages_without_last_assistant = [datapoint['messages'][:-1] if datapoint['messages'][-1]['role'] == 'assistant' else datapoint['messages'] for datapoint in datapoints]
        batch_on_policy_messages = [m + [dict(role='assistant', content=on_policy_text)] for m, on_policy_text in zip(messages_without_last_assistant, on_policy_texts)]
        batch_on_policy_tokenized = [self.tokenize_messages(on_policy_messages, tokenizer, max_seq_length + max_new_tokens) for on_policy_messages in batch_on_policy_messages]

        input_ids = [tokenized['input_ids'] for tokenized in batch_on_policy_tokenized]
        attention_mask = [tokenized['attention_mask'] for tokenized in batch_on_policy_tokenized]
        labels = [tokenized['labels'] for tokenized in batch_on_policy_tokenized]

        input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).cuda()
        attention_mask = self.pad_sequence(attention_mask, batch_first=True, padding_value=0).cuda()
        labels = self.pad_sequence(labels, batch_first=True, padding_value=-100).cuda()

        bsz, seq_length = input_ids.size()

        # Take mean over non-zero loss along the seq_length dimension
        action_mask = (labels[..., :-1] != -100).long()
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        logprobs = F.log_softmax(logits[..., :-1, :], dim=-1).gather(dim=-1, index=input_ids[..., 1:].unsqueeze(-1)).squeeze(-1)
        
        # Calculate KL divergence to the reference model ######################
        if ref_model is not None and kl_beta > 0.0:
            with torch.no_grad():
                ref_outputs = ref_model(input_ids=input_ids, attention_mask=attention_mask)
                ref_logits = ref_outputs.logits.detach()
                ref_logprobs = F.log_softmax(ref_logits[..., :-1, :], dim=-1).gather(dim=-1, index=input_ids[..., 1:].unsqueeze(-1)).squeeze(-1)
            kl = 0.5 * (logprobs - ref_logprobs).pow(2)
            avg_kl = (kl * action_mask).sum(dim=-1) / action_mask.sum(dim=-1)
            num_generation_per_prompt = 1
            avg_kl = avg_kl.reshape(bsz, num_generation_per_prompt).mean(dim=1, keepdim=True)
        else:
            kl = 0.0
            avg_kl = torch.zeros(bsz).to(input_ids.device)
        #######################################################################

        avg_logprobs = (logprobs * action_mask).sum(dim=-1) / action_mask.sum(dim=-1)
        avg_logprobs = avg_logprobs.unsqueeze(-1).clone().detach()

        loss = -rewards.to(logprobs.device).unsqueeze(-1) * logprobs
        loss = ((loss * action_mask).sum(dim=-1) / action_mask.sum(dim=-1)).mean()
        return ModelOutput(
            loss=loss,
            logprobs=avg_logprobs,                       # bsz x 1
            kl=avg_kl.unsqueeze(-1),                     # bsz x 1
            rewards=orig_outcome_rewards.unsqueeze(-1),  # bsz x 1
        )


class GRPOLoss(BaseLoss):
    def __init__(self, config=None):
        super().__init__(config)
        from torch.nn.utils.rnn import pad_sequence
        from core.data import tokenize_messages
        self.tokenize_messages = tokenize_messages
        self.pad_sequence = pad_sequence
    
    def compute_loss(
        self,
        model,
        ref_model,
        reward_model,
        batch_on_policy_texts,
        batch_parsed_on_policy_texts,
        batch_ends_with_eos,
        datapoints,
        tokenizer,
        **kwargs,
    ):
        """
        batch_on_policy_texts (List[List[str]]): shape = bsz x num_generation_per_prompt
        batch_parsed_on_policy_texts (List[List[str]]): shape = bsz x num_generation_per_prompt
        batch_ends_with_eos (List[List[bool]]): shape = bsz x num_generation_per_prompt
        kl_beta (float)
        datapoints (List[Dict]): shape = bsz
        tokenizer (Tokenizer)
        max_seq_length (int): allowable training sequence length
        max_new_tokens (int): max generation tokens
        local_rank (int)

        GRPO:
        1. advantage is calculated within each group
        2. loss/reward/advantage/kl is calculated at token level
        """
        kl_beta = kwargs.get('kl_beta', None)
        max_seq_length = kwargs.get('max_seq_length', None)
        max_new_tokens = kwargs.get('max_new_tokens', None)

        # Number of groups needs to be greater than 1
        assert len(batch_on_policy_texts[0]) == len(batch_parsed_on_policy_texts[0]) == len(batch_ends_with_eos[0]) > 1

        bsz = len(datapoints)
        num_generation_per_prompt = len(batch_on_policy_texts[0])

        # tokenize on-policy texts
        batch_messages_without_last_assistant = [datapoint['messages'][:-1] if datapoint['messages'][-1]['role'] == 'assistant' else datapoint['messages'] for datapoint in datapoints]
        batch_on_policy_messages = [
            [messages_without_last_assistant + [dict(role='assistant', content=text)] for text in on_policy_texts]
            for messages_without_last_assistant, on_policy_texts in zip(batch_messages_without_last_assistant, batch_on_policy_texts)
        ]
        batch_on_policy_tokenized = [
            [
                self.tokenize_messages(
                    msg,
                    tokenizer,
                    max_seq_length + max_new_tokens
                )
                for msg in on_policy_messages
            ]
            for on_policy_messages in batch_on_policy_messages
        ]

        input_ids = [tokenized['input_ids'] for on_policy_tokenized in batch_on_policy_tokenized for tokenized in on_policy_tokenized]
        attention_mask = [tokenized['attention_mask'] for on_policy_tokenized in batch_on_policy_tokenized for tokenized in on_policy_tokenized]
        labels = [tokenized['labels'] for on_policy_tokenized in batch_on_policy_tokenized for tokenized in on_policy_tokenized]

        input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id).cuda()
        attention_mask = self.pad_sequence(attention_mask, batch_first=True, padding_value=0).cuda()
        labels = self.pad_sequence(labels, batch_first=True, padding_value=-100).cuda()

        action_mask = (labels[..., :-1] != -100).long()

        batch_ends_with_eos = torch.tensor(batch_ends_with_eos).float().to(input_ids.device)

        outcome_rewards = []
        for parsed_on_policy_texts, datapoint in zip(batch_parsed_on_policy_texts, datapoints):
            outcome_rewards.append([float(reward_model.compute_outcome_reward(text, datapoint)) for text in parsed_on_policy_texts])

        outcome_rewards = torch.tensor(outcome_rewards).to(input_ids.device)  # shape = bsz x num_generation_per_prompt
        outcome_rewards = outcome_rewards * batch_ends_with_eos
        or_std = outcome_rewards.std(dim=1, keepdim=True) + 1e-8
        or_mean = outcome_rewards.mean(dim=1, keepdim=True)
        orig_outcome_rewards = outcome_rewards.clone().detach()
        outcome_rewards = (outcome_rewards - or_mean) / or_std
        advantages = outcome_rewards.reshape(bsz * num_generation_per_prompt, -1)  # shape = (bsz * num_generation_per_prompt) x seq_length
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # shape = bsz x num_generation_per_prompt x seq_length x vocab_size
        logprobs = F.log_softmax(logits[..., :-1, :], dim=-1).gather(dim=-1, index=input_ids[..., 1:].unsqueeze(-1)).squeeze(-1)

        # Calculate KL divergence to the reference model ######################
        if ref_model is not None and kl_beta > 0.0:
            with torch.no_grad():
                ref_outputs = ref_model(input_ids=input_ids, attention_mask=attention_mask)
                ref_logits = ref_outputs.logits.detach()
                ref_logprobs = F.log_softmax(ref_logits[..., :-1, :], dim=-1).gather(dim=-1, index=input_ids[..., 1:].unsqueeze(-1)).squeeze(-1)
            kl = 0.5 * (logprobs - ref_logprobs).pow(2)
            avg_kl = (kl * action_mask).sum(dim=-1) / action_mask.sum(dim=-1)
            avg_kl = avg_kl.reshape(bsz, num_generation_per_prompt).mean(dim=1, keepdim=True)
        else:
            kl = 0.0
            avg_kl = torch.zeros(bsz, 1).to(input_ids.device)
        #######################################################################

        avg_logprobs = (logprobs * action_mask).sum(dim=-1) / action_mask.sum(dim=-1)
        avg_logprobs = avg_logprobs.reshape(bsz, num_generation_per_prompt, -1).squeeze(-1)

        loss = -(advantages * logprobs - kl_beta * kl)  # shape = (bsz * num_generation_per_prompt) x seq_length
        loss = ((loss * action_mask).sum(dim=-1) / action_mask.sum(dim=-1)).mean()

        return ModelOutput(
            loss=loss,
            logprobs=avg_logprobs,         # bsz x num_generation_per_prompt
            kl=avg_kl,                     # bsz x 1
            rewards=orig_outcome_rewards,  # bsz x num_generation_per_prompt
        )


class DPOLoss(BaseLoss):
    def __init__(self, config=None):
        super().__init__(config)

    def compute_loss(
        self,
        model,
        ref_model,
        inputs,
        kl_beta,
        tokenizer,
        local_rank=0,
    ):

        scores = self.model_forward_for_scores(model, inputs, tokenizer, local_rank)

        with torch.no_grad():
            ref_scores = self.model_forward_for_scores(ref_model, inputs, tokenizer, local_rank)
            ref_scores = ref_scores.detach()
        
        scores = kl_beta * scores - kl_beta * ref_scores
        loss = -F.logsigmoid(scores).mean()
        return loss

    def model_forward_for_scores(self, model, inputs, tokenizer, local_rank):
        inputs = {k: v.cuda() for k, v in inputs.items()}
        chosen_inputs = inputs['chosen_input_ids']
        chosen_attention_mask = inputs['chosen_attention_mask']
        chosen_labels = inputs['chosen_labels']

        rejected_inputs = inputs['rejected_input_ids']
        rejected_attention_mask = inputs['rejected_attention_mask']
        rejected_labels = inputs['rejected_labels']

        chosen_outputs = model(
            input_ids=chosen_inputs,
            attention_mask=chosen_attention_mask,
            labels=chosen_labels,
        )
        rejected_outputs = model(
            input_ids=rejected_inputs,
            attention_mask=rejected_attention_mask,
            labels=rejected_labels,
        )

        chosen_logprobs = chosen_outputs.logits.log_softmax(dim=-1)
        rejected_logprobs = rejected_outputs.logits.log_softmax(dim=-1)

        chosen_labels = chosen_labels.masked_fill(chosen_labels == -100, tokenizer.pad_token_id)
        rejected_labels = rejected_labels.masked_fill(rejected_labels == -100, tokenizer.pad_token_id)

        chosen_logprobs = chosen_logprobs.gather(dim=-1, index=chosen_labels.unsqueeze(-1)).squeeze(-1)
        rejected_logprobs = rejected_logprobs.gather(dim=-1, index=rejected_labels.unsqueeze(-1)).squeeze(-1)

        chosen_mask = chosen_labels == tokenizer.pad_token_id
        rejected_mask = rejected_labels == tokenizer.pad_token_id

        chosen_last_positions = chosen_attention_mask.sum(dim=-1) - 1
        rejected_last_positions = rejected_attention_mask.sum(dim=-1) - 1

        # Ensure the last position of <eos> is included for calculating the logprob
        chosen_mask[torch.arange(chosen_mask.size(0)), chosen_last_positions] = False
        rejected_mask[torch.arange(rejected_mask.size(0)), rejected_last_positions] = False

        chosen_logprobs = chosen_logprobs.masked_fill(chosen_mask, 0.0)
        rejected_logprobs = rejected_logprobs.masked_fill(rejected_mask, 0.0)

        #######################################################################
        # TODO: remove this
#        chosen_logprobs_debug = chosen_logprobs.clone().detach()
        #######################################################################

        chosen_logprobs = chosen_logprobs.sum(dim=-1) / chosen_mask.sum(dim=-1)
        rejected_logprobs = rejected_logprobs.sum(dim=-1) / rejected_mask.sum(dim=-1)

        scores = chosen_logprobs - rejected_logprobs

        #######################################################################
        debug = False
#        debug = True
        if debug and local_rank == 0:
            chosen_input = chosen_inputs[0]
            attention_mask = chosen_attention_mask[0]
            labels = chosen_labels[0]
            chosen_logprob_debug = chosen_logprobs_debug[0]
            for input_id, attn_mask, label, chosen_logprob_debug in zip(chosen_input, attention_mask, labels, chosen_logprob_debug):
                tok = tokenizer.decode(input_id)
                attn_mask = attn_mask.item()
                label = tokenizer.decode(label)
                if attn_mask == 0 and label == tokenizer.pad_token:
                    continue
                chosen_logprob_debug = chosen_logprob_debug
                print((tok, attn_mask, label, chosen_logprob_debug))
            print('#' * 100)
            #input()
        #######################################################################

        return scores
