from typing import Tuple, Callable, Dict
from .base import Evaluator
import jax
from jax import numpy as jnp
import numpy as np
from latte_trans.preproc.copy import get_tokenizer, get_eval_dataset


def find(s, ch):
    return [i for i, ltr in enumerate(s) if ltr == ch]


def get_score(args, tokenizer, x, pred, i):
    x_out = tokenizer.decode(x[i])
    x_out = x_out.split(".")[0] + "."
    pred_out = tokenizer.decode(pred[i])

    if args.eval_task == "prefix_ngram":
        index = find(x_out, "|")[-1]
    elif args.eval_task in ["suffix_ngram", "copy", "duplicate_ngram"]:
        index = x_out.index("|")

    if args.eval_task == "suffix_ngram":
        gt = x_out[index + 1 + args.n_gram :][:-1]
        start_idx = index + args.n_gram
    else:
        gt = x_out[index + 1 :][:-1]
        start_idx = index

    end_idx = start_idx + len(gt)
    pred_model = pred_out[start_idx:end_idx]

    str_acc = int(gt == pred_model)
    char_acc = sum(map(str.__eq__, gt, pred_model)) / max(len(gt), len(pred_model))

    return str_acc, char_acc


class CopyEvaluator(Evaluator):
    def __init__(
        self,
        model,
        tokenizer,
        TO_TOKEN,
        config,
    ) -> None:

        self._tokenizer = tokenizer
        self._model = model
        self._batchnorm = config.batchnorm
        self._batch_size = config.batch_size
        self._config = config
        self._TO_TOKEN = TO_TOKEN

    def compute_metrics(self, output, labels):
        return output["loss"]

    def evaluate(
        self,
        trainer_eval_fn: Callable[[str, jax.Array], Tuple[jax.Array]],
        prefix="eval_",
        **kwargs,
    ) -> Dict[str, jax.Array]:
        """Iterate over validation data, get outputs from trainer eval
        and compute metrics.
        Decouple from trainer to add data-specific evaluation logic:
            - squad split in overlapping windows
            - language do generation from promts
        Args:
            trainer_eval_fn: Callable[Dict[str, np.array]] -> Tuple
                Function which places data to devices by trainer sharding.
                Contains platform specific model call. Outputs "labels", "model_output"
            prefix: str - used to rename metrics depending on eval/test data
            state: Trainer State - used for post eval
        """
        # scores = {}
        # progress_bar = tqdm(
        #     range(len(self._val_loader)), position=0, leave=True, initial=0
        # )
        # for batch in self._val_loader:
        #     labels, output = trainer_eval_fn(batch)
        #     metrics = self.compute_metrics(output, labels)
        #     for k in metrics.keys():
        #         scores[k] = scores.get(k, 0) + metrics[k]
        #     progress_bar.update(1)

        # scores = {prefix + k: v / len(self._val_loader) for k, v in scores.items()}

        # return scores
        lengths = np.arange(self._config.min_eval_len, self._config.max_eval_len)

        str_acc_mean_list = []
        str_acc_std_list = []
        char_accuracy_list = []

        for ood_length in lengths:
            str_acc_batch = np.zeros(self._config.eval_num_batches)
            char_acc_mean = 0

            for jj in range(self._config.eval_num_batches):

                long_dataset = get_eval_dataset(
                    self._config,
                    self._tokenizer,
                    self._TO_TOKEN,
                    target_min_len=ood_length,
                    target_max_len=ood_length,
                )
                batch = next(iter(long_dataset))

                # print("-" * 100)
                # print(f"EXAMPLE {batch['input'][0]}")
                # print("-" * 100)
                # print(
                #     batch["input_ids"][-1][batch["mask"][-1] == 1],
                #     batch["input_ids"][-1],
                #     batch["input"][-1],
                # )
                # print("*" * 100)

                x = batch["input_ids"]

                # stop grad
                ##prediction
                # logits = model(x)["logits"]
                labels, output = trainer_eval_fn(batch)
                logits = output["logits"]
                ##greedy decoding
                pred = jnp.argmax(logits, axis=-1)
                pred = np.array(jax.device_get(pred))

                ##evaluation
                for i in range(len(x)):
                    str_acc, char_acc = get_score(
                        self._config, self._tokenizer, x, pred, i
                    )

                    str_acc_batch[jj] += str_acc
                    char_acc_mean += char_acc

            str_acc_batch = str_acc_batch / len(x)
            mean_str_acc = np.mean(str_acc_batch)
            std_str_acc = np.std(str_acc_batch)

            str_acc_mean_list.append(mean_str_acc)
            str_acc_std_list.append(std_str_acc)

            mean_char_acc = char_acc_mean / (len(x) * self._config.eval_num_batches)
            char_accuracy_list.append(mean_char_acc)

            print(
                f"{self._config.eval_task}; len {ood_length}: {mean_str_acc} +- {std_str_acc}; char: {mean_char_acc}"
            )
        print("\n")
        res = {}
        for i in range(len(char_accuracy_list)):
            res["char_accuracy_len_" + str(i)] = char_accuracy_list[i]
        for i in range(len(str_acc_mean_list)):
            res["str_acc_mean_list_len_" + str(i)] = str_acc_mean_list[i]
        for i in range(len(str_acc_std_list)):
            res["str_acc_std_list_len_" + str(i)] = str_acc_std_list[i]
        return res
        # return str_acc_mean_list, str_acc_std_list, char_accuracy_list
