import tensorflow as tf
import numpy as np

from nlp_transformer import masked_accuracy


def get_evaluator(args, model, loss_fn, datasets, large_datasets, test_label_pairs):
    ev = Evaluator(args, model, loss_fn, datasets, large_datasets, test_label_pairs)
    return ev


class Evaluator(object):
    def __init__(self, args, model, loss_fn, datasets, large_datasets,
                 test_label_pairs):
        self.args = args
        self.model = model
        self.datasets = datasets
        self.large_datasets = large_datasets
        self.test_label_pairs = set(test_label_pairs)
        self.loss_fn = loss_fn

    def forward(self, x):
        y1_hat, y2_hat = [], []
        size = self.args.batch_size
        for i in range(0, len(x), size):
            j = min(i + size, len(x))
            y1, y2 = self.model(x[i:j])
            y1 = tf.argmax(y1, -1).numpy()
            y2 = tf.argmax(y2, -1).numpy()
            y1_hat.extend(y1)
            y2_hat.extend(y2)
        return y1_hat, y2_hat

    def get_accuracy(self, y_hat, y):
        n_samples = len(y[0])
        y1_hat_list = y_hat[0]
        y2_hat_list = y_hat[1]
        hit1, hit2, hit, sg_hit = 0, 0, 0, 0
        for i in range(n_samples):
            y1_hat = y1_hat_list[i]
            y2_hat = y2_hat_list[i]
            if (y1_hat, y2_hat) in self.test_label_pairs:
                sg_hit += 1
            h1 = y[0][i][y1_hat] == 1
            h2 = y[1][i][y2_hat] == 1
            if h1:
                hit1 += 1
            if h2:
                hit2 += 1
            if h1 and h2:
                hit += 1
        acc = hit / n_samples
        acc1 = hit1 / n_samples
        acc2 = hit2 / n_samples
        sg_acc = sg_hit / n_samples
        return acc1, acc2, acc, sg_acc

    def evaluate(self, x, y):
        if not isinstance(y, list):
            all_grads, all_loss, all_acc = self.get_gradient_loss_acc([x, y])
            return 0, 0, all_acc, 0
        y_hat = self.forward(x)
        return self.get_accuracy(y_hat, y)

    def test_evaluate(self, x, y):
        return self.evaluate(x, y)

    def evaluate_datasets(self, datasets):
        ret = []
        for dataset in datasets:
            ret.extend(
                self.evaluate(dataset[0], dataset[1]))
            ret.append("\t")
        return ret

    def evaluate_all(self):
        return self.evaluate_datasets(self.datasets)

    def large_evaluate_all(self):
        return self.evaluate_datasets(self.large_datasets)

    def get_output(self, y):
        y = np.asarray(y)
        y = np.argmax(y, axis=-1)
        y = np.transpose(np.asarray(y))
        return y

    def compute_accuracy(self, y_hat, y):
        if not isinstance(y, list):
            return masked_accuracy(y, y_hat).numpy()
        y_hat = self.get_output(y_hat)
        y = self.get_output(y)
        hit = 0.0
        for a, b in zip(y, y_hat):
            if a[0] == b[0] and a[1] == b[1]:
                hit += 1
        return hit / len(y)

    def get_value(self, g):
        if g is None:
            return None
        if isinstance(g, tf.IndexedSlices):
            g = tf.convert_to_tensor(g)
        return g.numpy().flatten()

    def batch_process(self, x, y):
        model, loss_fn = self.model, self.loss_fn
        with tf.GradientTape() as tape:
            y_hat = model(x)
            loss = loss_fn(y, y_hat)
        grads = tape.gradient(loss, model.trainable_variables)
        grads = [self.get_value(g) for g in grads]
        grads = [g for g in grads if g is not None]
        grads = np.concatenate(grads, 0)
        acc = self.compute_accuracy(y_hat, y)
        return grads, loss.numpy(), acc

    def get_gradient_loss_acc(self, data):
        x, y = data
        size = self.args.batch_size

        all_grads, all_loss, all_acc = None, 0, 0
        all_weight = 0
        for i in range(0, len(x), size):
            j = min(i + size, len(x))
            batch_x = x[i:j]
            if isinstance(y, list):
                batch_y = [y_k[i:j] for y_k in y]
            else:
                batch_y = y[i:j]
            grads, loss, acc = self.batch_process(batch_x, batch_y)
            weight = len(x[i:j])

            if all_grads is None:
                all_grads = weight * grads
            else:
                assert all_grads.shape == grads.shape
                all_grads += weight * grads
            all_loss += weight * loss
            all_acc += weight * acc
            all_weight += weight
        all_grads /= all_weight
        all_loss /= all_weight
        all_acc /= all_weight
        return all_grads, all_loss, all_acc
