from __future__ import annotations
from typing import Optional

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import AutoTokenizer, PreTrainedTokenizer
from modeling.processing_aya_vision import AyaVisionProcessor
from modeling.modeling_aya_vision import AyaVisionForConditionalGeneration

from modeling.modeling_llama import LlamaForCausalLM

def get_optimizer(model, args, detector=None):
    parameters = [p for p in model.parameters() if p.requires_grad]
    if detector is not None:
        parameters += list(detector.parameters())
    optimizer = optim.Adam(parameters, lr=args.learning_rate, fused=True)
    return optimizer


def get_model_name(nickname):
    model_name = {'aya_8B': "CohereLabs/aya-vision-8b",
                  'llama3_1B': "meta-llama/Llama-3.2-1B-Instruct",
                  'llama3_3B': "meta-llama/Llama-3.2-3B-Instruct",
                  }[nickname]

    is_vision = False
    if 'aya' in model_name.lower():
        is_vision = True

    return model_name, is_vision


def load_model(
    model_name_or_path: str,
    hf_token: str,
    trust_remote_code: bool = False,
    bf16: bool = True,
    device=None,
    compile=False
) -> tuple[LlamaForCausalLM, PreTrainedTokenizer]:

    if 'aya' in model_name_or_path:
        chat_template = "{{ bos_token }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># System Preamble\nYou are a helpful and honest assistant.\n<|END_OF_TURN_TOKEN|>\n{%- for message in messages -%}\n    <|START_OF_TURN_TOKEN|>{{ message.role | replace(\"user\", \"<|USER_TOKEN|>\") | replace(\"assistant\", \"<|CHATBOT_TOKEN|><|START_RESPONSE|>\") | replace(\"system\", \"<|SYSTEM_TOKEN|>\") }}\n    {%- if message.content is defined -%}\n        {%- if message.content is string -%}\n{{ message.content }}\n        {%- else -%}\n            {%- for item in message.content | selectattr('type', 'equalto', 'image') -%}\n<image>\n            {%- endfor -%}\n            {%- for item in message.content | selectattr('type', 'equalto', 'text') -%}\n{{ item.text }}\n            {%- endfor -%}\n        {%- endif -%}\n    {%- elif message.message is defined -%}\n        {%- if message.message is string -%}\n{{ message.message }}\n        {%- else -%}\n            {%- for item in message.message | selectattr('type', 'equalto', 'image') -%}\n<image>\n            {%- endfor -%}\n            {%- for item in message.message | selectattr('type', 'equalto', 'text') -%}\n{{ item.text }}\n            {%- endfor -%}\n        {%- endif -%}\n    {%- endif -%}\n    {%- if message.role == \"assistant\" -%}\n<|END_RESPONSE|>\n    {%- endif -%}\n<|END_OF_TURN_TOKEN|>\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\n{%- endif -%}\n"
        tokenizer_name = model_name_or_path
        processor = AyaVisionProcessor.from_pretrained(tokenizer_name, chat_template=chat_template, token=hf_token)
        tokenizer = processor.tokenizer
    else:
        tokenizer_name = model_name_or_path
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, token=hf_token)
        processor = None

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = 0
    assert tokenizer.pad_token_id == 0

    if 'llama' in model_name_or_path.lower():
        model = LlamaForCausalLM.from_pretrained(
            model_name_or_path,
            trust_remote_code=trust_remote_code,
            attn_implementation="flash_attention_2" if str(device) != 'cpu' else None,
            torch_dtype=torch.bfloat16,
            device_map=device,
            token=hf_token
        )
    elif 'aya' in model_name_or_path.lower():
        model = AyaVisionForConditionalGeneration.from_pretrained(
            model_name_or_path,
            trust_remote_code=trust_remote_code,
            _attn_implementation="flash_attention_2" if str(device) != 'cpu' else None,
            torch_dtype=torch.bfloat16,
            device_map=device,
            token=hf_token
        )
        for name, param in model.named_parameters():
            if not name.startswith('language_model'):
                param.requires_grad = False
            if name == "language_model.model.embed_tokens.weight":
                param.requires_grad = False
            if name.startswith("language_model.model.layers.") and int(name.split('.')[3]) < 16:
                param.requires_grad = False
            print(name, param.requires_grad)

    if compile:
        print('compiling model')
        model = torch.compile(model, dynamic=True)

    if processor is not None:
        return model, processor
    return model, tokenizer


