# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from math import ceil
import torch
import numpy as np
from models.utils.continual_model import ContinualModel
import torchvision.transforms as transforms
from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser
from utils.buffer import Buffer
import wandb
import scipy.optimize as opt


def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description='Continual learning via Lagrangian Duality')
    add_management_args(parser)
    add_experiment_args(parser)
    add_rehearsal_args(parser)
    return parser

class CallyBoth(ContinualModel):
   
    NAME = 'cally_both'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        
        super(CallyBoth, self).__init__(backbone, loss, args, transform)
        
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.epsilon_task = args.epsilon
        self.epsilon_sample = args.epsilon_sample
        self.partitionmethod = args.partition_method
        print(f"Constraint upper bounds: {self.epsilon_task}, {self.epsilon_sample}")
        print(f"Partition method: {self.partitionmethod}")
        self.lr_dual = 0.05
        self.task = 0

    def partition(self, lambdas, buffer, method = 'gen_bound', n_min = None):
        """
        Compute buffer partition.
        """
        num_tasks = len(lambdas)
        buffer_size = buffer.buffer_size
        n = None
        
        if method == 'gen_bound':

            if n_min is None:
                n_min = buffer_size/(2*num_tasks)
            
            def objective(n, lambdas):
                return lambdas @ np.sqrt(np.log(n)/n)  
            
            n0 = np.ones(num_tasks)*buffer_size/num_tasks
            constraint = opt.LinearConstraint(A=np.ones_like(n0), lb=-np.inf, ub=buffer_size)
            opt_method = 'SLSQP' # 'trust-constr'
            n = opt.minimize(objective, n0, args=lambdas.numpy(), method='SLSQP', 
                                        bounds=[(n_min, buffer_size)],
                                            constraints = constraint)
            # Round n
            n = torch.from_numpy(np.round(n.x))
        
        else:
            raise NotImplementedError    

        return n

    def begin_task(self, dataset):
        
        self.lambdas_sample = torch.zeros(len(dataset.train_loader.dataset), requires_grad=False)
        self.buf_lambdas_sample = torch.zeros(self.args.buffer_size, requires_grad=False)

        if self.task > 0:
            self.lambdas_task = torch.ones(1, requires_grad=False)
            self.buf_lambdas_task = torch.ones(self.task, requires_grad=False)

    def end_task(self, dataset):
        
        samples_per_class = self.args.buffer_size // (dataset.N_CLASSES_PER_TASK*(self.task+1))

        if self.buffer.is_empty():

            imgs, targets = [], []

            space_per_class = np.zeros(dataset.N_CLASSES_PER_TASK*dataset.N_TASKS) 
            space_per_class[:dataset.N_CLASSES_PER_TASK] += samples_per_class

            idxs_lambdas_descending = (-self.lambdas_sample).argsort()
            num_outliers = 100
            
            imgs, targets = [], []

            for i in idxs_lambdas_descending[num_outliers:]:
                img, target, original_img, index = dataset.train_loader.dataset[i]
                if space_per_class[target] > 0:
                    imgs.append(original_img.unsqueeze(0))
                    targets.append(target)
                    space_per_class[target] -= 1
                if np.sum(space_per_class) == 0:
                    break     

            self.buffer.add_data(
                examples=torch.cat(imgs),
                labels=torch.Tensor(targets),
                task_labels=torch.zeros(len(targets))+self.task
            )
        
        else:
            
            # Get all buffer
            all_data = self.buffer.get_all_data(transform = None)
            
            # Empty old buffer to replace with new one
            self.buffer.empty()

            # Compute partition
            n = self.partition(torch.cat((self.lambdas_task, self.buf_lambdas_task), 0), self.buffer, self.partitionmethod)

            # Get samples from buffer uniformly from the past
            examples, labels, tasks = all_data

            space_per_class = torch.zeros(dataset.N_CLASSES_PER_TASK*dataset.N_TASKS) 
            for task in range(self.task):
                begin = task*dataset.N_CLASSES_PER_TASK
                space_per_class[begin:begin+dataset.N_CLASSES_PER_TASK] += n[task+1]//dataset.N_CLASSES_PER_TASK

            random_idxs = np.arange(len(labels))
            np.random.shuffle(random_idxs)
            
            imgs, targets, task_labels = [], [], []

            for i in random_idxs:
                if space_per_class[labels[i]] > 0:
                    imgs.append(examples[i].unsqueeze(0))
                    targets.append(labels[i])
                    task_labels.append(tasks[i])
                    space_per_class[labels[i]] -= 1
                if torch.sum(space_per_class) == 0:
                    break  

            self.buffer.add_data(
                examples=torch.cat(imgs),
                labels=torch.Tensor(targets),
                task_labels=torch.tensor(task_labels)
            )

            # Get samples from new task
            imgs, targets = [], []
            
            space_per_class = torch.zeros(dataset.N_CLASSES_PER_TASK*dataset.N_TASKS)
            begin = self.task*dataset.N_CLASSES_PER_TASK
            n[0] = self.buffer.buffer_size - self.buffer.num_seen_examples
            space_per_class[begin: begin+dataset.N_CLASSES_PER_TASK] += n[0]//dataset.N_CLASSES_PER_TASK

            idxs_lambdas_descending = (-self.lambdas_sample).argsort()
            # Example num outliers, dataset dependent
            num_outliers = 100

            imgs, targets = [], []

            for i in idxs_lambdas_descending[num_outliers:]:
                img, target, original_img, index = dataset.train_loader.dataset[i]
                if space_per_class[target] > 0:
                    imgs.append(original_img.unsqueeze(0))
                    targets.append(target)
                    space_per_class[target] -= 1
                if torch.sum(space_per_class) == 0:
                    break     

            self.buffer.add_data(
                examples=torch.cat(imgs),
                labels=torch.Tensor(targets),
                task_labels=torch.zeros(len(targets))+self.task
            )

        self.buffer.create_loader(size = self.args.minibatch_size, model_transform = self.transform)
        self.task += 1
                
    def observe(self, inputs, labels, not_aug_inputs, indexes):

        real_batch_size = inputs.shape[0]
        self.opt.zero_grad()
        
        loss = 0.
        lagrangian = 0.
        buf_losses_task = torch.zeros(self.task, device = self.device)

        if not self.buffer.is_empty():

            lambdas_task = self.lambdas_task.clone().to(self.device)
            lambdas_sample = self.lambdas_sample[indexes].clone().to(self.device)

            # Sample from Buffer
            buf_indexes, buf_inputs, buf_labels, buf_tasks = self.buffer.get_data_from_loader(size = self.args.minibatch_size)
            buf_lambdas_task = self.buf_lambdas_task.clone().to(self.device)
            buf_lambdas_sample = self.buf_lambdas_sample[buf_indexes].clone().to(self.device)

            # Forward Pass
            outputs = self.net(inputs)
            buf_outputs = self.net(buf_inputs)

            # Sample wise Lagrangian
            loss_sample = self.loss(outputs, labels, reduction = 'none')
            buf_loss_sample = self.loss(buf_outputs, buf_labels.squeeze(), reduction = 'none')
            lagrangian_sample = torch.mean(lambdas_sample*(loss-self.epsilon_sample)) + \
                                torch.mean(buf_lambdas_sample*(buf_loss_sample - self.epsilon_sample))
            
            # Task wise Lagrangian
            lagrangian_task = torch.mean(loss_sample).item()
            lagrangian_task += lambdas_task*(torch.mean(loss_sample) - self.epsilon_task)
            for task in range(self.task):
                if torch.sum(buf_tasks == task) > 1:
                    buf_losses_task[task] += self.loss(buf_outputs[buf_tasks==task], buf_labels[buf_tasks==task].squeeze())
            
            # Eval Lagrangian
            lagrangian_task += torch.sum(buf_lambdas_task*(buf_losses_task-self.epsilon_task))
            lagrangian = lagrangian_task + lagrangian_sample
                            
            # Primal Update
            lagrangian.backward()
            self.opt.step()

            # Eval Slacks sample
            lambdas_sample += self.lr_dual*(loss_sample - self.epsilon_sample)
            lambdas_sample = torch.nn.ReLU()(lambdas_sample)
            self.lambdas_sample[indexes] = lambdas_sample.detach().cpu()
           
           # Eval Slacks task
            buf_lambdas_sample += self.lr_dual*(buf_loss_sample - self.epsilon_sample)
            buf_lambdas_sample = torch.nn.ReLU()(buf_lambdas_sample)
            self.buf_lambdas_sample[buf_indexes] = buf_lambdas_sample.detach().cpu()

            # Dual Update Sample
            lambdas_task += self.lr_dual*(torch.mean(loss_sample) - self.epsilon_task)
            lambdas_task = torch.nn.ReLU()(lambdas_task)
            self.lambdas_task = lambdas_task.detach().cpu()
           
           # Dual Update Task
            buf_lambdas_task += self.lr_dual*(buf_losses_task - self.epsilon_task)
            buf_lambdas_task = torch.nn.ReLU()(buf_lambdas_task)
            self.buf_lambdas_task = buf_lambdas_task.detach().cpu()
            
        else:
            
            lambdas_sample = self.lambdas_sample[indexes].clone().to(self.device)

            # Forward Pass
            outputs = self.net(inputs)
            
            # Evaluate Lagrangian
            loss_sample = self.loss(outputs, labels, reduction = 'none')
            lagrangian = torch.mean(loss_sample)+torch.mean(lambdas_sample*(loss_sample - self.epsilon_sample))

            # Primal Update
            lagrangian.backward()
            self.opt.step()

            # Dual Update
            lambdas_sample += self.lr_dual*(loss_sample - self.epsilon_sample)
            lambdas_sample = torch.nn.ReLU()(lambdas_sample)
            self.lambdas_sample[indexes] = lambdas_sample.detach().cpu()

        return torch.mean(loss_sample).item()
