# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from math import ceil
import torch
import numpy as np
from models.utils.continual_model import ContinualModel
import torchvision.transforms as transforms
from utils.args import add_management_args, add_experiment_args, add_rehearsal_args, ArgumentParser
from utils.buffer import Buffer
import wandb
import scipy.optimize as opt


def get_parser() -> ArgumentParser:
    parser = ArgumentParser(description='Continual learning via Lagrangian Duality')
    add_management_args(parser)
    add_experiment_args(parser)
    add_rehearsal_args(parser)
    parser.add_argument('--partitionmethod', type=str, default='linear')
    return parser

class CallyPerTask(ContinualModel):
   
    NAME = 'cally_per_task'
    COMPATIBILITY = ['class-il', 'domain-il', 'task-il', 'general-continual']

    def __init__(self, backbone, loss, args, transform):
        
        super(CallyPerTask, self).__init__(backbone, loss, args, transform)
        
        self.buffer = Buffer(self.args.buffer_size, self.device)
        self.epsilon = args.epsilon
        self.partitionmethod = args.partitionmethod
        print(f"Constraint upper bound: {self.epsilon}")
        print(f"Partition method: {self.partitionmethod}")
        self.lr_dual = 0.05
        self.task = 0

    def partition(self, lambdas, buffer, method = 'linear'):
        """
        Compute buffer partition.
        Methods: linear, gen_bound, uniform
        """
        num_tasks = len(lambdas)
        buffer_size = buffer.buffer_size
        n = None

        if method == 'uniform':
            n = torch.ones_like(lambdas)*buffer_size//num_tasks

        elif method == 'linear':
            sum_duals = torch.sum(lambdas)
            if sum_duals == 0:
                n = torch.ones_like(lambdas)*buffer_size//num_tasks
            else:
                n = torch.ceil( (lambdas/(2*sum_duals)+1/(2*num_tasks))*buffer_size )
            
        elif method == 'gen_bound':

            n_min = buffer_size/(2*num_tasks)
            
            def objective(n, lambdas):
                return lambdas @ np.sqrt(np.log(n)/n)  
            
            n0 = np.ones(num_tasks)*buffer_size/num_tasks
            constraint = opt.LinearConstraint(A=np.ones_like(n0), lb=-np.inf, ub=buffer_size)
            opt_method = 'SLSQP' # 'trust-constr'
            n = opt.minimize(objective, n0, args=lambdas.numpy(), method='SLSQP', 
                                        bounds=[(n_min, buffer_size)],
                                            constraints = constraint)
            # Round n
            n = torch.from_numpy(np.round(n.x))

        return n


    def begin_task(self, dataset):

        if self.task > 0:
            self.lambdas = torch.ones(1, requires_grad=False)
            self.buf_lambdas = torch.ones(self.task, requires_grad=False)

    def end_task(self, dataset):
        
        samples_per_class = self.args.buffer_size // (dataset.N_CLASSES_PER_TASK*(self.task+1))

        if self.buffer.is_empty():

            imgs, targets = [], []

            space_per_class = np.zeros(dataset.N_CLASSES_PER_TASK*dataset.N_TASKS) 
            space_per_class[:dataset.N_CLASSES_PER_TASK] += samples_per_class

            random_idxs = np.arange(len(dataset.train_loader.dataset))
            np.random.shuffle(random_idxs)
            
            imgs, targets = [], []

            for i in random_idxs:
                img, target, original_img, index = dataset.train_loader.dataset[i]
                if space_per_class[target] > 0:
                    imgs.append(original_img.unsqueeze(0))
                    targets.append(target)
                    space_per_class[target] -= 1
                if np.sum(space_per_class) == 0:
                    break     

            self.buffer.add_data(
                examples=torch.cat(imgs),
                labels=torch.Tensor(targets),
                task_labels=torch.zeros(len(targets))+self.task
            )
        
        else:
            
            # Get all buffer
            all_data = self.buffer.get_all_data(transform = None)
            
            # Empty old buffer to replace with new one
            self.buffer.empty()

            n = self.partition(torch.cat((self.lambdas, self.buf_lambdas), 0), self.buffer, self.partitionmethod)
            
            # Get samples from new task
            imgs, targets = [], []
            space_per_class = torch.zeros(dataset.N_CLASSES_PER_TASK*dataset.N_TASKS)
            begin = self.task*dataset.N_CLASSES_PER_TASK
            space_per_class[begin: begin+dataset.N_CLASSES_PER_TASK] += n[0]//dataset.N_CLASSES_PER_TASK

            print(f"Space per class {space_per_class}")
            random_idxs = np.arange(len(dataset.train_loader.dataset))
            np.random.shuffle(random_idxs)
            
            imgs, targets = [], []

            for i in random_idxs:
                img, target, original_img, index = dataset.train_loader.dataset[i]
                if space_per_class[target] > 0:
                    imgs.append(original_img.unsqueeze(0))
                    targets.append(target)
                    space_per_class[target] -= 1
                if torch.sum(space_per_class) == 0:
                    break     

            self.buffer.add_data(
                examples=torch.cat(imgs),
                labels=torch.Tensor(targets),
                task_labels=torch.zeros(len(targets))+self.task
            )
            print(f"Space per class {space_per_class}")

            # Get samples from buffer
            examples, labels, tasks = all_data
            if not self.args.nowand:
                wandb.log({'buffer': labels.unique(return_counts=True)[1].cpu(),
                        'lambda': self.lambdas,
                        'buf_lambdas': self.buf_lambdas, 
                        'sum buf lambdas': torch.sum(self.buf_lambdas), 
                        'task': self.task})

            space_per_class = torch.zeros(dataset.N_CLASSES_PER_TASK*dataset.N_TASKS) 
            for task in range(self.task):
                begin = task*dataset.N_CLASSES_PER_TASK
                space_per_class[begin:begin+dataset.N_CLASSES_PER_TASK] += n[task+1]//dataset.N_CLASSES_PER_TASK

            print(f"space per class {space_per_class}")

            random_idxs = np.arange(len(labels))
            np.random.shuffle(random_idxs)
            
            imgs, targets, task_labels = [], [], []

            for i in random_idxs:
                if space_per_class[labels[i]] > 0:
                    imgs.append(examples[i].unsqueeze(0))
                    targets.append(labels[i])
                    task_labels.append(tasks[i])
                    space_per_class[labels[i]] -= 1
                if torch.sum(space_per_class) == 0:
                    break  

            self.buffer.add_data(
                examples=torch.cat(imgs),
                labels=torch.Tensor(targets),
                task_labels=torch.tensor(task_labels)
            )
            print(f"space per class {space_per_class}")

            if torch.sum(space_per_class) != 0:
                print("Careful, buffer is not full because a task difficulty was underestimated.")
                print(f"Num samples in buffer: {self.buffer.num_seen_examples}")

        self.buffer.create_loader(size = self.args.minibatch_size, model_transform = self.transform)
        self.task += 1
                
    def observe(self, inputs, labels, not_aug_inputs, indexes):

        real_batch_size = inputs.shape[0]
        
        self.opt.zero_grad()
        
        loss = 0.
        lagrangian = 0.
        buf_losses = torch.zeros(self.task, device = self.device)

        if not self.buffer.is_empty():

            lambdas = self.lambdas.clone().to(self.device)
            buf_lambdas = self.buf_lambdas.clone().to(self.device)

            # Sample from Buffer
            buf_indexes, buf_inputs, buf_labels, buf_tasks = self.buffer.get_data_from_loader(size = self.args.minibatch_size)

            # Forward Pass
            outputs = self.net(inputs)
            buf_outputs = self.net(buf_inputs)
            
            loss = self.loss(outputs, labels)
            lagrangian += lambdas*(loss - self.epsilon)    
            for task in range(self.task):
                if torch.sum(buf_tasks == task) > 1:
                    buf_losses[task] += self.loss(buf_outputs[buf_tasks==task], buf_labels[buf_tasks==task].squeeze())
            
            lagrangian += torch.sum(buf_lambdas*(buf_losses-self.epsilon))
                            
            # Primal Update
            lagrangian.backward()
            self.opt.step()

            # Dual Update
            lambdas += self.lr_dual*(loss - self.epsilon)
            lambdas = torch.nn.ReLU()(lambdas)
            self.lambdas = lambdas.detach().cpu()
           
            buf_lambdas += self.lr_dual*(buf_losses - self.epsilon)
            buf_lambdas = torch.nn.ReLU()(buf_lambdas)
            self.buf_lambdas = buf_lambdas.detach().cpu()
            
        else:

            # Forward Pass
            outputs = self.net(inputs)
            
            # Evaluate Lagrangian
            loss = self.loss(outputs, labels)

            # Primal Update
            loss.backward()
            self.opt.step()


        return loss.item()
