import tensorflow as tf
import numpy as np
import quadprog

try:
    from . import core
except:
    import sys, os
    sys.path.append(os.path.abspath(__file__))
    import core

class TaskReplayMemory:
    """
    A simple Replay memory for GEM
    """
    def __init__(self, num_tasks, memory_size, input_shape, num_outputs):
        if not isinstance(input_shape, list):
            input_shape = list(input_shape)
        self.input_memory = np.zeros([num_tasks, memory_size] + input_shape, dtype=np.float32)
        self.label_memory = np.zeros([num_tasks, memory_size, num_outputs], dtype=np.int64)

        self.memory_size = memory_size
        self.ptr = 0
        

    def store(self, inputs, labels, task_idx):
        batch_size = inputs.shape[0]
        start_point = self.ptr
        if self.ptr + batch_size > self.memory_size:
            end_point = self.memory_size            
            start_point2 = 0
            end_point2 = (self.ptr + batch_size) - self.memory_size

            data_limit = self.memory_size - self.ptr 

            self.input_memory[task_idx, start_point:end_point] = inputs[:data_limit]
            self.label_memory[task_idx, start_point:end_point] = labels[:data_limit]

            self.input_memory[task_idx, start_point2:end_point2] = inputs[data_limit:]
            self.label_memory[task_idx, start_point2:end_point2] = labels[data_limit:]

            self.ptr = end_point2
        else:
            end_point = self.ptr + batch_size
            self.input_memory[task_idx, start_point:end_point] = inputs
            self.label_memory[task_idx, start_point:end_point] = labels



