import collections
import copy
import logging
import os
import pickle

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class AGEM(Finetune):
    def __init__(self, args):
        super().__init__(args)
        
        self._grad_dims = {}
        self._grad_xy = {}
        self._grad_er = {}
        for modality, parameters in self._network.get_group_parameters().items():
            self._grad_dims[modality] = []
            for param in parameters:
                self._grad_dims[modality].append(param.data.numel())
                
            self._grad_xy[modality] = torch.Tensor(np.sum(self._grad_dims[modality])).to(self._device)
            self._grad_er[modality] = torch.Tensor(np.sum(self._grad_dims[modality])).to(self._device)
               
    def _training_step(
        self, train_loader, initial_epoch, nb_epochs, record_bn=True, clipper=None
    ):
        best_epoch, best_acc = -1, -1.
        wait = 0

        if len(self._multiple_devices) > 1:
            logger.info("Duplicating model on {} gpus.".format(len(self._multiple_devices)))
            training_network = nn.DataParallel(self._network, self._multiple_devices)
        else:
            training_network = self._network

        for epoch in range(initial_epoch, nb_epochs):
            self._metrics = collections.defaultdict(float)

            self._epoch_percent = epoch / (nb_epochs - initial_epoch)

            prog_bar = tqdm(
                train_loader,
                disable=self._disable_progressbar,
                ascii=True,
                bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}"
            )
            for i, input_dict in enumerate(prog_bar, start=1):
                targets = input_dict.pop("target")
                task_id = input_dict.pop("task_id")
                inputs = input_dict

                loss = self._forward_loss(
                    training_network,
                    inputs,
                    targets,
                    task_id
                )

                if clipper:
                    training_network.apply(clipper)
                    
                self._print_metrics(prog_bar, epoch, nb_epochs, i)
                
            if self._disable_progressbar:
                self._print_metrics(None, epoch, nb_epochs, i)

            if self._scheduler:
                self._scheduler.step()
                
    def _forward_loss(
        self,
        training_network,
        inputs,
        targets,
        task_id,
        **kwargs
    ):
        
        inputs = {key: item.to(self._device) for key, item in inputs.items()} 
        targets = targets.to(self._device)
        
        if self._task > 0:
            mem_batch = (task_id != self._task).sum()
            batch_size = [len(task_id)-mem_batch, mem_batch]
            
            mem_inputs = {key: item[-batch_size[1]:] for key, item in inputs.items()} 
            inputs = {key: item[:batch_size[0]] for key, item in inputs.items()}
            targets, mem_targets = torch.split(targets, batch_size)            
            
        self._optimizer.zero_grad()
        outputs = training_network(inputs)
        loss = F.cross_entropy(outputs["logits"], targets)      
        self._metrics["loss"] += loss.item()      
        loss.backward()
        
        if self._task > 0: 
            self._store_grad(self._grad_xy)

            self._optimizer.zero_grad()
            mem_outputs = training_network(mem_inputs)
            penalty = F.cross_entropy(mem_outputs["logits"], mem_targets)
            penalty.backward()
            self._metrics["penalty"] += penalty.item()      
            
            self._store_grad(self._grad_er) 

            dot_prod = 0
            for modality in self._grad_xy.keys():
                dot_prod += torch.dot(self._grad_xy[modality], self._grad_er[modality])
            if dot_prod.item() < 0:
                g_tilde = self._project(gxy=self._grad_xy, ger=self._grad_er)
                self._overwrite_grad(g_tilde)
            else:
                self._overwrite_grad(self._grad_xy)            

        self._optimizer.step()
        
        if bool(torch.isnan(loss).item()): #not utils.check_loss(loss):
            raise ValueError("A loss is NaN: {}".format(self._metrics))
  
          
        if self._task > 0:
            self._metrics["acc"] += (torch.cat([targets, mem_targets]) == torch.cat([outputs['logits'], mem_outputs['logits']]).argmax(axis=1)).float().mean().item()
        else:
            self._metrics["acc"] += (targets == outputs['logits'].argmax(axis=1)).float().mean().item()

        return loss                    
        
    def _store_grad(self, grads):
        """
            This stores parameter gradients of past tasks.
            grads: gradients
        """
        parameters = self._network.get_group_parameters()
        for modality in self._grad_dims.keys():
            grads[modality].fill_(0.0)
            count = 0
            for param in parameters[modality]:
                if param.grad is not None:
                    begin = 0 if count == 0 else sum(self._grad_dims[modality][:count])
                    end = np.sum(self._grad_dims[modality][:count + 1])
                    grads[modality][begin:end].copy_(param.grad.data.view(-1))
                count += 1
                
    def _overwrite_grad(self, newgrad):
        """
            This is used to overwrite the gradients with a new gradient
            vector, whenever violations occur.
            newgrad: corrected gradient
        """
        parameters = self._network.get_group_parameters()
        for modality in self._grad_dims.keys():
            count = 0
            for param in parameters[modality]:
                if param.grad is not None:
                    begin = 0 if count == 0 else sum(self._grad_dims[modality][:count])
                    end = sum(self._grad_dims[modality][:count + 1])
                    this_grad = newgrad[modality][begin:end].contiguous().view(
                        param.grad.data.size())
                    param.grad.data.copy_(this_grad)
                count += 1

    def _project(self, gxy, ger):
        sum1 = 0
        sum2 = 0
        for modality in gxy.keys():
            sum1 += torch.dot(gxy[modality], ger[modality])
            sum2 += torch.dot(ger[modality], ger[modality])
            
        div = sum1/sum2
        corr = {}        
        for modality in gxy.keys():
            corr[modality] = gxy[modality] - div * ger[modality]
        return corr

    def _after_task_intensive(self, inc_dataset):
        inc_dataset.update_exemplar()



