import tensorflow as tf
from algorithm import *

class Trainer():
    def __init__(self, model, datastream, args):
        super(Trainer, self).__init__()
        self.args = args
        self.method = args.imp_method
        self.datastream = datastream
        self.task_num = datastream.__len__()
        self.BuildObjective(model)
        self.BuildOptimizer(args.optimizer, args.lr)
        # batch level 
        self.batch_gradients = []
        self.batch_losses = []

        self.loss_mem = [] # store the losses from the latest 5 batchs 
        self.loss_start = [] # store the initial losses for each task
        self.mem_loss_start = [] # initial memory loss when a task end

        # Gradnorm: init loss for each task
        self.init_losses = {}
        self.init_mem_losses = []

        # CVW: Continuous record
        self.losses_continuous = {k:[] for k in range(self.task_num)}
        self.gradnorm_continuous = {k:[] for k in range(self.task_num)}
        self.losses_continuous_curr = []
        self.gradnorm_continuous_curr = []
        self.losses_continuous_mean_curr = []
        self.mem_losses_continuous = []
        self.mem_gradnorm_continuous = []

        # DWA: l_{k-1}/l_{k-2}
        self.previous_loss_ratios = {k: -1. for k in range(self.task_num)}
        self.previous_loss_ratios_curr = []
        self.mem_previous_loss_ratios = 0.

    def BuildObjective(self, model):
        self.model = model
        self.split_loss_object = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
        self.split_loss_object_for_onehot = tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
        self.loss_object_for_onehot = tf.keras.losses.CategoricalCrossentropy()
        self.train_loss = [tf.keras.metrics.Mean(name='train_loss_{}'.format(i)) for i in range(self.task_num)]
        self.test_loss = [tf.keras.metrics.Mean(name='test_loss_{}'.format(i)) for i in range(self.task_num)]

        self.train_acc = [tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy_{}'.format(i)) for i in range(self.task_num)]
        self.test_acc = [tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy_{}'.format(i)) for i in range(self.task_num)]
        self.previous_losses = []

    def BuildOptimizer(self, optimizer, learning_rate):
        if optimizer == 'sgd':
            self.optimizer = tf.keras.optimizers.SGD(learning_rate)
            self.image_optimizer = tf.keras.optimizers.SGD(learning_rate)
        elif optimizer == 'adam':
            self.optimizer = tf.keras.optimizers.Adam(learning_rate)
            self.image_optimizer = tf.keras.optimizers.Adam(learning_rate)
        else:
            raise Exception('Invalid optimizer {}'.format(optimizer))
        
    def compute_loss(self, labels, predictions):
        per_example_loss = self.split_loss_object(labels, predictions)
        avg_loss =  tf.nn.compute_average_loss(per_example_loss, global_batch_size=labels.shape[0])
        return avg_loss, per_example_loss

    def compute_loss_for_onehot(self, labels, predictions):
        per_example_loss = self.split_loss_object_for_onehot(labels, predictions)
        avg_loss =  tf.nn.compute_average_loss(per_example_loss, global_batch_size=labels.shape[0])
        return avg_loss, per_example_loss

    def train_step(self, images, labels, task_id, mask):
        with tf.GradientTape() as tape:
            predictions = self.model(images, mask, training=True)
            loss, per_example_loss = self.compute_loss(labels, predictions)
        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.batch_gradients.append(gradients)
        self.batch_losses.append(loss)
        if task_id != None:
            self.train_loss[task_id](loss)
            self.train_acc[task_id](labels, predictions)
        
    def train_step_for_onehot(self, images, labels, mask):
        with tf.GradientTape(persistent=True) as tape:
            predictions = self.model(images, mask, training=True)
            loss, per_example_loss = self.compute_loss_for_onehot(labels, predictions)
        gradients = tape.gradient(loss, self.model.trainable_variables)

        self.batch_gradients.append(gradients)
        self.batch_losses.append(loss)

    def reset_state(self):
        for task in range(self.task_num):
            self.train_loss[task].reset_states()
            self.train_acc[task].reset_states()

    def InitializeBatch(self):
        self.batch_losses = []
        self.batch_gradients = []

    def GetMemoryGradient(self, batch_id):
        end_time = [e for s, e in self.datastream.TimeLine]
        end_time_plus = [e+1 for s, e in self.datastream.TimeLine]
        end_time_plus_plus = [e+2 for s, e in self.datastream.TimeLine]
        if batch_id > min(end_time)  and self.args.with_mem in ['ur', 'ur_reduce']:
            images, labels, masks = self.datastream.MemStream.get_next()
            self.train_step_for_onehot(images, labels, masks) # watch out, here is onehot

            # save init memory loss for gradnorm
            if batch_id in end_time_plus:
                self.init_mem_losses.append(self.batch_losses[-1])
            
            # DWA: if use memory, the first two will be set to 1
            if len(self.mem_losses_continuous) > 0:
                if batch_id in end_time_plus + end_time_plus_plus:
                    self.mem_previous_loss_ratios = 1.
                else:
                    self.mem_previous_loss_ratios = self.mem_losses_continuous[-1]/self.mem_losses_continuous[-2]

            # Update the continuous mem losses
            self.mem_losses_continuous.append(self.batch_losses[-1])

    def GetCurrentGradient(self, batch_id):
        # reset the continus curr
        self.losses_continuous_mean_curr = []
        self.losses_continuous_curr = []
        self.gradnorm_continuous_curr = []
        self._g_c = []
        
        start_time = [s for s, e in self.datastream.TimeLine]
        for task_id in range(self.datastream.__len__()):
            if batch_id>=self.datastream.TimeLine[task_id][0] and batch_id<=self.datastream.TimeLine[task_id][1]:                        
                images, labels = self.datastream.TrainStream[task_id].get_next()
                self.train_step(tf.Variable(images), labels, task_id, self.datastream.MaskSet[task_id])

                # DWA: 
                self.previous_loss_ratios_curr.append(self.previous_loss_ratios[task_id])
                # CVW: update the continuous losses mean
                if batch_id != self.datastream.TimeLine[task_id][0]:
                    self.losses_continuous_mean_curr.append(tf.reduce_mean(self.losses_continuous[task_id]))
                else:
                    self.losses_continuous_mean_curr.append(-1.0)
   
                # DWA: update previous two loss ratio
                if batch_id in [self.datastream.TimeLine[task_id][0], self.datastream.TimeLine[task_id][0]+1]:
                    self.previous_loss_ratios[task_id] = 1.
                else:
                    self.previous_loss_ratios[task_id] = self.losses_continuous[task_id][-1]/self.losses_continuous[task_id][-2]

                
                self.losses_continuous[task_id].append(self.batch_losses[-1])

                # GradNorm: write the init loss for gradnorm
                if batch_id in start_time:
                    self.init_losses[task_id] = self.batch_losses[-1]

                self.losses_continuous_curr.append(self.losses_continuous[task_id])


    def Update(self, method, i):
        if len(self.batch_gradients) == 1: # for one task only
            d = self.batch_gradients[0]
            self.optimizer.apply_gradients(zip(d, self.model.trainable_variables))
        else: # for two or more tasks
            if method == 'mgda': # MGDA
                d = mgda.ComputeGradient(self.batch_gradients)
            elif method == 'pcgrad': # PCGrad
                d = pcgrad.ComputeGradient(self.batch_gradients)
            elif method == 'gradnorm': # GradNorm
                # insert the memory init loss at the begining
                if len(self.init_mem_losses) == 0:
                    init_losses = [v for k, v in self.init_losses.items()]
                else:
                    init_losses = [tf.reduce_mean(self.init_mem_losses)] + [v for k, v in self.init_losses.items()]
                d = gradnorm.ComputeGradient(self.batch_gradients, self.batch_losses, init_losses)
            elif method == 'cvw': # CV-Weighting
                if len(self.mem_losses_continuous) > 1:
                    self.losses_continuous_mean_curr.insert(0, tf.reduce_mean(self.mem_losses_continuous[:-1]))
                elif len(self.mem_losses_continuous) == 1:
                    self.losses_continuous_mean_curr.insert(0, -1.0)
                else:
                    pass
                d = cvw.ComputeGradient(self.batch_gradients, self.batch_losses, self.losses_continuous_mean_curr)
            elif method == 'rlw': # Random Loss Weighting
                d = rlw.ComputeGradient(self.batch_gradients)
            elif method == 'dwa': #
                if len(self.mem_losses_continuous) > 0:
                    previous_loss_ratios = [self.mem_previous_loss_ratios] + self.losses_continuous_mean_curr
                else:
                    previous_loss_ratios = self.losses_continuous_mean_curr
                d = dwa.ComputeGradient(self.batch_gradients, previous_loss_ratios)
            elif method == 'graddrop':
                d = graddrop.ComputeGradient(self.batch_gradients)
            elif method =='maxdo':
                d = maxdo.ComputeGradient(self.batch_gradients)
            else:
                raise Exception('Invalid method')
            self.optimizer.apply_gradients(zip(d, self.model.trainable_variables))
            d.clear() # to release GPU
        
        for task_id in range(self.datastream.__len__()):
            if i==self.datastream.TimeLine[task_id][1] and self.args.with_mem in ['ur', 'ur_reduce']:
                self.datastream.UpdataeMemory(task_id, self.args.with_mem, mem_split='ref') # Construct datastream


