import collections
import copy
import logging
import os
import pickle

import torch
from torch.utils.data import ConcatDataset

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.models.finetune import Finetune

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class Joint(Finetune):
    def __init__(self, args):
        super().__init__(args)
        self.prev_dataset = None
                        
    # ----------
    # Public API
    # ----------
    def _training_step(
        self, train_loader, initial_epoch, nb_epochs, record_bn=True, clipper=None
    ):
        if self._task > 0:
            self.inc_dataset.cur_train_dataset = ConcatDataset([self.prev_dataset, self.inc_dataset.cur_train_dataset])
            
        train_loader = self.inc_dataset.get_cur_train_loader(shuffle=True, num_workers=self.inc_dataset.args["workers"], drop_last=True)            
        super()._training_step(train_loader, initial_epoch, nb_epochs)
        
    def _after_task_intensive(self, inc_dataset):
        self.prev_dataset = inc_dataset.cur_train_dataset        
