from typing import Callable, Tuple, Dict
from functools import partial
from tqdm import tqdm
import math
from jax import numpy as jnp
from torch.utils.data import DataLoader
import jax
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from flax import linen as nn
from .base import Evaluator
from .losses import cross_entropy_loss_lm
import time
import numpy as np


def pred_acc_lm(output, labels):
    logits = output["logits"][:, :-1, :]
    labels = labels[:, 1:]
    n_valid = (labels != -100).sum()
    loss = output["loss"]
    accuracy = jnp.sum(jnp.argmax(logits, -1) == labels) / n_valid
    metrics = {
        "loss": loss,
        "bpc": loss / math.log(2),
        "ppl": jnp.exp(loss),
        "accuracy": accuracy,
    }
    return metrics


def pred_acc_lm_mem(output, labels):
    """
    Does not do any extra calculations outside of
    jitted trainer eval to save memory.
    Or one could jit and distribute this function as per LanguageEvaluatorSeq
    """
    loss = output["loss"]
    metrics = {
        "loss": loss,
        "bpc": loss / math.log(2),
        "ppl": jnp.exp(loss),
    }
    return metrics


class LanguageEvaluator(Evaluator):
    def __init__(
        self,
        model,
        tokenizer,
        val_data,
        data_collator,
        config,
        rng,
    ) -> None:
        super().__init__(val_data, data_collator, config)

        self._tokenizer = tokenizer
        self._model = model
        self._batchnorm = config.batchnorm
        self._batch_size = config.batch_size
        self._config = config
        self._rng = rng
        self._total_samples = (
            len(self._val_loader)
            if self._config.eval_samples is None
            else self._config.eval_samples
        )
        assert self._total_samples <= len(self._val_loader)

    def evaluate(
        self,
        trainer_eval_fn: Callable[[str, jax.Array], Tuple[jax.Array]],
        prefix="eval_",
        **kwargs,
    ) -> Dict[str, jax.Array]:
        """Iterate over validation data, get outputs from trainer eval
        and compute metrics.
        Decouple from trainer to add data-specific evaluation logic:
            - squad split in overlapping windows
            - language do generation from promts
        Args:
            trainer_eval_fn: Callable[Dict[str, np.array]] -> Tuple
                Function which places data to devices by trainer sharding.
                Contains platform specific model call. Outputs "labels", "model_output"
            prefix: str - used to rename metrics depending on eval/test data
            state: Trainer State - used for post eval
        """
        scores = {}
        print(self._config.eval_samples)
        progress_bar = tqdm(
            range(self._total_samples), position=0, leave=True, initial=0
        )
        it = 0
        val_iter = iter(self._val_loader)
        all_tines = []
        while it < self._total_samples:
            it += 1
            start = time.perf_counter()
            batch = next(val_iter)
            labels, output = trainer_eval_fn(batch)
            metrics = self.compute_metrics(output, labels)
            for k in metrics.keys():
                scores[k] = scores.get(k, 0) + metrics[k]
            end = time.perf_counter()
            res = (end - start) * 1000
            all_tines.append(res)
            # jax.debug.print("Time taken in MS: {}", res)
            progress_bar.update(1)

        scores = {prefix + k: v / self._total_samples for k, v in scores.items()}
        # print(
        #     "Final mean: ", np.mean(all_tines[3:]), "Final Var: ", np.std(all_tines[3:])
        # )
        return scores

    def compute_metrics(self, output, labels):
        return pred_acc_lm_mem(output, labels)


def cross_entropy_seq_lm(logits, lables, ignore_index=-100):
    """
    Args:
        logits: jnp.array(BLH)
        lables: jnp.array(BL, dtype=long)
        ignore_index: must be a negative value
    """
    lables = nn.one_hot(lables, num_classes=logits.shape[-1])
    loss = -jnp.einsum("BLH,BLH->BL", lables, nn.log_softmax(logits, axis=-1))
    loss = loss.sum(axis=0, keepdims=True) / logits.shape[0]
    return loss


