import itertools
import numpy as np
import time
import torch

from tqdm import tqdm
from transformers import (
    PreTrainedModel,
    PreTrainedTokenizerBase,
)
from typing import (
    Dict,
    List,
    Tuple,
)

from .data_utils import get_test_dataset


@torch.no_grad()
def eff_eval(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    dataset: str,
    sequence_length: int,
    generated_sequence_length: int,
    batch_size: int,
    device: torch.device,
    evaluation_batch_number: int = 10,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
) -> Dict[str, float]:
    model.eval()
    torch.cuda.empty_cache()
    weights_memory = torch.cuda.memory_allocated(device=device)

    dataloader = get_test_dataset(
        tokenizer=tokenizer,
        data_name=dataset,
        sequence_length=sequence_length,
        batch_size=batch_size,
    )

    end_memory = 0
    min_start_memory = float('inf')
    total_time = 0
    full_token_number = 0
    partial_token_number = 0

    for batch_idx, batch in enumerate(
            tqdm(
                iterable=itertools.islice(dataloader, evaluation_batch_number),
                desc=f'[Evaluating Throughput]',
                dynamic_ncols=True,
            )):
        _batch = batch.to(device=device)

        full_token_number += \
            _batch.shape[0] * (sequence_length + generated_sequence_length)
        partial_token_number += _batch.shape[0] * generated_sequence_length

        torch.cuda.reset_peak_memory_stats(device=device)
        start_memory = torch.cuda.memory_allocated(device=device)
        min_start_memory = min(
            min_start_memory,
            start_memory,
        )
        torch.cuda.synchronize(device=device)

        start_time = time.time()
        generation_output = model.generate(
            input_ids=_batch,
            pad_token_id=tokenizer.eos_token_id,
            max_length=(sequence_length + generated_sequence_length),
            do_sample=True,
            use_cache=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )
        torch.cuda.synchronize(device=device)
        end_time = time.time()

        end_memory = max(
            end_memory,
            torch.cuda.max_memory_allocated(device=device),
        )

        total_time += (end_time - start_time)

    outputs = {
        'total_memory': format(end_memory / (1024**3), '.2f'),
        'weights_memory': format(weights_memory / (1024**3), '.2f'),
        'activation_memory': format((end_memory - start_memory) / (1024**3),
                                    '.2f'),
        'full_throughput': format(full_token_number / total_time, '.2f'),
        'partial_throughput': format(partial_token_number / total_time, '.2f'),
    }
    return outputs


@torch.no_grad()
def ppl_eval(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    datasets: List[str],
    sequence_length: int,
    batch_size: int,
    device: torch.device,
) -> Tuple[Dict[str, float], Dict[str, List[int]]]:
    """ Evaluate the perplexity of the model.

    Args:
        model (PreTrainedModel): The model to be evaluated.
        tokenizer (PreTrainedTokenizerBase): The tokenizer.
        datasets (List[str]): The test datasets. This value can be 'c4' or 'wikitext2'.
        sequence_length (int): The sequence length.
        batch_size (int): The batch size.
        device (torch.device): The device to be used.

    Returns:
        Tuple[Dict[str, float], Dict[str, List[int]]]: A tuple containing the perplexity and the chosen experts counts.
    """

    all_chosen_experts_counts = {}
    ppls = {}

    for dataset in datasets:
        dataloader = get_test_dataset(
            tokenizer=tokenizer,
            data_name=dataset,
            sequence_length=sequence_length,
            batch_size=batch_size,
        )

        dataset_chosen_experts_counts = None
        if hasattr(model, 'molos_config'):
            dataset_chosen_experts_counts = []

        nlls = []
        for batch_idx, batch in enumerate(iterable=tqdm(
                iterable=dataloader,
                desc=f'[Evaluating {dataset}]',
                dynamic_ncols=True,
        )):
            batch = batch.to(device=device)
            output = model(
                batch,
                use_cache=False,
            )

            if hasattr(output, 'router_logits') and \
                    output.router_logits is not None:
                batch_chosen_experts_indices = []
                for layer_router_logits in output.router_logits:
                    chosen_experts_indices = torch.topk(
                        input=layer_router_logits,
                        k=model.molos_config.selected_ex_num,
                        dim=-1,
                    ).indices
                    batch_chosen_experts_indices.append(
                        chosen_experts_indices.view(-1))

                all_layer_chosen_experts_counts = [
                    0 for _ in range(model.molos_config.ex_num)
                ]
                for layer_idx, chosen_experts_indices in \
                        enumerate(iterable=batch_chosen_experts_indices):
                    one_layer_chosen_experts_counts = [
                        0 for _ in range(model.molos_config.ex_num)
                    ]

                    batch_chosen_experts_counts = torch.bincount(
                        input=batch_chosen_experts_indices[layer_idx],
                        minlength=model.molos_config.ex_num,
                    )
                    for expert_idx in range(model.molos_config.ex_num):
                        one_layer_chosen_experts_counts[expert_idx] = \
                            batch_chosen_experts_counts[expert_idx].item()
                        all_layer_chosen_experts_counts[expert_idx] += \
                            one_layer_chosen_experts_counts[expert_idx]

                    if batch_idx == 0:
                        dataset_chosen_experts_counts.append(
                            one_layer_chosen_experts_counts)
                    else:
                        dataset_chosen_experts_counts[layer_idx] = [
                            x + y for x, y in zip(
                                dataset_chosen_experts_counts[layer_idx],
                                one_layer_chosen_experts_counts,
                                strict=True,
                            )
                        ]

                if batch_idx == 0:
                    dataset_chosen_experts_counts.append(
                        all_layer_chosen_experts_counts)
                else:
                    dataset_chosen_experts_counts[-1] = [
                        x + y for x, y in zip(
                            dataset_chosen_experts_counts[-1],
                            all_layer_chosen_experts_counts,
                            strict=True,
                        )
                    ]

            lm_logits = output.logits
            if torch.isfinite(input=lm_logits).all():
                shift_logits = lm_logits[:, :-1, :].contiguous()
                shift_labels = batch[:, 1:].contiguous()

                loss_function = torch.nn.CrossEntropyLoss(reduction='none')
                loss = loss_function(
                    shift_logits.reshape(-1, shift_logits.size(-1)),
                    shift_labels.view(-1),
                )
                nlls.append(loss)

        all_chosen_experts_counts[dataset] = dataset_chosen_experts_counts

        ppl = np.exp(torch.cat(
            tensors=nlls,
            dim=-1,
        ).mean().item())
        ppls[dataset] = ppl.item()

    return ppls, all_chosen_experts_counts


