# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
import torch
import torchvision
from argparse import Namespace
from utils.conf import get_device
from ..optimizers import get_optimizer, LR_Scheduler, get_apd_optimizer, MultiStepLRScheduler
import torch.distributed as dist
try:
    # noinspection PyUnresolvedReferences
    from apex import amp
except ImportError:
    amp = None
# import torch.cuda.amp as amp



class ContinualModel(nn.Module):
    """
    Continual learning model.
    """
    NAME = None
    COMPATIBILITY = []

    def __init__(self, backbone: nn.Module, loss: nn.Module,
            args: Namespace, transform: torchvision.transforms, logger):
        super(ContinualModel, self).__init__()

        #self.net = nn.DataParallel(backbone)
        #with torch.autograd.set_detect_anomaly(True):
        self.net = backbone
        self.logger = logger
        self.loss = loss
        self.args = args
        self.transform = transform
        # self.len_train_lodaer = len_train_lodaer
        
        # self.lr_scheduler = self.set_optimizer_and_lr_scheduler(task_id=0, len_train_lodaer=len_train_lodaer)

        # if args.amp_opt_level != "O0":
        #     self.net, self.opt = amp.initialize(self.net.to('cuda'), self.opt, opt_level=args.amp_opt_level)
        # self.net = torch.nn.parallel.DistributedDataParallel(self.net.to('cuda'), device_ids=[args.local_rank], broadcast_buffers=True, find_unused_parameters=True)
        self.device = get_device()

    def forward(self, x: torch.Tensor):
        """
        Computes a forward pass.
        :param x: batch of inputs
        :param task_label: some models require the task label
        :return: the result of the computation
        """
        return self.net.module.backbone.forward(x)

    def observe(self, inputs: torch.Tensor, labels: torch.Tensor,
                not_aug_inputs: torch.Tensor):
        """
        Compute a training step over a given batch of examples.
        :param inputs: batch of examples
        :param labels: ground-truth labels
        :param kwargs: some methods could require additional parameters
        :return: the value of the loss function
        """
        pass

    def set_optimizer_and_lr_scheduler(self, task_id, len_train_lodaer):
        task_init_lr = self.args.init_lr * self.args.train.task_lr_decay**task_id
        final_lr = self.args.train.min_lr * self.args.train.task_lr_decay**task_id
        wamup_lr = self.args.train.warmup_lr * self.args.train.task_lr_decay**task_id

        # if self.args.local_rank == 0:
        self.logger.info(f"optimizer_type: {self.args.optimizer}, \
                        warmup_lr: {wamup_lr}, \
                        base_lr: {task_init_lr}, \
                        final_lr: {final_lr}")

        if self.args.reinit_opt_per_task or not hasattr(self, 'opt'):
            # if not hasattr(self, 'opt'):
            #     self.logger.info('OPTIMIZER is not defined yet.')

            self.opt = get_optimizer(
                self.args.optimizer, self.net,
                lr=task_init_lr,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay,
                partial_freeze=self.args.partial_freeze                
            )
            self.logger.info('OPTIMIZER is defined now.')

        if self.args.train.lr_schedule.type == 'cosine':
            lr_scheduler = LR_Scheduler(
                optimizer=self.opt,
                warmup_epochs=self.args.train.warmup_epochs,
                warmup_lr=wamup_lr,
                num_epochs=self.args.num_epochs,
                base_lr=task_init_lr,
                final_lr=final_lr,
                iter_per_epoch=len_train_lodaer
            )
        elif self.args.train.lr_schedule.type == 'multistep':
            warmup_steps = int(len_train_lodaer * self.args.train.warmup_epochs)
            decay_steps = int(len_train_lodaer * (self.args.num_epochs - self.args.train.warmup_epochs))
            assert False not in [self.args.train.warmup_epochs < milestone_epoch for milestone_epoch in self.args.train.lr_schedule.multisteps]
            multi_steps = [(i-self.args.train.warmup_epochs) * len_train_lodaer for i in self.args.train.lr_schedule.multisteps]
            lr_scheduler = MultiStepLRScheduler(
                self.opt,
                milestones=multi_steps,
                gamma=self.args.train.lr_schedule.gamma,
                task_init_lr=task_init_lr,
                warmup_lr=wamup_lr,
                warmup_t=warmup_steps,
                decay_steps=decay_steps
            )
        else:
            raise NotImplementedError()
        return lr_scheduler


    def set_task(self, task_id, len_train_lodaer):
        self.lr_scheduler = self.set_optimizer_and_lr_scheduler(task_id=task_id, len_train_lodaer=len_train_lodaer)
        if self.args.amp_opt_level != "O0":
            if task_id == 0 or not hasattr(self.net, 'module'):
                self.net, self.opt = amp.initialize(self.net.to('cuda'), self.opt, opt_level=self.args.amp_opt_level)
                self.net = torch.nn.parallel.DistributedDataParallel(self.net.to('cuda'), device_ids=[self.args.local_rank], find_unused_parameters=True)
            elif self.args.reinit_opt_per_task:
                self.net, self.opt = amp.initialize(self.net.module.to('cpu').to('cuda'), self.opt, opt_level=self.args.amp_opt_level)
                self.net = torch.nn.parallel.DistributedDataParallel(self.net.to('cuda'), device_ids=[self.args.local_rank], find_unused_parameters=True)
            else:
                pass
        else:
            # self.opt.state_dict()['param_groups'][0]['betas']
            # (0.9, 0.999)
            if task_id == 0 or not hasattr(self.net, 'module'):
                self.net = self.net.to('cuda')
                self.net = torch.nn.parallel.DistributedDataParallel(self.net, device_ids=[self.args.local_rank], find_unused_parameters=True)
            elif self.args.reinit_opt_per_task:
                pass
                # self.logger.info('################# self.net reset here')
                # self.opt.state_dict()['param_groups'][0]['betas'] = (0.9, 0.999)
                # self.opt.state_dict()['param_groups'][0]['betas'] = (0.9, 0.999)
                # self.net = self.net.module.to('cpu')
                # dist.barrier()
                # self.net = torch.nn.parallel.DistributedDataParallel(self.net.to('cuda'), device_ids=[self.args.local_rank], find_unused_parameters=False)
            else:
                pass

    def end_task(self, logger=None):
        """self.lr_scheduler.reset()"""
        pass
