import copy
import json
import logging
import time
import os
from logging.handlers import TimedRotatingFileHandler

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

from lib.ExemplarManager import ExemplarManager
from .lr_scheduler import WarmupMultiStepLR, WarmupCosineAnnealingLR
from lib.model import DDC_Network
from torch.nn import functional as F
import pdb


def create_logger(cfg, file_suffix):
    dataset = cfg.DATASET.dataset_name
    net_type = cfg.BACKBONE.TYPE
    module_type = cfg.MODULE.TYPE
    log_dir = os.path.join(cfg.OUTPUT_DIR, cfg.NAME, "logs")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    time_str = time.strftime("%Y-%m-%d-%H-%M")
    log_name = "{}_{}_{}_{}.{}".format(dataset, net_type, module_type, time_str, file_suffix)
    log_file = os.path.join(log_dir, log_name)
    # set up logger
    print("=> creating log {}".format(log_file))
    head = "%(asctime)-15s %(message)s"
    logging.basicConfig(filename=str(log_file), format=head)
    logger = logging.getLogger(name=file_suffix)
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logger.addHandler(console)


    logger.info("---------------------Cfg is set as follow--------------------")
    logger.info(cfg)
    logger.info("-------------------------------------------------------------")
    return logger, log_file

def get_logger(cfg, file_suffix):
    """
    Args:
        name(str): name of logger
        log_dir(str): path of log
    """
    dataset = cfg.DATASET.dataset_name
    net_type = cfg.BACKBONE.TYPE
    module_type = cfg.MODULE.TYPE
    log_dir = os.path.join(cfg.OUTPUT_DIR, cfg.NAME, "logs")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    time_str = time.strftime("%Y-%m-%d-%H-%M")
    log_name = "{}_{}_{}_{}.{}".format(dataset, net_type, module_type, time_str, file_suffix)
    logger = logging.getLogger(file_suffix)
    logger.setLevel(logging.INFO)
    log_file = os.path.join(log_dir, log_name)
    info_handler = TimedRotatingFileHandler(log_file,
                                            when='D',
                                            encoding='utf-8')
    info_handler.setLevel(logging.INFO)
    # error_name = os.path.join(log_dir, '{}.error.log'.format(name))
    # error_handler = TimedRotatingFileHandler(error_name,
    #                                          when='D',
    #                                          encoding='utf-8')
    # error_handler.setLevel(logging.ERROR)

    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    info_handler.setFormatter(formatter)
    # error_handler.setFormatter(formatter)

    logger.addHandler(info_handler)
    # logger.addHandler(error_handler)

    return logger, log_file


def create_valid_logger(cfg):
    dataset = cfg.DATASET.DATASET
    net_type = cfg.BACKBONE.TYPE
    module_type = cfg.MODULE.TYPE

    test_model_path = os.path.join(*cfg.TEST.MODEL_FILE.split('/')[:-2])
    log_dir = os.path.join(test_model_path, "logs")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    time_str = time.strftime("%Y-%m-%d-%H-%M")
    log_name = "Test_{}_{}_{}_{}.log".format(dataset, net_type, module_type, time_str)
    log_file = os.path.join(log_dir, log_name)
    # set up logger
    print("=> creating log {}".format(log_file))
    head = "%(asctime)-15s %(message)s"
    logging.basicConfig(filename=str(log_file), format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger("").addHandler(console)

    logger.info("---------------------Start Testing--------------------")
    logger.info("Test model: {}".format(cfg.TEST.MODEL_FILE))

    return logger, log_file


def get_optimizer(cfg, model, BASE_LR=None, **kwargs):
    if BASE_LR:
        base_lr = BASE_LR
    else:
        base_lr = cfg.TRAIN.OPTIMIZER.BASE_LR
    params = []

    for name, p in model.named_parameters():
        if p.requires_grad:
            params.append({"params": p})

    if cfg.TRAIN.OPTIMIZER.TYPE == "SGD":
        optimizer = torch.optim.SGD(
            params,
            lr=base_lr,
            momentum=cfg.TRAIN.OPTIMIZER.MOMENTUM,
            weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY,
            nesterov=True,
        )
    elif cfg.TRAIN.OPTIMIZER.TYPE == 'SGDWithExtraWeightDecay':
        optimizer = SGDWithExtraWeightDecay(
            params,
            kwargs['num_class_list'],
            kwargs['classifier_shape'],
            lr=base_lr,
            momentum=cfg.TRAIN.OPTIMIZER.MOMENTUM,
            weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY,
            nesterov=True,
        )
    elif cfg.TRAIN.OPTIMIZER.TYPE == "ADAM":
        optimizer = torch.optim.Adam(
            params,
            lr=base_lr,
            betas=(0.9, 0.999),
            weight_decay=cfg.TRAIN.OPTIMIZER.WEIGHT_DECAY,
        )
    return optimizer


def get_scheduler(cfg, optimizer, lr_step=None):
    if lr_step:
        LR_STEP = lr_step
    else:
        LR_STEP = cfg.TRAIN.LR_SCHEDULER.LR_STEP
    if cfg.TRAIN.LR_SCHEDULER.TYPE == "multistep":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            LR_STEP,
            gamma=cfg.TRAIN.LR_SCHEDULER.LR_FACTOR,
        )
    elif cfg.TRAIN.LR_SCHEDULER.TYPE == "warmup":
        scheduler = WarmupMultiStepLR(
            optimizer,
            LR_STEP,
            gamma=cfg.TRAIN.LR_SCHEDULER.LR_FACTOR,
            warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARM_EPOCH,
        )
    else:
        raise NotImplementedError("Unsupported LR Scheduler: {}".format(cfg.TRAIN.LR_SCHEDULER.TYPE))

    return scheduler


