import tensorflow as tf
import numpy as np
import quadprog
import time
import sys, os

try:
    from . import core
except:
    sys.path.append(os.path.abspath(__file__))
    import core


# sys.path.append(os.path.abspath('../../../'))
from continual.utils.logx import EpochLogger
from continual.data.dataloader import TaskContinuumDataset, DataLoader
from continual.algos.memory import RandomTaskReplayMemory,ReservoirMemory

_MNIST_DATA_DESCRIPTER = {'input_shape':[28,28], 'num_outputs':10}
_CIFAR100_DATA_DESCRIPTER = {'input_shape':[32,32,3], 'num_outputs':100}
_CIFAR10_DATA_DESCRIPTER = {'input_shape':[32,32,3], 'num_outputs':10}

DATA_DESCRIPTER = {'mnist': _MNIST_DATA_DESCRIPTER,
                    'cifar100': _CIFAR100_DATA_DESCRIPTER,
                    'cifar10': _CIFAR10_DATA_DESCRIPTER}

NEG_INF = -10.0e10 #-1.0e32 #-10.0e10
log_every_step = 20

class _RandomTaskReplayMemory:
    """
   A simple Replay memory for nccl
    """
    def __init__(self, num_tasks, memory_size, input_shape, num_outputs, sample_batch_size):
        if not isinstance(input_shape, list):
            input_shape = list(input_shape)
        self.input_memory = np.zeros([num_tasks, memory_size] + input_shape, dtype=np.float32)
        self.label_memory = np.zeros([num_tasks, memory_size, num_outputs], dtype=np.int32)
        self.sample_batch_size = sample_batch_size
        self.memory_size = memory_size
        self.num_tasks = num_tasks

        self.num_outputs = num_outputs
        self.num_classes_per_task = int(self.num_outputs / self.num_tasks)
        self.ptr = 0
        

    def store(self, inputs, labels, task_idx):
        batch_size = inputs.shape[0]
        start_point = self.ptr
        if self.ptr + batch_size > self.memory_size:
            end_point = self.memory_size            
            start_point2 = 0
            end_point2 = (self.ptr + batch_size) - self.memory_size

            data_limit = self.memory_size - self.ptr 

            self.input_memory[task_idx, start_point:end_point] = inputs[:data_limit]
            self.label_memory[task_idx, start_point:end_point] = labels[:data_limit]

            self.input_memory[task_idx, start_point2:end_point2] = inputs[data_limit:]
            self.label_memory[task_idx, start_point2:end_point2] = labels[data_limit:]

            self.ptr = end_point2
        else:
            end_point = self.ptr + batch_size
            self.input_memory[task_idx, start_point:end_point] = inputs
            self.label_memory[task_idx, start_point:end_point] = labels
            self.ptr = self.ptr + batch_size

    def sample(self, current_task_idx):
        """
        memory architecture: [T, M, H, W, C]
        B samples on [T, M]
        """
        a = current_task_idx * self.memory_size
        sampled_idx = np.random.choice(a, size=self.sample_batch_size*current_task_idx, replace=False)
        mask = np.zeros(self.num_tasks * self.memory_size)
        mask[sampled_idx] = 1
        mask = mask.reshape(self.num_tasks, self.memory_size).astype(np.bool)
        
        inputs = self.input_memory[mask]
        labels = self.label_memory[mask]

        return inputs, labels

    def sample_for_split(self, current_task_idx):
        a = current_task_idx * self.memory_size
        sampled_idx = np.random.choice(a, size=self.sample_batch_size, replace=False)
        mask = np.zeros(self.num_tasks * self.memory_size)
        mask[sampled_idx] = 1
        mask = mask.reshape(self.num_tasks, self.memory_size).astype(np.bool)

        inputs = self.input_memory[mask]
        labels = self.label_memory[mask]

        labels_int = np.argmax(labels, axis=-1)
        #offsets = np.stack([np.array([labels_int[i] // self.num_classes_per_task, labels_int[i] // self.num_classes_per_task+1] for i in range(len(labels_int)))])
        
        offsets = np.zeros_like(labels, np.float32)
        for i in range(labels.shape[0]):
            c = int(labels_int[i] // self.num_classes_per_task)
            offsets[i, self.num_classes_per_task * c: self.num_classes_per_task * (c+1)] = 1.0
        
        return inputs, labels, offsets

class NonconvexContinualLearning:
    def __init__(self, sess, transform, input_shape, num_outputs, 
                 num_tasks, memory_strength, learning_rate,
                 network=core.mlp, hidden_layers=[100, 100], beta_max=1.0, L_upper_bound=10.0, L_estimate=True):
        self.sess = sess

        # self.L = (2.0 * L_coeff) / self.learning_rate 
        self.L = L_upper_bound
        self.transform = transform
        self.beta_max = beta_max

        self.learning_rate = learning_rate
        self.num_outputs = num_outputs
        input_batch_shape = [None] + input_shape   
        self.is_training = tf.placeholder_with_default(True, shape=None)

        # current data stream 
        self.inputs = tf.placeholder(tf.float32, shape=input_batch_shape, name="inputs")
        self.labels = tf.placeholder(tf.int32, shape=[None, num_outputs], name="labels")
        if transform == "split":
            self.offset = tf.placeholder(tf.int32, shape=[None, num_outputs], name="offset")

        with tf.variable_scope("network"):
            if transform == 'split':
                self.logits = network(self.inputs, hidden_layers, num_outputs, self.is_training)
                self.pruned_logits = self._pruning(self.logits, self.offset)
                self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.pruned_logits)
            else:
                self.logits = network(self.inputs, hidden_layers, num_outputs, self.is_training)
                self.cross_entropy_loss = tf.losses.softmax_cross_entropy(self.labels, self.logits)

            # self.loss_c = self.cross_entropy_loss # + 0.00001 * tf.losses.get_regularization_loss('current')

        # previous samples
        self.inputs_p = tf.placeholder(tf.float32, shape=input_batch_shape, name='input_p')
        self.labels_p = tf.placeholder(tf.int32, shape=[None, num_outputs], name='label_p')
        if transform == 'split':
            self.offset_p = tf.placeholder(tf.int32, shape=[None, num_outputs], name='offset_p')

        # concat samples
        self.input_concat = tf.keras.layers.Concatenate(axis=0)([self.inputs_p, self.inputs])
        self.label_concat = tf.keras.layers.Concatenate(axis=0)([self.labels_p, self.labels])
        if transform == 'split':
            self.offset_concat = tf.keras.layers.Concatenate(axis=0)([self.offset_p, self.offset])

        with tf.variable_scope("network", reuse=tf.AUTO_REUSE):
            if transform == 'split':
                self.logits_con = network(self.input_concat, hidden_layers, num_outputs, self.is_training)
                self.pruned_logits_con = self._pruning(self.logits_con, self.offset_concat)
                self.cross_entropy_loss_con = tf.losses.softmax_cross_entropy(self.label_concat, self.pruned_logits_con)
                # tf.nn.softmax_cross_entropy_with_logits_v2(
                    # self.label_concat, self.pruned_logits_con)
            else:
                self.logits_con = network(self.input_concat, hidden_layers, num_outputs, self.is_training)
                self.cross_entropy_loss_con = tf.nn.softmax_cross_entropy_with_logits_v2(
                    self.label_concat, self.logits_con)

            # reg_losses_p = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'previous')
            # print(self.cross_entropy_loss_con)
            self.loss = tf.reduce_mean(self.cross_entropy_loss_con)# + tf.losses.get_regularization_loss()
            # self.loss_p = tf.reduce_mean(self.cross_entropy_loss_con[:tf.shape(self.inputs_p)[0]]) # + 0.00001 * tf.losses.get_regularization_loss('previous')
            # self.loss_c = tf.reduce_mean(self.cross_entropy_loss_con[tf.shape(self.inputs_p)[0]:])
        # print(self.cross_entropy_loss)
        # print(tf.nn.softmax_cross_entropy_with_logits_v2(self.labels, self.pruned_logits)[0])
        # print(self.loss_p)
        # print(tf.trainable_variables())
        # time.sleep(10000)
        """
        # previous samples
        self.inputs_p = tf.placeholder(tf.float32, shape=input_batch_shape, name='input_p')
        self.labels_p = tf.placeholder(tf.int32, shape=[None, num_outputs], name='label_p')
        if transform == 'split':
            self.offset_p = tf.placeholder(tf.int32, shape=[None, num_outputs], name='offset_p')
        
        with tf.variable_scope("network", reuse=tf.AUTO_REUSE):
            if transform == 'split':
                self.logits_p = network(self.inputs_p, hidden_layers, num_outputs, self.is_training)
                self.pruned_logits_p = self._pruning(self.logits_p, self.offset_p)
                self.cross_entropy_loss_p = tf.losses.softmax_cross_entropy(self.labels_p, self.pruned_logits_p)
            else:
                self.logits_p = network(self.inputs_p, hidden_layers, num_outputs, self.is_training)
                self.cross_entropy_loss_p = tf.losses.softmax_cross_entropy(self.labels_p, self.logits_p)

            # reg_losses_p = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'previous')
            self.loss_p = self.cross_entropy_loss_p # + 0.00001 * tf.losses.get_regularization_loss('previous')

        # self.input_concat = tf.concat([self.inputs_p, self.inputs], 0)
        # self.label_concat = tf.concat([self.labels_p, self.labels], 0)
        # self.offset_concat = tf.concat([self.offset_p, self.offset], 0)

        # with tf.variable_scope("network", reuse=tf.AUTO_REUSE):
        #     if transform == 'split':
        #         self.logits_con = network(self.input_concat, hidden_layers, num_outputs, self.is_training)
        #         self.pruned_logits_con = self._pruning(self.logits_con, self.offset_concat)
        #         self.cross_entropy_loss_con = tf.losses.softmax_cross_entropy(self.label_concat, self.pruned_logits_con)
        #     else:
        #         self.logits_con = network(self.input_concat, hidden_layers, num_outputs, self.is_training)
        #         self.cross_entropy_loss_con = tf.losses.softmax_cross_entropy(self.label_concat, self.logits_con)

        """
        
        # self.optimizer = tf.train.MomentumOptimizer(self.learning_rate, 0.9)
        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)

        # self.previous_g = tf.gradients(self.loss_p, tf.trainable_variables())
        # self.current_g = tf.gradients(self.loss_c, tf.trainable_variables())

        # reg = 0.00001
        # self.reg_grads = tf.gradients(tf.losses.get_regularization_loss(), tf.trainable_variables())
        # print(tf.losses.get_regularization_losses('network'))
        # print(self.reg_grads)
        # time.sleep(10000)
        self.current_gvs = self.optimizer.compute_gradients(self.cross_entropy_loss) # + tf.losses.get_regularization_loss())
        self.gvs = self.optimizer.compute_gradients(self.loss)

        # learning rates for each task
        # self.ip= self._inner_product(self.previous_g, self.current_g)
        # self.gp_norm_square = self._inner_product(self.previous_g, self.previous_g)
        # self.gc_norm_square = self._inner_product(self.current_g, self.current_g)

        lr = 1.0


        self.alpha = 1.0
        self.beta = 1.0

        # by the definition of tensorflow optimizer, we should divide it by learning_rate

        # self.alpha = tf.where(self.ip <= 0, lr * (1.0 - self.ip / self.gp_norm_square), lr)
        # # # # self.beta = tf.where(self.ip <= 0, lr,  (self.ip * (1 - self.learning_rate * self.L)) / (self.L * self.gc_norm_square))
        # # # # beta 조정..
        # beta_H =  (self.ip * (1 - self.learning_rate * self.L)) / (self.L * self.gc_norm_square * self.learning_rate)
        # beta_H = tf.clip_by_value(beta_H, 0, self.beta_max)
        # self.beta = tf.where(self.ip <= 0, lr, beta_H)

        grad = []
        # for gp, gc, gr in zip(self.previous_g, self.current_g, self.reg_grads):
        #     if gr is None:
        #         grad.append(self.alpha * gp + self.beta * gc)
        #     else:
        #         grad.append(self.alpha * gp + self.beta * gc + gr)
        # grad = [self.alpha * gp + self.beta * gc for gp, gc in zip(self.previous_g, self.current_g)]

        # self.grad_norm = tf.sqrt(self._inner_product(grad, grad))
        # self.nccl_gvs = [(g, v) for g, v in zip(grad, tf.trainable_variables())]
        scaled_cgvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in self.current_gvs if g is not None]

        # if network is not core.mlp:
        #     gvs1 = []
        #     gvs2 = []

        #     _gvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in self.gvs]
        #     _nccl_gvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in self.nccl_gvs]           
        #     for gv1, gv2 in zip(_gvs, _nccl_gvs):
        #         # print(gv1[1].name)
        #         if 'bn' in gv1[1].name:
        #             # print(gv1[1].name)
        #             gvs1.append(gv1)
        #         else:
        #             gvs2.append(gv2)

            
        #     with tf.control_dependencies([self.optimizer.apply_gradients(gvs1)]):
        #         self.nccl_op = self.optimizer.apply_gradients(gvs2)
        # else:
        #     gvs = self.nccl_gvs
        #     gvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in gvs]
        #     self.nccl_op = self.optimizer.apply_gradients(gvs)
        scaled_gvs = [(tf.clip_by_value(g, -0.5, 0.5), v) for g, v in self.gvs]
        self.nccl_op = self.optimizer.apply_gradients(scaled_gvs)
        self.train_op = self.optimizer.apply_gradients(scaled_cgvs)


        if transform == 'split':
            self.train_step = self.train_step_for_split
        else:
            self.train_step = self.train_step_for_general
        
        # test operators
        self.equal = tf.equal(tf.argmax(self.logits, axis=-1), tf.argmax(self.labels, axis=-1))
        # self.L_lower_bound_of_exp = 1 / self.grad_norm

        # self.l_estimate() 


    def l_estimate(self):  
        one_step_update = [v - self.learning_rate * g for g, v in self.current_gvs]
        temp_vars = [v for g, v in self.current_gvs]

        assign_one_step_update = [tf.assign(r, v) for r, v in zip(tf.trainable_variables(), one_step_update)]
        with tf.control_dependencies(assign_one_step_update):
            next_step_grad = tf.gradients(self.cross_entropy_loss, tf.trainable_variables())
            with tf.control_dependencies(next_step_grad):
                assign_orginal = [tf.assign(r, v) for r, v in zip(tf.trainable_variables(), temp_vars)]

        self.grad_diff = [g1 - g2 for g1, g2 in zip(self.current_g, next_step_grad) ]
        self.grad_diff_norm = tf.sqrt(self._inner_product(self.grad_diff, self.grad_diff))
        self.var_diff = self.learning_rate *tf.sqrt(self._inner_product(self.current_g, self.current_g))
        self.l_upper_bound_estimate = self.grad_diff_norm / self.var_diff


    def _pruning(self, logits, offset_mask):
        # equal = tf.equal(offset_mask,1)
        # pl = tf.where(equal, logits, NEG_INF * tf.ones_like(logits, dtype=tf.float32))
        # return pl
        return logits
    
    def _inner_product(self, g1_list, g2_list):
        sum_prod = 0
        for g1, g2 in zip(g1_list, g2_list):
            sum_prod += tf.reduce_sum(tf.math.multiply(g1, g2))
        
        return sum_prod


    def train_step_for_general(self, inputs, labels, replaymemory, task_idx, transform=None):
        if task_idx == 0:
            feed_dict = {self.inputs:inputs, self.labels:labels}
            self.sess.run(self.train_op, feed_dict=feed_dict)
        else:
            input_p, label_p = replaymemory.sample(task_idx)
            feed_dict = {self.inputs:inputs, self.labels:labels,
                         self.inputs_p:input_p, self.labels_p:label_p}
            
            self.sess.run(self.nccl_op, feed_dict=feed_dict)

    def _generate_offset_mask(self, labels, offset_class):
        """
        labels: one_hot
        offset_class: c_start,  c_end
        """
        offset_mask = np.zeros_like(labels, np.int32)
        offset_mask[:, offset_class[0]:offset_class[1]] = 1
        # print(np.argmax(labels, axis=-1))
        return offset_mask

    def train_step_for_split(self, inputs, labels, replaymemory, task_idx, offset_class):
        offset_mask = self._generate_offset_mask(labels, offset_class)
        # print(offset_mask) 
        if task_idx == 0:
            feed_dict = {self.inputs:inputs, self.labels:labels, 
                self.offset:offset_mask}
            self.sess.run(self.train_op, feed_dict=feed_dict)
        else:
            input_p, label_p, offset_p = replaymemory.sample_for_split(task_idx)
            feed_dict = {self.inputs:inputs, self.labels:labels, self.offset:offset_mask,
                         self.inputs_p:input_p, self.labels_p:label_p, self.offset_p:offset_p}

            self.sess.run(self.nccl_op, feed_dict= feed_dict)
            l, lab, _ = self.sess.run([self.cross_entropy_loss_con, self.label_concat,self.nccl_op], feed_dict= feed_dict)
            # print(l)
            # print('\n')
            # print(lab)
            # _, a,b, ip = self.sess.run([self.nccl_op, self.alpha, self.beta, self.ip], feed_dict= feed_dict)
            # print(a,b, ip)

    def test_accuracy(self, testloader):
        test_results = list()
        for i, data in enumerate(testloader, 0):
            inputs, labels = data
            e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                     self.labels: labels,
                                                     self.is_training:False})
            test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc

    def running_average_accuracy(self, testloader_list):
        test_results = list()
        for testloader in testloader_list:
            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                e = self.sess.run(self.equal, feed_dict={self.inputs: inputs,
                                                        self.labels: labels,
                                                        self.is_training:False})
                test_results.append(e)
        test_acc = np.mean(np.concatenate(test_results))
        return test_acc

    def average_test_l_estimate(self, testloader_list, transform_list):
        test_results = list()
        for testloader, tr in zip(testloader_list, transform_list):

            for i, data in enumerate(testloader, 0):
                inputs, labels = data
                if self.transform == 'split':
                    offset_mask = self._generate_offset_mask(labels, tr)
                    l = self.sess.run(self.l_upper_bound_estimate, feed_dict={self.inputs: inputs,
                                                            self.labels: labels,
                                                            self.offset: offset_mask,
                                                            self.is_training:False})
                else:
                    l = self.sess.run(self.l_upper_bound_estimate, feed_dict={self.inputs: inputs,
                                        self.labels: labels,
                                        self.is_training:False})

                test_results.append(l)
        test_l = np.mean(test_results)
        return test_l


