import tensorflow as tf
import numpy as np
import quadprog
import time
import sys, os

try:
    from . import core
except:
    sys.path.append(os.path.abspath(__file__))
    import core


# sys.path.append(os.path.abspath('../../../'))
from continual.utils.logx import EpochLogger
from continual.data.dataloader import TaskContinuumDataset, DataLoader
from continual.algos.memory import RandomTaskReplayMemory,ReservoirMemory

_MNIST_DATA_DESCRIPTER = {'input_shape':[28,28], 'num_outputs':10}
_CIFAR100_DATA_DESCRIPTER = {'input_shape':[32,32,3], 'num_outputs':100}
_CIFAR10_DATA_DESCRIPTER = {'input_shape':[32,32,3], 'num_outputs':10}

DATA_DESCRIPTER = {'mnist': _MNIST_DATA_DESCRIPTER,
                    'cifar100': _CIFAR100_DATA_DESCRIPTER,
                    'cifar10': _CIFAR10_DATA_DESCRIPTER}

NEG_INF = -10.0e10 #-1.0e32 #-10.0e10
LOG_EVERY_STEP = 100

class NonconvexContinualLearning:
    def __init__(self, sess, transform, input_shape, num_outputs, 
                 num_tasks, memory_strength, learning_rate,
                 network=core.mlp, hidden_layers=[100, 100], beta_max=1.0, L_upper_bound=10.0, L_estimate=True):
        self.sess = sess

        # self.L = (2.0 * L_coeff) / self.learning_rate 
        self.L = L_upper_bound
        self.transform = transform
        self.beta_max = beta_max

        self.learning_rate = learning_rate
        self.num_outputs = num_outputs
        input_batch_shape = [None] + input_shape   
        self.is_training = tf.placeholder_with_default(True, shape=None)

        # current data stream 
        self.inputs = tf.placeholder(tf.float32, shape=input_batch_shape, name="inputs")
        self.labels = tf.placeholder(tf.int32, shape=[None, num_outputs], name="labels")
        self.offset = tf.placeholder(tf.float32, shape=[None], name="offset")

        with tf.variable_scope("network"):
            self.logits = network(self.inputs, hidden_layers, num_outputs, self.is_training)
            self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.logits)
            self.loss = self.cross_entropy_loss

            self.loss_p = tf.losses.softmax_cross_entropy(self.labels, self.logits, weights=self.offset)
            
            _offset = tf.dtypes.cast(tf.logical_not(tf.dtypes.cast(self.offset, tf.bool)), tf.float32)
            self.loss_c = tf.losses.softmax_cross_entropy(self.labels, self.logits,
                weights=_offset)

        # self.optimizer = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        # self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)

        self.previous_g = tf.gradients(self.loss_p, tf.trainable_variables())
        self.current_g = tf.gradients(self.loss_c, tf.trainable_variables())

        # reg = 0.00001
        # self.reg_grads = tf.gradients(tf.losses.get_regularization_loss(), tf.trainable_variables())
        # print(tf.losses.get_regularization_losses('network'))
        # print(self.reg_grads)

        self.gvs = self.optimizer.compute_gradients(self.loss)

        # learning rates for each task
        self.ip= self._inner_product(self.previous_g, self.current_g)
        self.gp_norm_square = self._inner_product(self.previous_g, self.previous_g)
        self.gc_norm_square = self._inner_product(self.current_g, self.current_g)

        lr = 1.0
        # self.alpha = 1.0
        # self.beta = 1.0

        # by the definition of tensorflow optimizer, we should divide it by learning_rate

        self.alpha = tf.where(self.ip <= 0, lr * (1.0 - self.ip / self.gp_norm_square), lr)
        # # # self.beta = tf.where(self.ip <= 0, lr,  (self.ip * (1 - self.learning_rate * self.L)) / (self.L * self.gc_norm_square))
        # # # beta 조정..
        beta_H =  (self.ip * (1 - self.learning_rate * self.L)) / (self.L * self.gc_norm_square * self.learning_rate)
        beta_H = tf.clip_by_value(beta_H, 0, self.beta_max)
        self.beta = tf.where(self.ip <= 0, lr, beta_H)


        # grad = []
        # scaled_cgvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in self.current_gvs if g is not None]


        # CLIP beta only
        # grad = [self.alpha * gp + self.beta * tf.clip_by_value(gc, -0.5, 0.5)
        #     for gp, gc in zip(self.previous_g, self.current_g)]
        # _nccl_gvs = [(g, v) for g, v in zip(grad, tf.trainable_variables())]

        # CLIP beta only and clip again
        # grad = [self.alpha * gp + self.beta * tf.clip_by_value(gc, -0.5, 0.5)
        #     for gp, gc in zip(self.previous_g, self.current_g)]
        # _nccl_gvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in zip(grad, tf.trainable_variables())]


        # # CLIP RESPECTIVELY
        grad = [self.alpha * tf.clip_by_value(gp, -0.5, 0.5) + self.beta * tf.clip_by_value(gc, -0.5, 0.5)
            for gp, gc in zip(self.previous_g, self.current_g)]
        _nccl_gvs = [(g, v) for g, v in zip(grad, tf.trainable_variables())]


        # grad = [self.alpha * gp + self.beta * gc for gp, gc in zip(self.previous_g, self.current_g)]
        # _nccl_gvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in zip(grad, tf.trainable_variables())]


        scaled_gvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in self.gvs]
        self.train_op = self.optimizer.apply_gradients(scaled_gvs)
        self.nccl_op = self.optimizer.apply_gradients(_nccl_gvs)

        if transform == 'split':
            self.train_step = self.train_step_for_split
        else:
            self.train_step = self.train_step_for_general
        
        # test operators
        self.equal = tf.equal(tf.argmax(self.logits, axis=-1), tf.argmax(self.labels, axis=-1))
        # self.L_lower_bound_of_exp = 1 / self.grad_norm

        # self.l_estimate() 


    def l_estimate(self):  
        one_step_update = [v - self.learning_rate * g for g, v in self.current_gvs]
        temp_vars = [v for g, v in self.current_gvs]

        assign_one_step_update = [tf.assign(r, v) for r, v in zip(tf.trainable_variables(), one_step_update)]
        with tf.control_dependencies(assign_one_step_update):
            next_step_grad = tf.gradients(self.cross_entropy_loss, tf.trainable_variables())
            with tf.control_dependencies(next_step_grad):
                assign_orginal = [tf.assign(r, v) for r, v in zip(tf.trainable_variables(), temp_vars)]

        self.grad_diff = [g1 - g2 for g1, g2 in zip(self.current_g, next_step_grad) ]
        self.grad_diff_norm = tf.sqrt(self._inner_product(self.grad_diff, self.grad_diff))
        self.var_diff = self.learning_rate *tf.sqrt(self._inner_product(self.current_g, self.current_g))
        self.l_upper_bound_estimate = self.grad_diff_norm / self.var_diff

    def _inner_product(self, g1_list, g2_list):
        sum_prod = 0
        for g1, g2 in zip(g1_list, g2_list):
            sum_prod += tf.reduce_sum(tf.math.multiply(g1, g2))
        
        return sum_prod


    def train_step_for_general(self, inputs, labels, replaymemory, task_idx, transform=None):
        if task_idx == 0:
            feed_dict = {self.inputs:inputs, self.labels:labels}
            self.sess.run(self.train_op, feed_dict=feed_dict)
        else:
            input_p, label_p = replaymemory.sample(task_idx)
            inputs = np.concatenate([inputs, input_p], axis=0)
            labels = np.concatenate([labels, label_p], axis=0)
            feed_dict = {self.inputs:inputs, self.labels:labels,}
            
            self.sess.run(self.nccl_op, feed_dict=feed_dict)

    def _generate_offset_mask(self, labels, offset_class):
        """
        labels: one_hot
        offset_class: c_start,  c_end
        """
        offset_mask = np.zeros_like(labels, np.int32)
        offset_mask[:, offset_class[0]:offset_class[1]] = 1
        # print(np.argmax(labels, axis=-1))
        return offset_mask

    def train_step_for_split(self, inputs, labels, replaymemory, task_idx, offset_class):
        # offset_mask = self._generate_offset_mask(labels, offset_class)
        # print(offset_mask) 
        if task_idx == 0:
            feed_dict = {self.inputs:inputs, self.labels:labels,}
            self.sess.run(self.train_op, feed_dict=feed_dict)
        else:
            input_p, label_p, offset_p = replaymemory.sample_for_split(task_idx)
            inputs_ = np.concatenate([inputs, input_p], axis=0)
            labels_ = np.concatenate([labels, label_p], axis=0)
            off1 = np.zeros(len(labels))
            off2 = np.ones(len(label_p))
            # print(off1.shape, off2.shape)
            offset_mask = np.concatenate([off1, off2], axis=0)

            # offset_mask = np.concatenate([offset_mask, offset_p], axis=0)

            feed_dict = {self.inputs:inputs_, self.labels:labels_, self.offset:offset_mask,}
            # _, a, b = self.sess.run([self.nccl_op, self.alpha, self.beta], feed_dict=feed_dict)
            # print(a, b)
            # print(inputs.shape, labels.shape, offset_mask.shape)
            self.sess.run(self.nccl_op, feed_dict= feed_dict)

    def test_accuracy(self, testloader):
        test_results = list()
        for i, data in enumerate(testloader, 0):
            inputs, labels = data
            e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                     self.labels: labels,
                                                     self.is_training:False})
            test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc

    def running_average_accuracy(self, testloader_list):
        test_results = list()
        for testloader in testloader_list:
            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                        self.labels: labels,
                                                        self.is_training:False})
                test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc

    def average_test_l_estimate(self, testloader_list, transform_list):
        test_results = list()
        for testloader, tr in zip(testloader_list, transform_list):

            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                if self.transform == 'split':
                    offset_mask = self._generate_offset_mask(labels, tr)
                    l = self.sess.run(self.l_upper_bound_estimate, feed_dict={self.inputs: inputs,
                                                            self.labels: labels,
                                                            self.offset: offset_mask,
                                                            self.is_training:False})
                else:
                    l = self.sess.run(self.l_upper_bound_estimate, feed_dict={self.inputs: inputs,
                                        self.labels: labels,
                                        self.is_training:False})

                test_results.append(l)
        test_l = np.mean(test_results)
        return test_l