class GradientEpisodicMemory:
    def __init__(self, sess, transform, input_shape, num_outputs, 
                 num_tasks, memory_strength, learning_rate,
                 network=core.mlp, hidden_layers=[100, 100]):
        self.sess = sess
        self.learning_rate = learning_rate

        input_batch_shape = [None] + input_shape    
        self.inputs = tf.placeholder(tf.float32, shape=input_batch_shape)
        self.labels = tf.placeholder(tf.int64, shape=[None, num_outputs])
        #self.loss_weights = tf.placeholder(tf.float32, shape=[None, num_outputs])
        
        # network_architecture = hidden_layers + num_outputs

        
        if transform == 'incremental':
            self.logits = network(self.inputs, hidden_layers, num_outputs)
            # tf.assign(self.logits,)
            self.offset = tf.placeholder(tf.int64, shape=[None,]) 
            # design mask...
            self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.logits)
        else:
            self.logits = network(self.inputs, hidden_layers, num_outputs)
            self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.logits)

        self.optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        self.get_gradients = self.optimizer.compute_gradients(self.cross_entropy_loss)
        self.trainable_variables = tf.trainable_variables()

        with tf.variable_scope("assign_placeholder"):
            self.assign_placeholder = [tf.placeholder(tf.float32, shape=v.shape, name=v.name.split(":")[0])
            for v in self.trainable_variables]

        self.assign_op = [tf.assign(ref, value) for ref, value
            in zip(self.trainable_variables, self.assign_placeholder)]

        # test operators
        self.equal = tf.equal(tf.argmax(self.logits, axis=-1), tf.argmax(self.labels, axis=-1))


    def compute_previous_task_gradients(self, replaymemory, task_idx):
        if task_idx > 0:
            previous_task_gradients = []
            for task in range(task_idx):
                inputs = replaymemory.input_memory[task]
                labels = replaymemory.label_memory[task]
                # compute offset
                """
                offset1:
                offset2:
                forward(memory_data[past_task], past_task)[:, offset1:offset2],
                memory_labels[past_task] - offset1)
                self.forward(x, t)

                offset1:
                offset2:      
                
                """
                g_ = self.sess.run(self.get_gradients, feed_dict={self.inputs:inputs, self.labels:labels})
                previous_task_gradients.append(g_)        
            return previous_task_gradients #np.stack(previous_task_gradients, axis=-1)
        else:
            # print("the first task")
            return None

    def compute_current_task_gradients(self, replaymemory, task_idx):
        inputs = replaymemory.input_memory[task_idx]
        labels = replaymemory.label_memory[task_idx]
        # print(self.get_gradients)
        g_ = self.sess.run(self.get_gradients, feed_dict={self.inputs:inputs, self.labels:labels})
        return g_

    
    def _inner_product(self, g1, g2):
        matmul = tf.math.multiply(g1, g2)
        return tf.reduce_sum(matmul)

    def flatten_gradients(self, grads_and_vars):
        """
        numpy grad and var
        """
        grads_flattened_and_vars = []
        for grad, var in grads_and_vars:
            grads_flattened_and_vars.append((grad.reshape(-1), var))
        return grads_flattened_and_vars

    def concatenate_and_flatten_gradient(self, grads_and_vars):
        grads_flattened_and_vars = self.flatten_gradients(grads_and_vars)
        grads = [g for g, v in grads_flattened_and_vars]
        concatenated_grads = np.concatenate(grads)
        return concatenated_grads

    def _project2cone2(self, gradient, memories, margin=0.5, eps=1e-3):
        gradient = gradient.astype(np.double)
        memories = memories.astype(np.double)

        t = memories.shape[0]
        P = np.dot(memories, memories.transpose())
        P = 0.5 * (P + P.transpose()) + np.eye(t) * eps
        q = np.dot(memories, gradient) * -1
        G = np.eye(t)
        h = np.zeros(t) + margin
        v = quadprog.solve_qp(P, q, G, h)[0]
        x = np.dot(v, memories) + gradient
        # gradient.copy_(torch.Tensor(x).view(-1, 1))
        return x.astype(np.float32)

    def _convert_to_grads_and_vars(self, gradient, grads_and_vars):
        converted_gvs = []
        cursor = 0
        for g, v in grads_and_vars:
            num_elements = np.prod(g.shape)
            g_ = gradient[cursor:cursor + num_elements]
            converted_gvs.append((g_, v))
            cursor = cursor + num_elements

        return converted_gvs

    def gradient_projection_quadratic_programming(self, prev_g, current_g):
        """
        to get a projected gradient
        args:
            prev_g: gradients from previous tasks
            current_g: gradients from the current task
        """
        if prev_g == None:
            # return current gradient itself
            update_gradient = current_g
        else:
            prev_cfgs = [self.concatenate_and_flatten_gradient(p) for p in prev_g]
            prev_cfgs = np.stack(prev_cfgs)
            current_cfg = self.concatenate_and_flatten_gradient(current_g)
            mm = np.matmul(prev_cfgs, current_cfg)
            condition = np.sum(mm < 0)
            if condition == 0:
                update_gradient = current_g
            else:
                update_gradient_ = self._project2cone2(current_cfg, prev_cfgs)
                update_gradient = self._convert_to_grads_and_vars(update_gradient_, current_g)

        return update_gradient# gradients to update

    def train_step(self, update_gvs):
        feed_dict = dict()
        for gvs, ph in zip(update_gvs, self.assign_placeholder):
            grad, var = gvs
            new_var = var - self.learning_rate * grad.reshape(var.shape)
            feed_dict[ph] = new_var
        
        self.sess.run(self.assign_op, feed_dict=feed_dict)

    def test_accuracy(self, testloader):
        test_results = list()
        for i, data in enumerate(testloader, 0):
            inputs, labels = data
            e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                     self.labels: labels})
            test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc


