import tensorflow as tf
import numpy as np

try:
    from . import core
except:
    import sys, os
    sys.path.append(os.path.abspath(__file__))
    import core

NEG_INF = -1.0e32 #-10.0e10

class RandomTaskReplayMemory:
    """
    A simple Replay memory for nccl
    """
    def __init__(self, num_tasks, memory_size, input_shape, num_outputs, sample_batch_size):
        if not isinstance(input_shape, list):
            input_shape = list(input_shape)
        self.input_memory = np.zeros([num_tasks, memory_size] + input_shape, dtype=np.float32)
        self.label_memory = np.zeros([num_tasks, memory_size, num_outputs], dtype=np.int32)
        self.sample_batch_size = sample_batch_size
        self.memory_size = memory_size
        self.num_tasks = num_tasks

        self.num_outputs = num_outputs
        self.num_classes_per_task = int(self.num_outputs / self.num_tasks)
        self.ptr = 0
        

    def store(self, inputs, labels, task_idx):
        batch_size = inputs.shape[0]
        start_point = self.ptr
        if self.ptr + batch_size > self.memory_size:
            end_point = self.memory_size            
            start_point2 = 0
            end_point2 = (self.ptr + batch_size) - self.memory_size

            data_limit = self.memory_size - self.ptr 

            self.input_memory[task_idx, start_point:end_point] = inputs[:data_limit]
            self.label_memory[task_idx, start_point:end_point] = labels[:data_limit]

            self.input_memory[task_idx, start_point2:end_point2] = inputs[data_limit:]
            self.label_memory[task_idx, start_point2:end_point2] = labels[data_limit:]

            self.ptr = end_point2
        else:
            end_point = self.ptr + batch_size
            self.input_memory[task_idx, start_point:end_point] = inputs
            self.label_memory[task_idx, start_point:end_point] = labels
            self.ptr = self.ptr + batch_size

    def sample(self, current_task_idx):
        """
        memory architecture: [T, M, H, W, C]
        B samples on [T, M]
        """
        a = current_task_idx * self.memory_size
        sampled_idx = np.random.choice(a, size=self.sample_batch_size*current_task_idx, replace=False)
        mask = np.zeros(self.num_tasks * self.memory_size)
        mask[sampled_idx] = 1
        mask = mask.reshape(self.num_tasks, self.memory_size).astype(np.bool)
        
        inputs = self.input_memory[mask]
        labels = self.label_memory[mask]

        return inputs, labels

    def sample_for_split(self, current_task_idx):
        a = current_task_idx * self.memory_size
        sampled_idx = np.random.choice(a, size=self.sample_batch_size, replace=False)
        mask = np.zeros(self.num_tasks * self.memory_size)
        mask[sampled_idx] = 1
        mask = mask.reshape(self.num_tasks, self.memory_size).astype(np.bool)

        inputs = self.input_memory[mask]
        labels = self.label_memory[mask]

        labels_int = np.argmax(labels, axis=-1)
        #offsets = np.stack([np.array([labels_int[i] // self.num_classes_per_task, labels_int[i] // self.num_classes_per_task+1] for i in range(len(labels_int)))])
        
        offsets = np.zeros_like(labels, np.float32)
        for i in range(labels.shape[0]):
            c = int(labels_int[i] // self.num_classes_per_task)
            offsets[i, self.num_classes_per_task * c: self.num_classes_per_task * (c+1)] = 1.0
        
        return inputs, labels, offsets

class NonconvexContinualLearning:
    def __init__(self, sess, transform, input_shape, num_outputs, 
                 num_tasks, memory_strength, learning_rate,
                 network=core.mlp, hidden_layers=[100, 100], L_coeff=0.7, delta=0.5):
        self.sess = sess
        self.learning_rate = learning_rate
        self.num_outputs = num_outputs
        input_batch_shape = [None] + input_shape    
        self.inputs = tf.placeholder(tf.float32, shape=input_batch_shape, name="inputs")
        self.labels = tf.placeholder(tf.int32, shape=[None, num_outputs], name="labels")
        if transform == "split":
            self.offset = tf.placeholder(tf.int32, shape=[None, num_outputs], name="offset")

        self.is_training = tf.placeholder_with_default(True, shape=None)


        with tf.variable_scope("current"):
            if transform == 'split':
                self.logits = network(self.inputs, hidden_layers, num_outputs, self.is_training)
                self.pruned_logits = self._pruning(self.logits, self.offset)
                self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.pruned_logits)
            else:
                self.logits = network(self.inputs, hidden_layers, num_outputs, self.is_training)
                self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.logits)

        self.inputs_p = tf.placeholder(tf.float32, shape=input_batch_shape, name='input_p')
        self.labels_p = tf.placeholder(tf.int32, shape=[None, num_outputs], name='label_p')
        if transform == 'split':
            self.offset_p = tf.placeholder(tf.int32, shape=[None, num_outputs], name='offset_p')
        with tf.variable_scope("previous", reuse=tf.AUTO_REUSE):
            if transform == 'split':
                self.logits_p = network(self.inputs_p, hidden_layers, num_outputs, self.is_training)
                self.pruned_logits_p = self._pruning(self.logits_p, self.offset_p)
                self.cross_entropy_loss_p = tf.losses.softmax_cross_entropy(self.labels_p, self.pruned_logits_p)
            else:
                self.logits_p = network(self.inputs_p, hidden_layers, num_outputs, self.is_training)
                self.cross_entropy_loss_p = tf.losses.softmax_cross_entropy(self.labels_p, self.logits_p)

        self.optimizer = tf.train.GradientDescentOptimizer(0.9999)
        self.current_gvs = self.optimizer.compute_gradients(self.cross_entropy_loss)

        self.previous_g = tf.gradients(self.cross_entropy_loss_p, tf.trainable_variables(scope='previous'))
        self.current_g = tf.gradients(self.cross_entropy_loss, tf.trainable_variables(scope='current'))

        # learning rates for each task
        self.ip= self._inner_product(self.previous_g, self.current_g)
        self.gp_norm_square = self._inner_product(self.previous_g, self.previous_g)
        self.gc_norm_square = self._inner_product(self.current_g, self.current_g)

        self.delta = delta
        self.L = (2.0 * L_coeff) / self.learning_rate 
        
        self.alpha = tf.where(self.ip <= 0, self.delta * learning_rate * (1.0 - self.ip / self.gp_norm_square)  , learning_rate)
        
        # by the definition of tensorflow optimizer, we should divide it by learning_rate
        self.beta = tf.where(self.ip <= 0, self.delta * learning_rate,  self.ip / (self.L * self.gc_norm_square))

        grad = [self.alpha * gp + self.beta * gc for gp, gc in zip(self.previous_g, self.current_g)]

        current_variables = [(g,v) for g, v in self.current_gvs if 'current' in v.name]
        previous_variables = [(g,v) for g, v in self.current_gvs if 'previous' in v.name]

        nccl_c_op = [(g, v) for g, (g_, v) in zip(grad, current_variables)]
        nccl_p_op = [(g, v) for g, (g_, v) in zip(grad, previous_variables)]

        # self.nccl_gvs = [(self.alpha * gp + self.beta* gc, vc) for gp, (gc, vc) in zip(self.previous_g, self.current_gvs)]
        self.nccl_gvs = nccl_c_op + nccl_p_op
        scaled_cgvs = [(learning_rate * g, v) for g, v in self.current_gvs if g is not None]

        self.train_op = self.optimizer.apply_gradients(scaled_cgvs)
        self.nccl_op = self.optimizer.apply_gradients(self.nccl_gvs)

        if transform == 'split':
            self.train_step = self.train_step_for_split
        else:
            self.train_step = self.train_step_for_general
        
        # test operators
        self.equal = tf.equal(tf.argmax(self.logits, axis=-1), tf.argmax(self.labels, axis=-1))


    def _pruning(self, logits, offset_mask):
        equal = tf.equal(offset_mask,1)
        pl = tf.where(equal, logits, NEG_INF * tf.ones_like(logits, dtype=tf.float32))
        return pl

    
    def _inner_product(self, g1_list, g2_list):
        sum_prod = 0
        for g1, g2 in zip(g1_list, g2_list):
            sum_prod += tf.reduce_sum(tf.math.multiply(g1, g2))
        
        return sum_prod


    def train_step_for_general(self, inputs, labels, replaymemory, task_idx, transform=None):
        if task_idx == 0:
            feed_dict = {self.inputs:inputs, self.labels:labels}
            self.sess.run(self.train_op, feed_dict=feed_dict)
        else:
            input_p, label_p = replaymemory.sample(task_idx)
            feed_dict = {self.inputs:inputs, self.labels:labels,
                         self.inputs_p:input_p, self.labels_p:label_p}
            
            self.sess.run(self.nccl_op, feed_dict=feed_dict)

    def _generate_offset_mask(self, labels, offset_class):
        """
        labels: one_hot
        offset_class: c_start,  c_end
        """
        offset_mask = np.zeros_like(labels)
        offset_mask[:, offset_class[0]:offset_class[1]] = 1.0
        return offset_mask

    def train_step_for_split(self, inputs, labels, replaymemory, task_idx, offset_class):
        offset_mask = self._generate_offset_mask(labels, offset_class) 
        if task_idx == 0:
            feed_dict = {self.inputs:inputs, self.labels:labels, 
                self.offset:offset_mask}
            self.sess.run(self.train_op, feed_dict=feed_dict)
        else:
            input_p, label_p, offset_p = replaymemory.sample_for_split(task_idx)
            # import time
            # print(offset_p[0])
            # print(label_p[0])
            # time.sleep(10000)
            # import time
            # print(offset_mask[0])
            # print(labels[0])
            # time.sleep(10000)
            feed_dict = {self.inputs:inputs, self.labels:labels, self.offset:offset_mask,
                         self.inputs_p:input_p, self.labels_p:label_p, self.offset_p:offset_p}
            
            self.sess.run(self.nccl_op, feed_dict=feed_dict)
        

    def test_accuracy(self, testloader):
        test_results = list()
        for i, data in enumerate(testloader, 0):
            inputs, labels = data
            e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                     self.labels: labels,
                                                     self.is_training:False})
            test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc

    def running_average_accuracy(self, testloader_list):
        test_results = list()
        for testloader in testloader_list:
            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                        self.labels: labels,
                                                        self.is_training:False})
                test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc


