# The task of training a linear head to use learned representations to perform classification.

import tensorflow as tf
from keras.datasets import mnist
import numpy as np
from functools import reduce
import matplotlib.pyplot as plt

from . import task

# linear head architecture constants
representation_dimension = 2048
n_classes = 100

class CNN():
    @staticmethod
    def evaluate(x, weights_list_list):
        x = tf.nn.conv2d(x, weights_list_list[0][0], 1, "SAME") + weights_list_list[0][1]
        x = reduce(lambda images, weights_list: CNN.poolblock(images, weights_list), weights_list_list[1:-1], x)
        return tf.einsum('io,ni->no', weights_list_list[-1][0], x[:,0,0,:]) + weights_list_list[-1][1]

class ClassificationHeadTask(task.Task):
    """
    This class represents the task of using a linear classification head on pretrained feature vectors.
    """

    def __init__(self):
        self.batch_size = 256
        self.data_file_number = -1
        self.batch_offset = 0
        self.data_in_batch = 0
        self.regularization = 0.00005
        self.sample_new_batch()

    def sample_new_batch(self):
        self.batch_representations = np.zeros([0, representation_dimension], dtype=np.float64)
        self.batch_targets = np.zeros([0], dtype=np.int32)
        batch_left_to_read = self.batch_size
        while batch_left_to_read > 0:
            if self.batch_offset == self.data_in_batch:
                self.data_file_number += 1
                if self.data_file_number >= 99:
                    self.data_file_number = 0
                with open("./cifar100_representations/shuffled_inputs_" + str(self.data_file_number) + ".npy", "rb") as f:
                    self.training_representations = np.load(f)
                with open("./cifar100_representations/shuffled_targets_" + str(self.data_file_number) + ".npy", "rb") as f:
                    self.training_targets = np.load(f)
                self.data_in_batch = self.training_representations.shape[0]
                self.batch_offset = 0
            data_to_use = min(self.data_in_batch-self.batch_offset, batch_left_to_read)
            self.batch_representations = np.concatenate([self.batch_representations, self.training_representations[self.batch_offset:self.batch_offset+data_to_use]], axis=0)
            self.batch_targets = np.concatenate([self.batch_targets, self.training_targets[self.batch_offset:self.batch_offset+data_to_use]], axis=0)
            batch_left_to_read -= data_to_use
            self.batch_offset += data_to_use

    def get_initialization(self):
        """
        Get an initialization for the linear head in flattened weight vector form.
        """
        return tf.convert_to_tensor(np.zeros([representation_dimension*n_classes + n_classes], dtype=np.float64))

    def __call__(self, weights, new_batch=True):
        """
        Produce classification logits given some feature vectors.
        """
        if new_batch:
            self.sample_new_batch()

        return self.evaluate_loss_on_batch(weights, (self.batch_representations, self.batch_targets))

    def evaluate_loss_on_batch(self, weights, batch):
        """
        Produce the classification loss given a batch consisting of representations and true classes.
        """
        representations, classes = batch
        weight, bias = tf.reshape(weights[:-n_classes], [representation_dimension, n_classes]), weights[-n_classes:]
        logits = tf.einsum("nr,rc->nc", representations, weight) + bias
        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(tf.cast(classes, tf.int32), logits)
        return tf.math.reduce_mean(loss) + self.regularization*tf.math.reduce_sum(weights**2)

    def evaluate_loss_on_dataset(self, weights, dataset):
        """
        Evaluate the full-batch loss on a dataset.
        """
        representations, classes = dataset
        n_representations = representations.shape[0]
        batch_size = 256
        representations = [representations[i::batch_size,:] for i in range(batch_size)]
        classes = [classes[i::batch_size] for i in range(batch_size)]
        total_loss = 0
        for i in range(batch_size):
            total_loss = total_loss + representations[i].shape[0]*self.evaluate_loss_on_batch(weights, (representations[i], classes[i]))
        return total_loss / n_representations

    def evaluate_validation_loss(self, weights):
        """
        Evaluate the full-batch loss on the validatation set.
        """
        with open("./cifar100_representations/shuffled_inputs_99.npy", "rb") as f:
            validation_representations = np.load(f)
        with open("./cifar100_representations/shuffled_targets_99.npy", "rb") as f:
            validation_targets = np.load(f)
        return self.evaluate_loss_on_dataset(weights, (validation_representations, validation_targets))

    def evaluate_test_loss(self, weights):
        """
        Evaluate the full-batch loss on the validatation set.
        """
        with open("./cifar100_representations/shuffled_test_inputs.npy", "rb") as f:
            test_set_representations = np.load(f)[:,0,:]
        with open("./cifar100_representations/shuffled_test_targets.npy", "rb") as f:
            test_set_targets = np.load(f)[:,0]
        return self.evaluate_loss_on_dataset(weights, (test_set_representations, test_set_targets))

    def evaluate_accuracy_on_batch(self, weights, batch):
        """
        Produce the classification top-1 accuracy given a batch consisting of representations and true classes.
        """
        representations, classes = batch
        weight, bias = tf.reshape(weights[:-n_classes], [representation_dimension, n_classes]), weights[-n_classes:]
        logits = tf.einsum("nr,rc->nc", representations, weight) + bias
        accuracy = tf.math.reduce_mean(tf.cast(tf.math.argmax(logits, axis=1)==classes, tf.float64))
        return accuracy

    def evaluate_accuracy_on_dataset(self, weights, dataset):
        """
        Evaluate the full-batch loss on a dataset.
        """
        representations, classes = dataset
        n_representations = representations.shape[0]
        batch_size = 256
        representations = [representations[i::batch_size,:] for i in range(batch_size)]
        classes = [classes[i::batch_size] for i in range(batch_size)]
        total_loss = 0
        for i in range(batch_size):
            total_loss = total_loss + representations[i].shape[0]*self.evaluate_accuracy_on_batch(weights, (representations[i], classes[i]))
        return total_loss / n_representations

    def evaluate_validation_accuracy(self, weights):
        """
        Evaluate the full-batch loss on the validatation set.
        """
        with open("./cifar100_representations/shuffled_inputs_99.npy", "rb") as f:
            validation_representations = np.load(f)
        with open("./cifar100_representations/shuffled_targets_99.npy", "rb") as f:
            validation_targets = np.load(f)
        return self.evaluate_accuracy_on_dataset(weights, (validation_representations, validation_targets))

    def evaluate_test_accuracy(self, weights):
        """
        Evaluate the full-batch loss on the validatation set.
        """
        with open("./cifar100_representations/shuffled_test_inputs.npy", "rb") as f:
            test_set_representations = np.load(f)[:,0,:]
        with open("./cifar100_representations/shuffled_test_targets.npy", "rb") as f:
            test_set_targets = np.load(f)[:,0]
        return self.evaluate_accuracy_on_dataset(weights, (test_set_representations, test_set_targets))