class LanguageEvaluatorSeq(Evaluator):
    """
    Get Next token PPL on increasing context length (4048 - more)
    """

    def __init__(
        self,
        model,
        tokenizer,
        val_data,
        data_collator,
        config,
        eval_batch_size,
        rng,
    ) -> None:

        self._config = config
        self._val_loader = DataLoader(
            val_data,
            batch_size=eval_batch_size,
            shuffle=False,
            collate_fn=data_collator,
            drop_last=True,
        )

        self._tokenizer = tokenizer
        self._model = model
        self._batchnorm = config.batchnorm
        self._batch_size = eval_batch_size
        self._config = config
        self._rng = rng

        self._total_samples = (
            len(self._val_loader)
            if self._config.eval_samples is None
            else self._config.eval_samples
        )
        assert self._total_samples <= len(self._val_loader)

    def evaluate(
        self,
        trainer_eval_fn: Callable[[str, jax.Array], Tuple[jax.Array]],
        prefix="eval_",
        **kwargs,
    ) -> Dict[str, jax.Array]:
        """Iterate over validation data, get outputs from trainer eval
        and compute metrics.
        Decouple from trainer to add data-specific evaluation logic:
            - squad split in overlapping windows
            - language do generation from promts
        Args:
            trainer_eval_fn: Callable[Dict[str, np.array]] -> Tuple
                Function which places data to devices by trainer sharding.
                Contains platform specific model call. Outputs "labels", "model_output"
            prefix: str - used to rename metrics depending on eval/test data
            nr_merges: int - Number of examples to concatenate in a batch to create longer sequences.
                Based on the fact that data is not shuffeled
        """
        scores = {}
        progress_bar = tqdm(
            range(self._total_samples), position=0, leave=True, initial=0
        )
        total_loss = 0
        import numpy as np

        # distribute loss computation across devices:
        num_devices = len(jax.devices())
        devices = mesh_utils.create_device_mesh((num_devices,))
        mesh = Mesh(devices, axis_names=("B",))
        loss_fn = jax.jit(
            shard_map(
                jax.checkpoint(partial(cross_entropy_seq_lm, ignore_index=-100)),
                mesh,
                in_specs=(
                    PartitionSpec("B"),
                    PartitionSpec("B"),
                ),
                out_specs=PartitionSpec(),
                check_rep=False,
            ),
        )
        it = 0
        for batch in self._val_loader:
            if it > self._total_samples:
                break
            it += 1
            lables = batch["labels"][:, 1:]
            # no need to calculate loss before
            batch.pop("labels")  #
            batch["labels"] = None
            _, output = trainer_eval_fn(batch)
            # need to keep seq len dim
            logits = output["logits"][:, :-1, :]
            logits = jax.lax.stop_gradient(logits)
            lables = jax.lax.stop_gradient(lables)
            loss = loss_fn(logits, lables)
            # loss = jnp.where(valid, loss, 0.0).sum(axis=0)  # sum on batch
            # loss = loss / valid.sum(axis=0, keepdims=True)  # num example in batch
            # every 100 steps to reduce memory
            total_loss += np.array(loss)
            progress_bar.update(1)
            # release memory of previous computation
            output = {}
            loss = 0
            logits = None
            lables = None

        total_loss /= self._total_samples
        print(total_loss.shape)
        # jax.debug.print("{y}", y=jnp.exp(total_loss)[0, 800:1200])
        # jax.debug.print("Test {x}", x=jnp.arange(total_loss.shape[1]))
        total_loss_avg = total_loss.cumsum(axis=1) / (
            1 + np.arange(total_loss.shape[1])
        )
        scores = {
            "loss_seq_len": total_loss,
            "PPL": np.exp(total_loss),
            "PPL_mean": np.exp(total_loss_avg),
        }
        return scores

    def compute_metrics(self, output, labels):
        return pred_acc_lm(output, labels)