def continual_learning(task_continuum, data_loader, num_tasks, input_shape, 
                       num_outputs, batch_size, learning_rate, sample_memory_batch,
                       network_architecture, epochs_per_task, memory_size, log_every_step, l_coeff, delta):
    """
    data_loader is not a instance, class itself.
    """
    # set logger

    if len(network_architecture) ==1 and int(network_architecture[0])==-1:
        network = core.resnet
    else:
        network = core.mlp



    with tf.Session() as sess:
        # intialize model
        model = NonconvexContinualLearning(sess=sess, transform=task_continuum.transform, input_shape=input_shape, 
                                       num_outputs=num_outputs, num_tasks=num_tasks,
                                       memory_strength=None, learning_rate=learning_rate,
                                       network=network, hidden_layers=network_architecture, L_coeff=l_coeff, delta=delta)

        # define a replay memory
        replaymemory = RandomTaskReplayMemory(num_tasks=num_tasks, memory_size=memory_size, 
                                        input_shape=input_shape, num_outputs=num_outputs, sample_batch_size=sample_memory_batch)
        
        sess.run(tf.global_variables_initializer())

        # start training
        total_steps = 0
        testloader_list = []
        for task_idx, (transform, train_data, test_data) in enumerate(task_continuum, 0):
            # define train loader and test loader
            trainloader = data_loader(train_data, batch_size=batch_size)
            testloader = data_loader(test_data, batch_size=batch_size)
            testloader_list.append(testloader)

            # print(len(trainloader))
            if task_idx == 0:
                first_task_testloader = testloader
            print(f"----------------- New task: {task_idx} ------------------------")
            for epoch in range(epochs_per_task):
                for i, data in enumerate(trainloader, 0):
                    inputs, labels = data

                    replaymemory.store(inputs, labels, task_idx)

                    model.train_step(inputs, labels, replaymemory, task_idx, transform)

                    if i % log_every_step == 1:
                        test_acc_ft = model.test_accuracy(first_task_testloader)
                        test_acc = model.test_accuracy(testloader)
                        running_test_acc = model.running_average_accuracy(testloader_list)
                        
                        logging_string = (f"Task: {task_idx}, Epoch: {epoch}, Steps: {i}, Iterations: {total_steps} \t" 
                        f"First Test Acc: {test_acc_ft}, Curr. Task Acc: {test_acc}, "
                        f"Average test Acc: {running_test_acc}")

                        
                        print(logging_string)

                    total_steps += 1


