import collections
import copy
import logging
import os
import pickle

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.lib.network import hook
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class SSIL(Finetune):
    def __init__(self, args):
        super().__init__(args)

        self.n_classes_per_task = args['n_classes_per_task']
            
        self._old_network = None
        
        self._distillation_loss = args.get("distillation_loss", {})
        
    def _before_task(self):
        new_classifier = nn.Linear(self._network.features_dim, self._seen_classes).to(self._device)
        if self._task > 0:            
            new_classifier.weight.data[:self._old_seen_classes] = self._network.classifier.weight.data
            new_classifier.bias.data[:self._old_seen_classes] = self._network.classifier.bias.data
            
        self._network.classifier = new_classifier
        
        super()._before_task()

    def _compute_loss(self, inputs, outputs, targets, task_id): 
        if self._task == 0:
            loss = F.cross_entropy(outputs["logits"], targets)   
        
        else:
            old_batch = (task_id != self._task).sum()
            
            with torch.no_grad():
                old_outputs = self._old_network(inputs)                
                old_logits = old_outputs['logits'].detach()
                
            curr_out = outputs['logits'][:-old_batch, self._old_seen_classes:]
            loss_curr = F.cross_entropy(curr_out, targets[:-old_batch]-self._old_seen_classes)     
            
            prev_out = outputs['logits'][-old_batch:, :self._old_seen_classes]
            loss_prev = F.cross_entropy(prev_out, targets[-old_batch:])  

            loss_clf = (loss_curr * (len(task_id)-old_batch) + loss_prev * old_batch) / len(task_id)
            self._metrics["clf"] += loss_clf.item()   

            T = self._distillation_loss.get('temperature', 2)
            loss_dist = 0 
            for t in range(self._task):
                start = t * self.n_classes_per_task
                end = (t + 1) * self.n_classes_per_task

                soft_target = F.softmax(old_outputs['logits'][:, start:end] / T, dim=1)
                output_log = F.log_softmax(outputs['logits'][:, start:end] / T, dim=1)
                loss_dist += F.kl_div(output_log, soft_target, reduction='batchmean') * (T**2)
            
            self._metrics["dist"] += loss_dist.item()   
            
            loss = loss_dist + loss_clf
                        
        return loss
    
    def _after_task(self):
        self._old_network = self._network.copy().freeze().to(self._device)
        self._network.on_task_end()
    
    def _after_task_intensive(self, inc_dataset):
        inc_dataset.update_exemplar()