def get_model(cfg, num_classes, device):
    model = DDC_Network(cfg, mode="train", num_classes=num_classes)

    if cfg.CPU_MODE:
        model = model.to(device)
    else:
        model = torch.nn.DataParallel(model).cuda()

    return model


def get_category_list(annotations, num_classes, cfg):
    num_list = [0] * num_classes
    cat_list = []
    print("Weight List has been produced")
    for anno in annotations:
        category_id = anno["category_id"]
        num_list[category_id] += 1
        cat_list.append(category_id)
    return num_list, cat_list


class _RequiredParameter(object):
    """Singleton class representing a required parameter for an Optimizer."""

    def __repr__(self):
        return "<required parameter>"


required = _RequiredParameter()


class SGDWithExtraWeightDecay(torch.optim.Optimizer):

    def __init__(self, params, num_class_list, classifier_shape, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):

        self.extra_weight_decay = weight_decay / num_class_list[:, None].repeat(1, classifier_shape[-1])
        self.classifier_shape = classifier_shape
        self.first = True

        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGDWithExtraWeightDecay, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGDWithExtraWeightDecay, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                    if self.classifier_shape == d_p.shape:
                        if self.first:
                            self.first = False
                        else:
                            d_p.add_(self.extra_weight_decay * p.data)
                            self.first = True

                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

        return loss


def to_one_hot(y, classes):
    '''Convert a nd-array with integers [y] to a 2D "one-hot" tensor.'''
    c = np.zeros(shape=[len(y), classes], dtype='float32')
    c[range(len(y)), y] = 1.
    c = torch.from_numpy(c)
    return c


def get_cls_num_list(num_classes, training_dataset):
    cls_num_list = [0, ] * num_classes
    for d in training_dataset:
        cls_num_list[d[1]] += 1
    return cls_num_list


# def resume_ddc_tasks(cfg, resumed_file, model_path, pre_task_model):
#     with open(resumed_file, 'r') as fr:
#         breakpoint_data = json.load(fr)
#     dataset_split_handler = eval(cfg.DATASET.dataset)(cfg, breakpoint_data["split_selected_data"])
#     dataset_split_handler.get_dataset()
#     exemplar_manager = ExemplarManager(cfg.exemplar_manager.memory_budget, cfg.exemplar_manager.mng_approach,
#                                        cfg.exemplar_manager.store_original_imgs, cfg.exemplar_manager.norm_exemplars,
#                                        cfg.exemplar_manager.centroid_order,
#                                        img_transform_for_val=exemplar_img_transform_for_val,
#                                        img_transform_for_train=exemplar_img_transform_for_train,
#                                        device=device)
#     self.exemplar_manager.resume_manager(breakpoint_data)
#     self.resume_model(breakpoint_data)