if __name__ == "__main__":
    import sys, os
    your_abs_path_of_this_project = "path_to"
    sys.path.append(your_abs_path_of_this_project)
    from continual.data.dataloader import TaskContinuumDataset, DataLoader

    _MNIST_DATA_DESCRIPTER = {'input_shape':[28,28], 'num_outputs':10}
    _CIFAR100_DATA_DESCRIPTER = {'input_shape':[32,32,3], 'num_outputs':100}
    _CIFAR10_DATA_DESCRIPTER = {'input_shape':[32,32,3], 'num_outputs':10}

    DATA_DESCRIPTER = {'mnist': _MNIST_DATA_DESCRIPTER,
                       'cifar100': _CIFAR100_DATA_DESCRIPTER,
                       'cifar10': _CIFAR10_DATA_DESCRIPTER}


    import argparse
    parser = argparse.ArgumentParser()
    # dataset setup
    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=3, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='rotation', help='transform transformation')
    parser.add_argument('--min_angle', default=0, type=float, help='minimum rotation angle')
    parser.add_argument('--max_angle', default=90, type=float, help='maximum rotation angle')

    # learning hyperparameters
    parser.add_argument('--batch_size', default=100, type=int, help="batch size")
    parser.add_argument('--epochs_per_task', default=1, type=int, help="epochs per task")
    parser.add_argument('--shuffle_tasks', default=True, type=bool, help="shuffle tasks")
    parser.add_argument('--learning_rate', default=0.01, type=float, help="learning_rate")
    parser.add_argument('--l_coeff', default=0.7, type=float, help="learning_rate")
    parser.add_argument('--delta', default=0.5, type=float, help="learning_rate")

    # logging setup
    parser.add_argument('--log_every_step', default=20, type=int, help="logging step")

    # network and memory setup
    parser.add_argument('--memory_size', default=500, type=int, help="memory_size") 
    parser.add_argument('--sample_memory_batch', default=100, type=int, help="sample size of memory")
    parser.add_argument('--hid', nargs='*', default=[400,400], help='fc-layers or convnet, convnet=-1')


    args = parser.parse_args()

    task_continuum = TaskContinuumDataset(args.raw_data, args.transform, 
                                   args.num_tasks, args.seed,
                                   shuffle_tasks=True, train_samples_per_task=None)

    continual_learning(task_continuum, DataLoader, args.num_tasks, DATA_DESCRIPTER[args.raw_data]['input_shape'], 
                       DATA_DESCRIPTER[args.raw_data]['num_outputs'], args.batch_size, args.learning_rate,
                       args.sample_memory_batch, args.hid, args.epochs_per_task, args.memory_size, args.log_every_step,
                       args.l_coeff, args.delta)


    