import logging
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from abc import abstractmethod
from utils.toolkit import accuracy, count_parameters


class BaseLearner(object):
    def __init__(self, args):

        self.args = args
        self.dataset = args.get("dataset", "cifar100")
        self.class_num = args.get("init_cls", 10)
        self.increment = args.get("increment", 10)
        self.total_sessions = args.get("total_sessions", 10)
        self.epochs = args.get("epochs", 20)
        self.lrate = args.get("lrate", 0.0005)
        self.lrate_decay = args.get("lrate_decay", 0.1)
        self.weight_decay = args.get("weight_decay", 0.0)
        self.fc_lrate = args.get("fc_lrate", 0.0005)
        self.batch_size = args.get("batch_size", 128)
        self.num_workers = args.get("num_workers", 16)
        self.topk = args.get("topk", 5)
        self.cur_task = -1
        self.known_classes = 0  # Number of classes seen so far
        self.total_classes = 0  # Total number of classes in the current task
        self.network = self._build_network()

        self.debug = args.get("debug", False)
        self.device = args.get("device", 0)[0]
        self.multiple_gpus = args.get("device", 0)

    @property
    def feature_dim(self):
        if isinstance(self.network, nn.DataParallel):
            return self.network.module.feature_dim
        else:
            return self.network.feature_dim

    def build_train_loader(self, data_manager):
        train_dataset = data_manager.get_dataset(
            np.arange(self.known_classes, self.total_classes),
            source='train', mode='train')
        self.train_loader = DataLoader(
            train_dataset, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers)

    def build_test_loader(self, data_manager):
        test_dataset = data_manager.get_dataset(
            np.arange(0, self.total_classes),
            source='test', mode='test')
        self.test_loader = DataLoader(
            test_dataset, batch_size=self.batch_size, shuffle=False,
            num_workers=self.num_workers)

    def build_optimizer(self, parameters, lr, weight_decay, num_epochs):
        optimizer = optim.Adam(parameters, lr=lr, weight_decay=weight_decay, betas=(0.9,0.999))
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

        return optimizer, scheduler

    def before_task(self, data_manager):
        self.cur_task += 1
        self.total_classes = self.known_classes + data_manager.get_task_size(self.cur_task)
        self.network.update_fc()

    def after_task(self):
        self.known_classes = self.total_classes

    def incremental_train(self, task, data_manager):
        self.build_train_loader(data_manager)
        logging.info('Task {} learning on class {}-{}'.format(task, self.known_classes, self.total_classes))
        self._train(task, self.train_loader)

    def _train(self, task, train_loader):
        self.network.to(self.device)
        self.freeze_network()

        logging.info('All params: {}'.format(count_parameters(self.network)))
        logging.info('Trainable params: {}'.format(count_parameters(self.network, True)))

        encoder_params = self.network.image_encoder.parameters()
        cls_params = [p for p in self.network.classifier_pool.parameters() if p.requires_grad==True]

        if len(self.multiple_gpus) > 1:
            self.network = nn.DataParallel(self.network, self.multiple_gpus)

        # Setup optimizer
        encoder_params = {'params': encoder_params, 'lr': self.lrate, 'weight_decay': self.weight_decay}
        cls_params = {'params': cls_params, 'lr': self.fc_lrate, 'weight_decay': self.weight_decay}

        network_params = [encoder_params, cls_params]
        optimizer, scheduler = self.build_optimizer(
            network_params, self.lrate, self.weight_decay, self.epochs)
        self.run_epoch = self.epochs

        # to be implemented
        self._train_function(task, train_loader, optimizer, scheduler)

        if len(self.multiple_gpus) > 1:
            self.network = self.network.module
        return
    
    @abstractmethod
    def _train_function(self, task, train_loader, optimizer, scheduler):
        pass
    
    def incremental_test(self, data_manager):
        self.build_test_loader(data_manager)
        y_pred, y_pred_with_task, y_true, y_pred_task, y_true_task = self._test(self.test_loader)
        cnn_accy = self._evaluate(y_pred, y_true)
        cnn_accy_with_task = self._evaluate(y_pred_with_task, y_true)
        cnn_accy_task = (y_pred_task == y_true_task).sum().item()/len(y_pred_task)

        return cnn_accy, cnn_accy_with_task, cnn_accy_task

    def _test(self, test_loader):
        self.network.eval()
        
        y_pred, y_true = [], [] 
        y_pred_with_task = [] 
        y_pred_task, y_true_task = [], []

        for _, (_, inputs, targets) in enumerate(test_loader):
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            with torch.no_grad():
                y_true_task.append((torch.div(targets, self.class_num, rounding_mode='trunc')).cpu())
                if isinstance(self.network, nn.DataParallel):
                    outputs = self.network.module.interface(inputs)
                else:
                    outputs = self.network.interface(inputs)  # logits

            predicts = torch.topk(outputs, k=self.topk, dim=1, largest=True, sorted=True)[1].view(-1)  # [bs, topk]
            y_pred_task.append((torch.div(predicts, self.class_num, rounding_mode='trunc')).cpu())

            outputs_with_task = torch.zeros_like(outputs)[:,:self.class_num]
            for idx, i in enumerate(torch.div(targets, self.class_num, rounding_mode='trunc')):
                en, be = self.class_num*i, self.class_num*(i+1)
                outputs_with_task[idx] = outputs[idx, en:be]

            predicts_with_task = outputs_with_task.argmax(dim=1)
            predicts_with_task = predicts_with_task + (torch.div(targets, self.class_num, rounding_mode='trunc'))*self.class_num

            y_pred.append(predicts.cpu().numpy())
            y_pred_with_task.append(predicts_with_task.cpu().numpy())
            y_true.append(targets.cpu().numpy())

        return (
            np.concatenate(y_pred),
            np.concatenate(y_pred_with_task),
            np.concatenate(y_true),
            torch.cat(y_pred_task),
            torch.cat(y_true_task)
        )

    def _evaluate(self, y_pred, y_true):
        ret = {}
        grouped = accuracy(y_pred, y_true, self.known_classes, self.class_num)
        ret['grouped'] = grouped
        ret['top1'] = grouped['total']
        return ret

    def freeze_network(self):
        # Freeze all parameters and only update the parameters of the current task
        for name, param in self._network.named_parameters():
            param.requires_grad_(False)

    def _build_network(self):
        model_name = self.args['model_name'].lower()

        if model_name == "baseline":    # Baseline
            from models.net import Net
        elif model_name == "ewclora":   # EWC-LoRA
            from models.net_ewc import Net
        else:
            raise ValueError("Model {} is not defined".format(self.args['model_name']))
        return Net(self.args)