def nccl(raw_data, transform, num_tasks, seed, shuffle_tasks,
         batch_size, learning_rate, sample_memory_batch,
         network_architecture, epochs_per_task, memory_size, 
         l_upper_bound, memory, beta_max, logger_kwargs):
    """
    data_loader is not a instance, class itself.
    """
    # set logger
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    tf.set_random_seed(seed)
    np.random.seed(seed)

    input_shape = DATA_DESCRIPTER[raw_data]['input_shape']
    num_outputs = DATA_DESCRIPTER[raw_data]['num_outputs']

    
    task_continuum = TaskContinuumDataset(raw_data, transform, num_tasks, seed,
        shuffle_tasks=shuffle_tasks, train_samples_per_task=None)

    if len(network_architecture) ==1 and int(network_architecture[0])==-1:
        network = core.resnet
    else:
        network = core.mlp

    if memory == "random":
        Memory = RandomTaskReplayMemory
    elif memory == "reservoir":
        Memory = ReservoirMemory
    else:
        raise ValueError("Wrong type of memory")

    with tf.Session() as sess:
        # intialize model
        model = NonconvexContinualLearning(sess=sess, transform=task_continuum.transform, input_shape=input_shape, 
                                       num_outputs=num_outputs, num_tasks=num_tasks,
                                       memory_strength=None, learning_rate=learning_rate,
                                       network=network, hidden_layers=network_architecture, 
                                       beta_max=beta_max, L_upper_bound=l_upper_bound)

        # define a replay memory
        replaymemory = Memory(num_tasks=num_tasks, memory_size=memory_size, 
                                input_shape=input_shape, num_outputs=num_outputs, sample_batch_size=sample_memory_batch)
        
        sess.run(tf.global_variables_initializer())

        # Setup model saving
        logger.setup_tf_saver(sess, inputs={'inputs':model.inputs, 'inputs_p':model.inputs_p}, 
                                    outputs={'labels':model.labels, 'labels_p':model.labels_p})

        # start training
        total_steps = 0
        testloader_list = []
        transform_list = []
        start_time = time.time()
        total_epochs = 0

        for task_idx, (transform, train_data, test_data) in enumerate(task_continuum, 0):
            # define train loader and test loader
            trainloader = DataLoader(train_data, batch_size=batch_size)
            testloader = DataLoader(test_data, batch_size=batch_size)
            testloader_list.append(testloader)
            transform_list.append(transform)

            # print(len(trainloader))
            if task_idx == 0:
                first_task_testloader = testloader
            # print(f"----------------- New task: {task_idx} ------------------------")
            for epoch in range(epochs_per_task):
                for i, data in enumerate(trainloader, 0):
                    inputs, labels = data
                    # print(transform, np.argmax(labels, axis=-1))

                    model.train_step(inputs, labels, replaymemory, task_idx, transform)
                    replaymemory.store(inputs, labels, task_idx)

                    total_steps += 1                
                
                test_acc_ft = model.test_accuracy(first_task_testloader)
                test_acc = model.test_accuracy(testloader)
                running_test_acc = model.running_average_accuracy(testloader_list)
                # avg_l = model.average_test_l_estimate(testloader_list, transform_list)
                logger.store(TestAccT1=test_acc_ft, TestAccCT=test_acc, TestAvgAcc=running_test_acc) #, TestL=avg_l)


                # Log info about epoch
                total_epochs += 1
                logger.log_tabular('Task', task_idx)
                logger.log_tabular('Epoch', total_epochs)
                logger.log_tabular('TestAccT1', with_min_and_max=True)
                logger.log_tabular('TestAccCT', with_min_and_max=True)
                logger.log_tabular('TestAvgAcc', with_min_and_max=True)
                # logger.log_tabular('TestL', with_min_and_max=True)
                logger.log_tabular('Time', time.time() - start_time)
                logger.dump_tabular()
                # print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    # dataset setup
    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=3, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='rotation', help='transform transformation')
    parser.add_argument('--min_angle', default=0, type=float, help='minimum rotation angle')
    parser.add_argument('--max_angle', default=90, type=float, help='maximum rotation angle')

    # learning hyperparameters
    parser.add_argument('--batch_size', default=100, type=int, help="batch size")
    parser.add_argument('--epochs_per_task', default=1, type=int, help="epochs per task")
    parser.add_argument('--shuffle_tasks', default=True, type=bool, help="shuffle tasks")
    parser.add_argument('--learning_rate', default=0.03, type=float, help="learning_rate")
    parser.add_argument('--l_upper_bound', default=0.7, type=float, help="choose a proper upper bound of L")
    parser.add_argument('--beta_max', default=1.0, type=float, help="beta clipping value")
    parser.add_argument('--memory', default='random', help="memory type")

    # logging setup
    # parser.add_argument('--log_every_step', default=20, type=int, help="logging step")
    parser.add_argument('--exp_name', type=str, default='nccl')

    # network and memory setup
    parser.add_argument('--memory_size', default=500, type=int, help="memory_size") 
    parser.add_argument('--sample_memory_batch', default=100, type=int, help="sample size of memory")
    parser.add_argument('--network_architecture', nargs='*', default=[400,400], help='fc-layers or convnet, convnet=-1')
    args = parser.parse_args()

    from continual.utils.run_utils import setup_logger_kwargs
    logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)

    nccl(args.raw_data, args.transform, args.num_tasks, args.seed, args.shuffle_tasks, 
         args.batch_size, args.learning_rate, args.sample_memory_batch, args.network_architecture,
         args.epochs_per_task, args.memory_size, args.l_upper_bound, args.memory, args.beta_max, logger_kwargs)


    