def approx_kl_divergence(
    log_probs: torch.Tensor,
    log_probs_ref: torch.Tensor,
    action_mask: Optional[torch.Tensor],
) -> torch.Tensor:
    log_ratio = log_probs_ref.float() - log_probs.float()
    if action_mask is not None:
        log_ratio = log_ratio * action_mask

    return log_ratio.exp() - log_ratio - 1


def masked_mean(
        tensor: torch.Tensor,
        mask: Optional[torch.Tensor],
        dim: int = None,
        norm_items_per_row: int = None
) -> torch.Tensor:
    if mask is None:
        return tensor.mean(axis=dim)

    if norm_items_per_row is None:
        norm_items_per_row = mask.sum(axis=dim)

    return (tensor * mask).sum(axis=dim) / norm_items_per_row


class GRPOLoss(nn.Module):

    def __init__(self, clip_eps_low=0.2, clip_eps_high=0.2, kl_weight=0.01) -> None:
        super().__init__()
        self.clip_eps_low = clip_eps_low
        self.clip_eps_high = clip_eps_high
        self.kl_weight = kl_weight

    def forward(
        self,
        log_probs: torch.Tensor,
        log_probs_old: torch.Tensor,
        log_probs_ref: torch.Tensor,
        completion_mask: torch.Tensor,
        advantages: torch.Tensor,
        norm_items_per_row: int
    ) -> tuple[torch.Tensor, torch.Tensor]:

        kl = approx_kl_divergence(
            log_probs=log_probs,
            log_probs_ref=log_probs_ref,
            action_mask=completion_mask,
        )

        ratio = (log_probs - log_probs_old).exp()
        surr1 = ratio * advantages
        surr2 = ratio.clamp(1 - self.clip_eps_low, 1 + self.clip_eps_high) * advantages
        loss = -torch.min(surr1, surr2) + self.kl_weight * kl

        loss = masked_mean(loss, completion_mask, dim=-1, norm_items_per_row=norm_items_per_row).mean()
        return loss, kl.mean()


def sequence_log_probs_from_logits(
    logits: torch.tensor, output_ids: torch.tensor
) -> torch.Tensor:
    log_prob = F.log_softmax(logits, dim=-1)
    return log_prob.gather(dim=-1, index=output_ids.unsqueeze(-1)).squeeze(-1)

@torch.compile(dynamic=True)
def compiled_sequence_log_probs_from_logits(
    logits: torch.tensor, output_ids: torch.tensor
) -> torch.Tensor:
    log_prob = F.log_softmax(logits.to(torch.float32), dim=-1)
    return log_prob.gather(dim=-1, index=output_ids.unsqueeze(-1)).squeeze(-1)

@torch.compile(dynamic=True)
def iterative_sequence_log_probs_from_logits(logits, output_ids):
    per_token_logps = []
    for row_logits, row_labels in zip(logits, output_ids):  # loop to reduce peak mem consumption
        row_logps = F.log_softmax(row_logits.to(torch.float32), dim=-1)
        row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(row_per_token_logps)
    per_token_logps = torch.stack(per_token_logps)
    return per_token_logps

def sequences_log_probs(
    model: LlamaForCausalLM,
    sequence_ids: torch.Tensor
) -> torch.Tensor:
    output = model.forward(
        input_ids=sequence_ids,
        use_cache=False,
    )
    logits = output["logits"]
    log_probs = compiled_sequence_log_probs_from_logits(
        logits=logits[:, :-1],
        output_ids=sequence_ids[:, 1:],
    )
    return log_probs




def get_per_token_logps(model, inputs, num_logits_to_keep):
    # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
    #print('inputs', inputs.keys())
    output = model.forward(**inputs,
                           use_cache=False,
                           logits_to_keep=num_logits_to_keep + 1)

    logits = output["logits"]
    logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

    input_ids = inputs['input_ids'][:, -num_logits_to_keep:]
    per_token_logps = compiled_sequence_log_probs_from_logits(logits, input_ids)  # compute logprobs for the input tokens
    return per_token_logps


def iterative_sequences_log_probs(
    model: LlamaForCausalLM,
    sequence_ids: torch.Tensor
) -> torch.Tensor:
    output = model.forward(
        input_ids=sequence_ids,
        use_cache=False,
    )
    logits = output["logits"]
    log_probs = iterative_sequence_log_probs_from_logits(
        logits=logits[:, :-1],
        output_ids=sequence_ids[:, 1:],
    )
    return log_probs