def nccl(raw_data, transform, num_tasks, seed, shuffle_tasks,
         batch_size, learning_rate, sample_memory_batch,
         network_architecture, epochs_per_task, memory_size, 
         l_upper_bound, memory, beta_max, logger_kwargs):
    """
    data_loader is not a instance, class itself.
    """
    # set logger
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    tf.set_random_seed(seed)
    np.random.seed(seed)

    input_shape = DATA_DESCRIPTER[raw_data]['input_shape']
    num_outputs = DATA_DESCRIPTER[raw_data]['num_outputs']

    
    task_continuum = TaskContinuumDataset(raw_data, transform, num_tasks, seed,
        shuffle_tasks=shuffle_tasks, train_samples_per_task=None)

    if len(network_architecture) ==1 and int(network_architecture[0])==-1:
        network = core.resnet
    else:
        network = core.mlp

    if memory == "random":
        Memory = RandomTaskReplayMemory
    elif memory == "reservoir":
        Memory = ReservoirMemory
    else:
        raise ValueError("Wrong type of memory")

    with tf.Session() as sess:
        # intialize model
        model = NonconvexContinualLearning(sess=sess, transform=task_continuum.transform, input_shape=input_shape, 
                                       num_outputs=num_outputs, num_tasks=num_tasks,
                                       memory_strength=None, learning_rate=learning_rate,
                                       network=network, hidden_layers=network_architecture, 
                                       beta_max=beta_max, L_upper_bound=l_upper_bound)

        # define a replay memory
        replaymemory = Memory(num_tasks=num_tasks, memory_size=memory_size, 
                                input_shape=input_shape, num_outputs=num_outputs, sample_batch_size=sample_memory_batch)
        
        sess.run(tf.global_variables_initializer())

        # Setup model saving
        logger.setup_tf_saver(sess, inputs={'inputs':model.inputs, }, 
                                    outputs={'labels':model.labels, })

        # start training
        total_steps = 0
        testloader_list = []
        transform_list = []
        start_time = time.time()
        total_epochs = 0

        for task_idx, (transform, train_data, test_data) in enumerate(task_continuum, 0):
            # define train loader and test loader
            trainloader = DataLoader(train_data, batch_size=batch_size, drop_last=True)
            testloader = DataLoader(test_data, batch_size=batch_size)
            testloader_list.append(testloader)
            transform_list.append(transform)

            # print(len(trainloader))
            if task_idx == 0:
                first_task_testloader = testloader
            # print(f"----------------- New task: {task_idx} ------------------------")
            for epoch in range(epochs_per_task):
                for i, data in enumerate(trainloader, 0):
                    inputs, labels = data
                    # print(transform, np.argmax(labels, axis=-1))

                    model.train_step(inputs, labels, replaymemory, task_idx, transform)
                    replaymemory.store(inputs, labels, task_idx)

                    total_steps += 1


                    # if i % LOG_EVERY_STEP == 1:
                    #     test_acc_ft = model.test_accuracy(first_task_testloader)
                    #     test_acc = model.test_accuracy(testloader)
                    #     running_test_acc = model.running_average_accuracy(testloader_list)
                    #     # avg_l = model.average_test_l_estimate(testloader_list, transform_list)

                    #     logger.store(TestAccT1=test_acc_ft, TestAccCT=test_acc, TestAvgAcc=running_test_acc) #, TestL=avg_l)


                    #     # Log info about epoch

                    #     logger.log_tabular('Task', task_idx)
                    #     logger.log_tabular('Epoch', total_epochs)
                    #     logger.log_tabular('steps', total_steps)
                    #     logger.log_tabular('TestAccT1', with_min_and_max=True)
                    #     logger.log_tabular('TestAccCT', with_min_and_max=True)
                    #     logger.log_tabular('TestAvgAcc', with_min_and_max=True)
                    #     # logger.log_tabular('TestL', with_min_and_max=True)
                    #     logger.log_tabular('Time', time.time() - start_time)
                    #     logger.dump_tabular()
                    #     # print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))


                test_acc_ft = model.test_accuracy(first_task_testloader)
                test_acc = model.test_accuracy(testloader)
                running_test_acc = model.running_average_accuracy(testloader_list)
                # avg_l = model.average_test_l_estimate(testloader_list, transform_list)
                logger.store(TestAccT1=test_acc_ft, TestAccCT=test_acc, TestAvgAcc=running_test_acc) #, TestL=avg_l)

                # Log info about epoch
                logger.log_tabular('Task', task_idx)
                logger.log_tabular('Epoch', total_epochs)
                logger.log_tabular('steps', total_steps)
                logger.log_tabular('TestAccT1', average_only=True)
                logger.log_tabular('TestAccCT', average_only=True)
                logger.log_tabular('TestAvgAcc', average_only=True)
                # logger.log_tabular('TestL', with_min_and_max=True)
                logger.log_tabular('Time', time.time() - start_time)
                logger.dump_tabular()
                # print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
                total_epochs += 1


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    # dataset setup
    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=3, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='rotation', help='transform transformation')
    parser.add_argument('--min_angle', default=0, type=float, help='minimum rotation angle')
    parser.add_argument('--max_angle', default=90, type=float, help='maximum rotation angle')

    # learning hyperparameters
    parser.add_argument('--batch_size', default=100, type=int, help="batch size")
    parser.add_argument('--epochs_per_task', default=1, type=int, help="epochs per task")
    parser.add_argument('--shuffle_tasks', default=True, type=bool, help="shuffle tasks")
    parser.add_argument('--learning_rate', default=0.03, type=float, help="learning_rate")
    parser.add_argument('--l_upper_bound', default=0.7, type=float, help="choose a proper upper bound of L")
    parser.add_argument('--beta_max', default=1.0, type=float, help="beta clipping value")
    parser.add_argument('--memory', default='random', help="memory type")

    # logging setup
    parser.add_argument('--log_every_step', default=100, type=int, help="logging step")
    parser.add_argument('--exp_name', type=str, default='nccl')

    # network and memory setup
    parser.add_argument('--memory_size', default=500, type=int, help="memory_size") 
    parser.add_argument('--sample_memory_batch', default=100, type=int, help="sample size of memory")
    parser.add_argument('--network_architecture', nargs='*', default=[400,400], help='fc-layers or convnet, convnet=-1')
    args = parser.parse_args()

    from continual.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    nccl(args.raw_data, args.transform, args.num_tasks, args.seed, args.shuffle_tasks, 
         args.batch_size, args.learning_rate, args.sample_memory_batch, args.network_architecture,
         args.epochs_per_task, args.memory_size, args.l_upper_bound, args.memory, args.beta_max, logger_kwargs)


    