import tensorflow as tf
import numpy as np
import quadprog
import time
import sys, os

try:
    from . import core
except:
    sys.path.append(os.path.abspath(__file__))
    import core


# sys.path.append(os.path.abspath('../../../'))
from continual.utils.logx import EpochLogger
from continual.data.dataloader import TaskContinuumDataset, DataLoader
from continual.algos.memory import RandomTaskReplayMemory,ReservoirMemory

_MNIST_DATA_DESCRIPTER = {'input_shape':[28,28], 'num_outputs':10}
_CIFAR100_DATA_DESCRIPTER = {'input_shape':[32,32,3], 'num_outputs':100}
_CIFAR10_DATA_DESCRIPTER = {'input_shape':[32,32,3], 'num_outputs':10}

DATA_DESCRIPTER = {'mnist': _MNIST_DATA_DESCRIPTER,
                    'cifar100': _CIFAR100_DATA_DESCRIPTER,
                    'cifar10': _CIFAR10_DATA_DESCRIPTER}

NEG_INF = -10.0e10 #-1.0e32 #-10.0e10
log_every_step = 20

class Reservoir:
    def __init__(self, sess, transform, input_shape, num_outputs, 
                 num_tasks, memory_strength, learning_rate,
                 network=core.mlp, hidden_layers=[100, 100]):
        self.sess = sess
        self.transform = transform
        self.learning_rate = learning_rate
        self.num_outputs = num_outputs
        input_batch_shape = [None] + input_shape
        self.is_training = tf.placeholder_with_default(True, shape=None)

        self.inputs = tf.placeholder(tf.float32, shape=input_batch_shape, name="inputs")
        self.labels = tf.placeholder(tf.int32, shape=[None, num_outputs], name="labels")
        if transform == "split":
            self.offset = tf.placeholder(tf.int32, shape=[None, num_outputs], name="offset")

        with tf.variable_scope("current"):
            if transform == 'split':
                self.logits = network(self.inputs, hidden_layers, num_outputs, self.is_training)
                self.pruned_logits = self._pruning(self.logits, self.offset)
                self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.pruned_logits)

            else:
                self.logits = network(self.inputs, hidden_layers, num_outputs, self.is_training)
                self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.logits)

            # reg_losses_c = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            # print(reg_losses_c)
            self.loss_c = self.cross_entropy_loss + 0.00001 * tf.losses.get_regularization_loss('current')
        

        # self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
        # self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate)
        # self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        # self.optimizer = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        gvs = self.optimizer.compute_gradients(self.loss_c)
     
        self.train_op = self.optimizer.apply_gradients([(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in gvs])

        # self.current_gvs = self.optimizer.compute_gradients(self.loss_c)

        # self.previous_g = tf.gradients(self.loss_p, tf.trainable_variables(scope='previous'))
        # self.current_g = tf.gradients(self.loss_c, tf.trainable_variables(scope='current'))

        # # gradient clipping
        # # self.previous_g = [tf.clip_by_value(gp, -0.5, 0.5) for gp in self.previous_g]
        # # self.current_g = [tf.clip_by_value(gc, -0.5, 0.5) for gc in self.current_g]

        # # learning rates for each task
        # self.ip= self._inner_product(self.previous_g, self.current_g)
        # self.gp_norm_square = self._inner_product(self.previous_g, self.previous_g)
        # self.gc_norm_square = self._inner_product(self.current_g, self.current_g)


        # lr = 1.0
               
        # # by the definition of tensorflow optimizer, we should divide it by learning_rate
        # # self.alpha = tf.where(self.ip <= 0, self.delta * lr * (1.0 - self.ip / self.gp_norm_square)  , lr)
        # # self.beta = tf.where(self.ip <= 0, self.delta * lr,  self.ip / (self.L * self.gc_norm_square))
        # # self.alpha = tf.where(self.ip <= 0, lr * (1.0 - self.ip / self.gp_norm_square), lr)
        # # self.beta = tf.where(self.ip <= 0, lr,  (self.ip * (1 - self.learning_rate * self.L)) / (self.L * self.gc_norm_square))
        # # beta 조정..
        # # beta_H =  (self.ip * (1 - self.learning_rate * self.L)) / (self.L * self.gc_norm_square * self.learning_rate)
        # #clipped_beta = tf.clip_by_value(beta_H, 0, 50)
        # # self.beta = tf.where(self.ip <= 0, lr, beta_H)

        # self.alpha = 1.0
        # self.beta = 1.0

        # # grad = [self.alpha * gp + self.beta * gc for gp, gc in zip(self.previous_g, self.current_g)]
        # grad = [tf.clip_by_value(self.alpha * gp + self.beta * gc, -0.5, 0.5) for gp, gc in zip(self.previous_g, self.current_g)]

        # self.grad_norm = tf.sqrt(self._inner_product(grad, grad))

        # current_variables = [(g,v) for g, v in self.current_gvs if 'current' in v.name]
        # previous_variables = [(g,v) for g, v in self.current_gvs if 'previous' in v.name]

        # nccl_c_op = [(g, v) for g, (g_, v) in zip(grad, current_variables)]
        # nccl_p_op = [(g, v) for g, (g_, v) in zip(grad, previous_variables)]

        # # self.nccl_gvs = [(self.alpha * gp + self.beta* gc, vc) for gp, (gc, vc) in zip(self.previous_g, self.current_gvs)]
        # self.nccl_gvs = nccl_c_op + nccl_p_op
        # scaled_cgvs = [(g, v) for g, v in self.current_gvs if g is not None]

        # self.train_op = self.optimizer.apply_gradients(scaled_cgvs)
        # self.nccl_op = self.optimizer.apply_gradients(self.nccl_gvs)

        if transform == 'split':
            self.train_step = self.train_step_for_split
        else:
            self.train_step = self.train_step_for_general
        
        # test operators
        self.equal = tf.equal(tf.argmax(self.logits, axis=-1), tf.argmax(self.labels, axis=-1))
        # self.L_lower_bound_of_exp = 1 / self.grad_norm

        # self.l_estimate() 

    def l_estimate(self):  
        """
        one_step_update = [v - self.learning_rate * g for g, v in self.current_gvs if 'current' in v.name]
        temp_vars = [v for g, v in self.current_gvs if 'current' in v.name]

        assign_one_step_update = [tf.assign(r, v) for r, v in zip(tf.trainable_variables(scope='current'), one_step_update)]
        with tf.control_dependencies(assign_one_step_update):
            next_step_grad = tf.gradients(self.loss_c, tf.trainable_variables(scope='current'))
            with tf.control_dependencies(next_step_grad):
                assign_orginal = [tf.assign(r, v) for r, v in zip(tf.trainable_variables(scope='current'), temp_vars)]

        self.grad_diff = [g1 - g2 for g1, g2 in zip(self.current_g, next_step_grad) ]
        self.grad_diff_norm = tf.sqrt(self._inner_product(self.grad_diff, self.grad_diff))
        self.var_diff = self.learning_rate *tf.sqrt(self._inner_product(self.current_g, self.current_g))
        self.l_upper_bound_estimate = self.grad_diff_norm / self.var_diff
        """
        pass


    def _pruning(self, logits, offset_mask):
        # equal = tf.equal(offset_mask,1)
        # pl = tf.where(equal, logits, NEG_INF * tf.ones_like(logits, dtype=tf.float32))
        # return pl
        return logits
    
    def _inner_product(self, g1_list, g2_list):
        sum_prod = 0
        for g1, g2 in zip(g1_list, g2_list):
            sum_prod += tf.reduce_sum(tf.math.multiply(g1, g2))
        
        return sum_prod


    def train_step_for_general(self, inputs, labels, replaymemory, task_idx, transform=None):
        if task_idx == 0:
            feed_dict = {self.inputs:inputs, self.labels:labels}
            self.sess.run(self.train_op, feed_dict=feed_dict)
        else:
            input_p, label_p = replaymemory.sample(task_idx)
            inputs = np.concatenate([inputs, input_p], axis=0)
            labels = np.concatenate([labels, label_p], axis=0)
            feed_dict = {self.inputs:inputs, self.labels:labels,}
            
            self.sess.run(self.train_op, feed_dict=feed_dict)

    def _generate_offset_mask(self, labels, offset_class):
        """
        labels: one_hot
        offset_class: c_start,  c_end
        """
        offset_mask = np.zeros_like(labels, np.int32)
        offset_mask[:, offset_class[0]:offset_class[1]] = 1
        # print(np.argmax(labels, axis=-1))
        return offset_mask

    def train_step_for_split(self, inputs, labels, replaymemory, task_idx, offset_class):
        offset_mask = self._generate_offset_mask(labels, offset_class)
        # print(offset_mask) 
        if task_idx == 0:
            feed_dict = {self.inputs:inputs, self.labels:labels, 
                self.offset:offset_mask}
            self.sess.run(self.train_op, feed_dict=feed_dict)
        else:
            input_p, label_p, offset_p = replaymemory.sample_for_split(task_idx)
            inputs = np.concatenate([inputs, input_p], axis=0)
            labels = np.concatenate([labels, label_p], axis=0)
            offset_mask = np.concatenate([offset_mask, offset_p], axis=0)

            feed_dict = {self.inputs:inputs, self.labels:labels, self.offset:offset_mask,}
            # _, a, b = self.sess.run([self.nccl_op, self.alpha, self.beta], feed_dict=feed_dict)
            # print(a, b)
            self.sess.run(self.train_op, feed_dict= feed_dict)

    def test_accuracy(self, testloader):
        test_results = list()
        for i, data in enumerate(testloader, 0):
            inputs, labels = data
            e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                     self.labels: labels,
                                                     self.is_training:False})
            test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc

    def running_average_accuracy(self, testloader_list):
        test_results = list()
        for testloader in testloader_list:
            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                        self.labels: labels,
                                                        self.is_training:False})
                test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc

    def average_test_l_estimate(self, testloader_list, transform_list):
        test_results = list()
        for testloader, tr in zip(testloader_list, transform_list):

            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                if self.transform == 'split':
                    offset_mask = self._generate_offset_mask(labels, tr)
                    l = self.sess.run(self.l_upper_bound_estimate, feed_dict={self.inputs: inputs,
                                                            self.labels: labels,
                                                            self.offset: offset_mask,
                                                            self.is_training:False})
                else:
                    l = self.sess.run(self.l_upper_bound_estimate, feed_dict={self.inputs: inputs,
                                        self.labels: labels,
                                        self.is_training:False})

                test_results.append(l)
        test_l = np.mean(test_results)
        return test_l