@torch.no_grad()
def qa_load_balance_eval(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    datasets: List[str],
    sequence_length: int,
    batch_size: int,
    device: torch.device,
) -> Dict[str, List[int]]:
    all_chosen_experts_counts = {}

    for dataset in datasets:
        dataloader = get_test_dataset(
            tokenizer=tokenizer,
            data_name=dataset,
            sequence_length=sequence_length,
            batch_size=batch_size,
        )

        dataset_chosen_experts_counts = None
        if hasattr(model, 'molos_config'):
            dataset_chosen_experts_counts = []

        for batch_idx, batch in enumerate(iterable=tqdm(
                iterable=dataloader,
                desc=f'[Evaluating {dataset}]',
                dynamic_ncols=True,
        )):
            batch = batch.to(device=device)
            output = model(
                batch,
                use_cache=False,
            )

            if hasattr(output, 'router_logits') and \
                    output.router_logits is not None:
                batch_chosen_experts_indices = []
                for layer_router_logits in output.router_logits:
                    chosen_experts_indices = torch.topk(
                        input=layer_router_logits,
                        k=model.molos_config.selected_ex_num,
                        dim=-1,
                    ).indices
                    batch_chosen_experts_indices.append(
                        chosen_experts_indices.view(-1))

                all_layer_chosen_experts_counts = [
                    0 for _ in range(model.molos_config.ex_num)
                ]
                for layer_idx, chosen_experts_indices in \
                        enumerate(iterable=batch_chosen_experts_indices):
                    one_layer_chosen_experts_counts = [
                        0 for _ in range(model.molos_config.ex_num)
                    ]

                    batch_chosen_experts_counts = torch.bincount(
                        input=batch_chosen_experts_indices[layer_idx],
                        minlength=model.molos_config.ex_num,
                    )
                    for expert_idx in range(model.molos_config.ex_num):
                        one_layer_chosen_experts_counts[expert_idx] = \
                            batch_chosen_experts_counts[expert_idx].item()
                        all_layer_chosen_experts_counts[expert_idx] += \
                            one_layer_chosen_experts_counts[expert_idx]

                    if batch_idx == 0:
                        dataset_chosen_experts_counts.append(
                            one_layer_chosen_experts_counts)
                    else:
                        dataset_chosen_experts_counts[layer_idx] = [
                            x + y for x, y in zip(
                                dataset_chosen_experts_counts[layer_idx],
                                one_layer_chosen_experts_counts,
                                strict=True,
                            )
                        ]

                if batch_idx == 0:
                    dataset_chosen_experts_counts.append(
                        all_layer_chosen_experts_counts)
                else:
                    dataset_chosen_experts_counts[-1] = [
                        x + y for x, y in zip(
                            dataset_chosen_experts_counts[-1],
                            all_layer_chosen_experts_counts,
                            strict=True,
                        )
                    ]

        all_chosen_experts_counts[dataset] = dataset_chosen_experts_counts

    return all_chosen_experts_counts