def continual_learning(task_continuum, data_loader, num_tasks, input_shape, 
                       num_outputs, batch_size, learning_rate, network_architecture,
                       epochs_per_task, memory_size, log_every_step):
    """
    data_loader is not a instance, class itself.
    """
    # set logger

    # set random seed


    with tf.Session() as sess:
        # intialize model
        model = GradientEpisodicMemory(sess=sess, transform=task_continuum.transform, input_shape=input_shape, 
                                       num_outputs=num_outputs, num_tasks=num_tasks,
                                       memory_strength=None, learning_rate=learning_rate,
                                       network=core.mlp, hidden_layers=network_architecture)

        # define a replay memory
        replaymemory = TaskReplayMemory(num_tasks=num_tasks, memory_size=memory_size, 
                                        input_shape=input_shape, num_outputs=num_outputs)
        
        sess.run(tf.global_variables_initializer())

        # start training
        total_steps = 0
        for task_idx, (transform, train_data, test_data) in enumerate(task_continuum, 0):
            # define train loader and test loader
            trainloader = data_loader(train_data, batch_size=batch_size)
            testloader = data_loader(test_data, batch_size=batch_size)
            # replaymemory.initialize_new_task(task_idx)
            if task_idx == 0:
                first_task_testloader = testloader

            for epoch in range(epochs_per_task):
                for i, data in enumerate(trainloader, 0):
                    inputs, labels = data

                    replaymemory.store(inputs, labels, task_idx)
                    
                    # train_op
                    previous_task_g = model.compute_previous_task_gradients(replaymemory, task_idx)
                    curr_task_g = model.compute_current_task_gradients(replaymemory, task_idx)
                    update_gvs = model.gradient_projection_quadratic_programming(previous_task_g,
                                                                                 curr_task_g)

                    model.train_step(update_gvs)
                    if i % log_every_step == 1:
                        test_acc_ft = model.test_accuracy(first_task_testloader)
                        test_acc = model.test_accuracy(testloader)
                        print(f"Task: {task_idx}, Epoch: {epoch}, Steps: {i}, Iterations: {total_steps} \t First Task  Test Acc: {test_acc_ft}, Current Task Test Acc: {test_acc}")

                    total_steps += 1


if __name__ == "__main__":
    import sys, os
    sys.path.append("/abs_path_to_nccl/nonconvex-continual-learning")
    from continual.data.dataloader import TaskContinuumDataset, DataLoader

    _MNIST_DATA_DESCRIPTER = {'input_shape':[28,28], 'num_outputs':10}
    _CIFAR100_DATA_DESCRIPTER = {'input_shape':[32,32], 'num_outputs':100}

    DATA_DESCRIPTER = {'mnist': _MNIST_DATA_DESCRIPTER,
                       'cifar100': _CIFAR100_DATA_DESCRIPTER}


    import argparse
    parser = argparse.ArgumentParser()
    # dataset setup
    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=3, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='rotation', help='transform transformation')
    parser.add_argument('--min_angle', default=0, type=float, help='minimum rotation angle')
    parser.add_argument('--max_angle', default=90, type=float, help='maximum rotation angle')

    # learning hyperparameters
    parser.add_argument('--batch_size', default=128, type=int, help="batch size")
    parser.add_argument('--epochs_per_task', default=1, type=int, help="epochs per task")
    parser.add_argument('--shuffle_tasks', default=True, type=bool, help="shuffle tasks")
    parser.add_argument('--learning_rate', default=0.1, type=float, help="learning_rate")

    # logging setup
    parser.add_argument('--log_every_step', default=10, type=int, help="logging step")

    # network and memory setup
    parser.add_argument('--memory_size', default=128, type=int, help="memory_size") 
    parser.add_argument('--hid', nargs='*', default=[100,100], help='fc-layers or convnet, convnet=-1')


    args = parser.parse_args()

    task_continuum = TaskContinuumDataset(args.raw_data, args.transform, 
                                   args.num_tasks, args.seed,
                                   shuffle_tasks=True, train_samples_per_task=20000)

    continual_learning(task_continuum, DataLoader, args.num_tasks, DATA_DESCRIPTER[args.raw_data]['input_shape'], 
                       DATA_DESCRIPTER[args.raw_data]['num_outputs'], args.batch_size,
                       args.learning_rate, args.hid, args.epochs_per_task, args.memory_size, args.log_every_step)


    