def reservoir(raw_data, transform, num_tasks, seed, shuffle_tasks,
         batch_size, learning_rate, sample_memory_batch,
         network_architecture, epochs_per_task, memory_size, 
         memory, logger_kwargs):
    """
    data_loader is not a instance, class itself.
    """
    # set logger
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    tf.set_random_seed(seed)
    np.random.seed(seed)

    input_shape = DATA_DESCRIPTER[raw_data]['input_shape']
    num_outputs = DATA_DESCRIPTER[raw_data]['num_outputs']

    
    task_continuum = TaskContinuumDataset(raw_data, transform, num_tasks, seed,
        shuffle_tasks=shuffle_tasks, train_samples_per_task=None)

    if len(network_architecture) ==1 and int(network_architecture[0])==-1:
        network = core.resnet
    else:
        network = core.mlp

    if memory == "random":
        Memory = RandomTaskReplayMemory
    elif memory == "reservoir":
        Memory = ReservoirMemory
    else:
        raise ValueError("Wrong type of memory")

    with tf.Session() as sess:
        # intialize model
        model = Reservoir(sess=sess, transform=task_continuum.transform, input_shape=input_shape, 
                                       num_outputs=num_outputs, num_tasks=num_tasks,
                                       memory_strength=None, learning_rate=learning_rate,
                                       network=network, hidden_layers=network_architecture)

        # define a replay memory
        replaymemory = Memory(num_tasks=num_tasks, memory_size=memory_size, 
                                input_shape=input_shape, num_outputs=num_outputs, sample_batch_size=sample_memory_batch)
        
        sess.run(tf.global_variables_initializer())

        # Setup model saving
        logger.setup_tf_saver(sess, inputs={'inputs':model.inputs}, outputs={'labels':model.labels})

        # start training
        total_steps = 0
        testloader_list = []
        transform_list = []
        start_time = time.time()
        total_epochs = 0

        for task_idx, (transform, train_data, test_data) in enumerate(task_continuum, 0):
            # define train loader and test loader
            trainloader = DataLoader(train_data, batch_size=batch_size)
            testloader = DataLoader(test_data, batch_size=batch_size)
            testloader_list.append(testloader)
            transform_list.append(transform)

            # print(len(trainloader))
            if task_idx == 0:
                first_task_testloader = testloader
            # print(f"----------------- New task: {task_idx} ------------------------")
            for epoch in range(epochs_per_task):
                for i, data in enumerate(trainloader, 0):
                    inputs, labels = data
                    # print(transform, np.argmax(labels, axis=-1))

                    model.train_step(inputs, labels, replaymemory, task_idx, transform)
                    replaymemory.store(inputs, labels, task_idx)

                    # if i % log_every_step == 1:
                    #     test_acc_ft = model.test_accuracy(first_task_testloader)
                    #     test_acc = model.test_accuracy(testloader)
                    #     running_test_acc = model.running_average_accuracy(testloader_list)
                    #     """
                    #     logging_string = (f"Task: {task_idx}, Epoch: {epoch}, Steps: {i}, Iterations: {total_steps} \t" 
                    #     f"First Test Acc: {test_acc_ft}, Curr. Task Acc: {test_acc}, "
                    #     f"Average test Acc: {running_test_acc}")

                        
                    #     print(logging_string)
                    #     """
                    #     logger.store(TestAccT1=test_acc_ft, TestAccCT=test_acc, TestAvgAcc=running_test_acc)

                    total_steps += 1                
                
                test_acc_ft = model.test_accuracy(first_task_testloader)
                test_acc = model.test_accuracy(testloader)
                running_test_acc = model.running_average_accuracy(testloader_list)
                # avg_l = model.average_test_l_estimate(testloader_list, transform_list)
                logger.store(TestAccT1=test_acc_ft, TestAccCT=test_acc, TestAvgAcc=running_test_acc)



                # Log info about epoch
                total_epochs += 1
                logger.log_tabular('Task', task_idx)
                logger.log_tabular('Epoch', total_epochs)
                logger.log_tabular('TestAccT1', with_min_and_max=True)
                logger.log_tabular('TestAccCT', with_min_and_max=True)
                logger.log_tabular('TestAvgAcc', with_min_and_max=True)
                # logger.log_tabular('TestL', with_min_and_max=True)
                logger.log_tabular('Time', time.time() - start_time)
                logger.dump_tabular()
                # print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    # dataset setup
    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=3, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='rotation', help='transform transformation')
    parser.add_argument('--min_angle', default=0, type=float, help='minimum rotation angle')
    parser.add_argument('--max_angle', default=90, type=float, help='maximum rotation angle')

    # learning hyperparameters
    parser.add_argument('--batch_size', default=100, type=int, help="batch size")
    parser.add_argument('--epochs_per_task', default=1, type=int, help="epochs per task")
    parser.add_argument('--shuffle_tasks', default=True, type=bool, help="shuffle tasks")
    parser.add_argument('--learning_rate', default=0.03, type=float, help="learning_rate")
    parser.add_argument('--l_upper_bound', default=0.7, type=float, help="choose a proper upper bound of L")
    parser.add_argument('--memory', default='random', help="memory type")

    # logging setup
    # parser.add_argument('--log_every_step', default=20, type=int, help="logging step")
    parser.add_argument('--exp_name', type=str, default='nccl')

    # network and memory setup
    parser.add_argument('--memory_size', default=500, type=int, help="memory_size") 
    parser.add_argument('--sample_memory_batch', default=100, type=int, help="sample size of memory")
    parser.add_argument('--network_architecture', nargs='*', default=[400,400], help='fc-layers or convnet, convnet=-1')
    args = parser.parse_args()

    from continual.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    reservoir(args.raw_data, args.transform, args.num_tasks, args.seed, args.shuffle_tasks, 
         args.batch_size, args.learning_rate, args.sample_memory_batch, args.network_architecture,
         args.epochs_per_task, args.memory_size, args.memory, logger_kwargs)


    