import collections
import copy
import logging
import os
import pickle

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

from inclearn.lib import factory, herding, losses, network, schedulers, utils
from inclearn.lib.network import hook
from inclearn.models.base import IncrementalLearner

EPSILON = 1e-8

logger = logging.getLogger(__name__)


class Finetune(IncrementalLearner):
    def __init__(self, args):
        super().__init__()
        self._n_classes = args['n_classes']
        self._n_tasks = args['n_tasks']

        self._nlabel = args["nlabel"]
        self._disable_progressbar = args.get("no_progressbar", False)

        self._device = args["device"][0]
        self._multiple_devices = args["device"]

        self._opt_name = args["optimizer"]
        self._lr = args["lr"]
        self._weight_decay = args["weight_decay"]
        self._n_epochs = args["epochs"]

        self._scheduling = args["scheduling"]
        self._lr_decay = args["lr_decay"]
        self._groupwise_factors = args.get("groupwise_factors", {})

        self._network = network.BasicNet(
            args['dataset'],
            self._n_classes,
            arch=args.get("arch", None),
            device=self._device,
        )
                               
    # ----------
    # Public API
    # ----------

    def _before_task(self):   
        if self._groupwise_factors:
            params = []
            for group_name, group_params in self._network.get_group_parameters().items():
                factor = self._groupwise_factors.get(group_name, 1.0)
                if isinstance(factor, list):
                    factor = factor[0] if self._task == 0 else factor[1]
                if factor == 0.:
                    continue
                params.append({"params": group_params, "lr": self._lr * factor})
                logger.info(f"Group: {group_name}, lr: {self._lr * factor}.")
        else:
            params = []
            for group_name, group_params in self._network.get_group_parameters().items():
                params.append({"params": group_params, "lr": self._lr})
                logger.info(f"Group: {group_name}, lr: {self._lr}.")

        self._optimizer = factory.get_optimizer(
            params, self._opt_name, self._lr, self._weight_decay
        )

        self._scheduler = factory.get_lr_scheduler(
            self._scheduling,
            self._optimizer,
            nb_epochs=self._n_epochs,
            lr_decay=self._lr_decay,
            task=self._task
        )
            
    def _train_task(self, train_loader):
        logger.debug("nb {}.".format(len(train_loader.dataset)))
        self._training_step(train_loader, 0, self._n_epochs)

    def _training_step(
        self, train_loader, initial_epoch, nb_epochs, record_bn=True, clipper=None
    ):
        best_epoch, best_acc = -1, -1.
        wait = 0

        if len(self._multiple_devices) > 1:
            logger.info("Duplicating model on {} gpus.".format(len(self._multiple_devices)))
            training_network = nn.DataParallel(self._network, self._multiple_devices)
        else:
            training_network = self._network

        for epoch in range(initial_epoch, nb_epochs):
            self._metrics = collections.defaultdict(float)

            self._epoch_percent = epoch / (nb_epochs - initial_epoch)

            prog_bar = tqdm(
                train_loader,
                disable=self._disable_progressbar,
                ascii=True,
                bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}"
            )
            for i, input_dict in enumerate(prog_bar, start=1):
                targets = input_dict.pop("target")
                task_id = input_dict.pop("task_id")
                inputs = input_dict

                self._optimizer.zero_grad()
                loss = self._forward_loss(
                    training_network,
                    inputs,
                    targets,
                    task_id
                )
                loss.backward()
                self._optimizer.step()

                if clipper:
                    training_network.apply(clipper)
                    
                self._print_metrics(prog_bar, epoch, nb_epochs, i)
                
            if self._disable_progressbar:
                self._print_metrics(None, epoch, nb_epochs, i)

            if self._scheduler:
                self._scheduler.step()

    def _print_metrics(self, prog_bar, epoch, nb_epochs, nb_batches):
        pretty_metrics = ", ".join(
            "{}: {}".format(metric_name, round(metric_value / nb_batches, 3))
            for metric_name, metric_value in self._metrics.items()
        )

        if prog_bar is None:
            logger.info(
                "T{}/{}, E{}/{} => {}".format(
                    self._task + 1, self._n_tasks, epoch + 1, nb_epochs, pretty_metrics
                )
            )
        else:
            prog_bar.set_description(
                "T{}/{}, E{}/{} => {}".format(
                    self._task + 1, self._n_tasks, epoch + 1, nb_epochs, pretty_metrics
                )
            )

    def _forward_loss(
        self,
        training_network,
        inputs,
        targets,
        task_id,
        **kwargs
    ):
        
        inputs = {key: item.to(self._device) for key, item in inputs.items()} 
        targets = targets.to(self._device)
        
        outputs = training_network(inputs)

        loss = self._compute_loss(inputs, outputs, targets, task_id, **kwargs)            
        if bool(torch.isnan(loss).item()): #not utils.check_loss(loss):
            raise ValueError("A loss is NaN: {}".format(self._metrics))

        self._metrics["loss"] += loss.item()        
        self._metrics["acc"] += (targets == outputs['logits'].argmax(axis=1)).float().mean().item()

        return loss                    

    def _after_task(self):
        self._network.on_task_end()
        
    def _eval_task(self, data_loader):
        ypred = []
        ytrue = []
        zid = []

        for input_dict in data_loader:
            targets = input_dict.pop("target").numpy()
            task_id = input_dict.pop("task_id").numpy()
            
            ytrue.append(targets)
            zid.append(task_id)

            inputs = {key: item.to(self._device) for key, item in input_dict.items()}
            logits = self._network(inputs)["logits"].detach()

            preds = F.softmax(logits, dim=-1)
            ypred.append(preds.cpu().numpy())

        ypred = np.concatenate(ypred)
        ytrue = np.concatenate(ytrue)
        zid = np.concatenate(zid)

        return ypred, ytrue, zid

    # -----------
    # Private API
    # -----------

    def _compute_loss(self, inputs, outputs, targets, task_id):     
        loss = F.cross_entropy(outputs["logits"], targets)      
            
        return loss
