from typing import Dict, Any
import abc
from tqdm import tqdm
from flax import linen as nn
from jax import numpy as jnp
import jax

from latte_trans.config import Config
from latte_trans.evals.losses import cross_entropy_loss_lm, cross_entropy_loss
from latte_trans.models.modules.seq_layers import Decoder


def cross_entropy_loss_lm2(logits, target, mask, ignore_index=-100):
    """
    Ignores masked tokens from the loss. Specifically designed for the copy task
    Args:
        logits: jnp.array(BLH)
        target: jnp.array(BL, dtype=long)
        ignore_index: must be a negative value
    """
    valid = (target != ignore_index) & (mask == 1)
    # Indices outside the range [0, num_classes) will be encoded as zeros:
    target = nn.one_hot(target, num_classes=logits.shape[-1])
    loss = jnp.einsum("BLH,BLH->BL", target, nn.log_softmax(logits, axis=-1))
    # loss = jnp.einsum("BL,BL->BL", loss, valid)
    loss = jnp.sum(loss * valid) / jnp.sum(mask)  # mean reduction on sequene level
    return -loss


def cross_entropy_loss_lm(logits, target, mask, ignore_index=-100):
    """
    Ignores masked tokens from the loss. Specifically designed for the copy task
    Args:
        logits: jnp.array(BLH)
        target: jnp.array(BL, dtype=long)
        ignore_index: must be a negative value
    """
    valid = (target != ignore_index) & (mask == 1)
    # Indices outside the range [0, num_classes) will be encoded as zeros:
    target = nn.one_hot(target, num_classes=logits.shape[-1])
    loss = jnp.einsum("BLH,BLH->BL", target, nn.log_softmax(logits, axis=-1))
    # loss = jnp.einsum("BL,BL->BL", loss, valid)
    loss = jnp.sum(loss * valid) / jnp.sum(mask)  # mean reduction on sequene level
    return -loss


class LMHeadCopy(nn.Module):
    config: Config
    vocab_size: int
    pad_id: int
    ignore_index: int

    @nn.compact
    def __call__(
        self,
        input_ids: jnp.array,
        mask: jnp.array,
        labels: jnp.array = None,
        train: bool = False,
    ) -> Dict[str, jnp.array]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """
        encoder = Decoder(
            vocab_size=self.vocab_size,
            nlayers=self.config.nlayers,
            hidden_dim=self.config.hidden_dim,
            max_seq_len=self.config.max_seq_len,
            pos_embed_max_len=self.config.pos_embed_max_len,
            L=self.config.L,
            unroll=self.config.unroll,
            nheads=self.config.nheads,
            dropout=self.config.dropout,
            prenorm=self.config.prenorm,
            batchnorm=self.config.batchnorm,
            block_type=self.config.block_type,
            attention_type=self.config.attention_type,
        )

        head = nn.Dense(self.vocab_size, dtype=jnp.float32)

        input_ids = input_ids[:, :-1]
        labels = labels[:, 1:]
        mask = mask[:, 1:]
        X = encoder(input_ids, train=train, do_inference=False, cache=None)  # BLH

        if self.config.prenorm:
            if self.config.batchnorm:
                X = nn.BatchNorm(use_running_average=not train, momentum=0.9)(X)
            else:
                X = nn.LayerNorm()(X)

        logits = head(X)  # BLH -> BLV
        # ignore pad tokens
        loss = cross_entropy_loss_lm(
            logits=logits, target=labels, mask=mask, ignore_index=self.ignore_index
        )

        return {"loss": loss, "logits": logits}


class LMHeadCopyHugg(nn.Module):
    config: Config
    vocab_size: int
    pad_id: int
    ignore_index: int

    @nn.compact
    def __call__(
        self,
        input_ids: jnp.array,
        mask: jnp.array,
        labels: jnp.array = None,
        train: bool = False,
    ) -> Dict[str, jnp.array]:
        """
        Args:
            input_ids: jnp.array(BL) - input ids
            labels: jnp.array(BL)
            train: bool - used for dropout
        Returns:
            out: Dict[str, jnp.array] - loss and logits
        """
        from transformers import FlaxGPTNeoForCausalLM, GPTNeoConfig

        config = GPTNeoConfig(
            bos_token_id=0,
            eos_token_id=0,
            hidden_size=self.config.hidden_dim,
            intermediate_size=self.config.hidden_dim * 4,
            num_attention_heads=self.config.nheads,
            num_hidden_layers=self.config.nlayers,
            vocab_size=self.vocab_size,
        )
        encoder = FlaxGPTNeoForCausalLM(config).module

        head = nn.Dense(self.vocab_size, dtype=jnp.float32)

        input_ids = input_ids[:, :-1]
        labels = labels[:, 1:]
        mask = mask[:, 1:]
        batch_size, seq_length = input_ids.shape[0:2]
        attention_mask = jnp.ones((batch_size, seq_length))
        position_ids = jnp.broadcast_to(
            jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
        )
        X = encoder(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            deterministic=not train,
        ).logits  # BLH

        logits = head(X)  # BLH -> BLV
        # ignore pad tokens
        loss = cross_entropy_loss_lm(
            logits=logits, target=labels, mask=mask, ignore_index=self.ignore_index
        )

        return {"loss": loss, "logits": logits}
