import collections
import copy
import logging
import os
import pickle

import numpy as np
import torch
import scipy
import scipy.linalg
from scipy.spatial.distance import cdist
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.lib.network import hook
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class DER(Finetune):
    def __init__(self, args):
        super().__init__(args)
        
        self.alpha = args.get('alpha', 1.0)
        self.beta = args.get('beta', 1.0)
        
    def _training_step(
        self, train_loader, initial_epoch, nb_epochs, record_bn=True, clipper=None
    ):
        best_epoch, best_acc = -1, -1.
        wait = 0

        if len(self._multiple_devices) > 1:
            logger.info("Duplicating model on {} gpus.".format(len(self._multiple_devices)))
            training_network = nn.DataParallel(self._network, self._multiple_devices)
        else:
            training_network = self._network

        for epoch in range(initial_epoch, nb_epochs):
            self._metrics = collections.defaultdict(float)

            self._epoch_percent = epoch / (nb_epochs - initial_epoch)

            prog_bar = tqdm(
                train_loader,
                disable=self._disable_progressbar,
                ascii=True,
                bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}"
            )
            for i, input_dict in enumerate(prog_bar, start=1): 
                targets = input_dict.pop("target")
                task_id = input_dict.pop("task_id")
                past_logits = input_dict.pop("past_logits") if self._task>0 else None
                inputs = input_dict
                
                self._optimizer.zero_grad()
                loss = self._forward_loss(
                    training_network,
                    inputs,
                    targets,
                    task_id,
                    past_logits=past_logits
                )
                loss.backward()
                self._optimizer.step()

                if clipper:
                    training_network.apply(clipper)
                    
                self._print_metrics(prog_bar, epoch, nb_epochs, i)
                
            if self._disable_progressbar:
                self._print_metrics(None, epoch, nb_epochs, i)

            if self._scheduler:
                self._scheduler.step()
        
    def _compute_loss(self, inputs, outputs, targets, task_id, past_logits):     
        if self._task == 0:
            loss = F.cross_entropy(outputs["logits"], targets)      

        else:
            old_batch = (task_id != self._task).sum()
            batch_size = [len(task_id)-old_batch, old_batch//2, old_batch//2]
                        
            outputs_new, outputs_alpha, outputs_beta = torch.split(outputs["logits"], batch_size)
            targets_new, targets_alpha, targets_beta = torch.split(targets, batch_size)
            
            old_outputs_alpha = past_logits[batch_size[0]:batch_size[0]+batch_size[1]].to(self._device)
    
            loss = F.cross_entropy(outputs_new, targets_new)
            self._metrics["clf"] += loss.item()   
            loss_alpha = self.alpha * F.mse_loss(outputs_alpha, old_outputs_alpha)
            self._metrics["old_mse"] += loss_alpha.item()   
            loss_beta = self.beta * F.cross_entropy(outputs_beta, targets_beta)
            self._metrics["old_clf"] += loss_beta.item()   
            
            loss += loss_alpha + loss_beta
            
        return loss
    
    def _after_task_intensive(self, inc_dataset):
        predictions = []
        for input_dict in inc_dataset.get_cur_train_loader(shuffle=False, num_workers=8):
            inputs = {key: item.to(self._device) for key, item in input_dict.items() if (key!="target" and key!="task_id")} 
            predictions.append(self._network(inputs)['logits'].detach().cpu().numpy())
        predictions = np.concatenate(predictions)
        inc_dataset.update_cur_train_predictions(predictions)
        
        inc_dataset.update_exemplar()



