import torch

import os
import time
import logging
import numpy as np

from builder import Network
from utils.lr_scheduler import WarmupCustomLR


def create_logger(cfg):
    dataset = cfg.dataset.dataset
    net_type = cfg.backbone.type
    pooling_type = cfg.pooling.type
    scaling_type = cfg.scaling.type
    seed_num = cfg.seed_num
    log_dir = os.path.join(cfg.output_dir, cfg.name, 'logs')
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_name = "{}_{}_{}_{}_{}.log".format(
        dataset, net_type, pooling_type, scaling_type, seed_num)
    log_file = os.path.join(log_dir, log_name)
    # set up logger
    print("=> creating log {}".format(log_file))
    head = "%(asctime)-15s %(message)s"
    logging.basicConfig(filename=str(log_file), format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger("").addHandler(console)

    print(cfg)

    logger.info("--------------------Cfg is set as follow---------------------")
    logger.info(cfg)
    logger.info("-------------------------------------------------------------")
    return logger, log_file


def get_optimizer(cfg, model):
    base_lr = cfg.train.optimizer.base_lr
    params = []

    for name, p in model.named_parameters():
        if p.requires_grad:
            params.append({"params": p})

    if cfg.train.optimizer.type == 'SGD':
        optimizer = torch.optim.SGD(
            params,
            lr=base_lr,
            momentum=cfg.train.optimizer.momentum,
            weight_decay=cfg.train.optimizer.weight_decay,
#            nesterov=True,
        )
    elif cfg.train.optimizer.type == 'ADAM':
        optimizer = torch.optim.Adam(
            params,
            lr=base_lr,
            betas=(0.9, 0.999),
            weight_decay=cfg.train.optimizer.weight_decay,
        )
    return optimizer


def get_scheduler(cfg, optimizer):
    if 'multistep' in cfg.train.lr_scheduler.type:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            cfg.train.lr_scheduler.lr_step,
            gamma=cfg.train.lr_scheduler.lr_factor,
        )
    elif 'cosine' in cfg.train.lr_scheduler.type:
        T_max = cfg.train.lr_scheduler.T_max
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=T_max if T_max > 0 else cfg.train.num_epochs, 
            eta_min=cfg.train.lr_scheduler.eta_min,
        )
    elif cfg.train.lr_scheduler.type == 'none':
        scheduler = None
    else:
        raise NotImplementedError(
            "Unsupported LR Scheduler: {}".format(cfg.train.lr_scheduler.type)
        )

    if cfg.train.lr_scheduler.type.startswith('warmup'):
        scheduler = WarmupCustomLR(
            optimizer,
            warmup_epochs=cfg.train.lr_scheduler.warmup_epochs,
            after_scheduler=scheduler,
        )

    return scheduler


def get_model(cfg, num_classes, rank):
    model = Network(cfg, num_classes=num_classes)

    if not cfg.cpu_mode:
        model = model.cuda(rank)

    if cfg.ddp:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            broadcast_buffers=False,
            find_unused_parameters=True,
        )

    return model


def get_category_list(targets, num_classes, cfg):
    num_list = [0] * num_classes
    ctgy_list = []
    for tgt in targets:
        ctgy_id = tgt.item()
        num_list[ctgy_id] += 1
        ctgy_list.append(ctgy_id)
    return num_list, ctgy_list


# The mixup_data, mixup_criterion method is copied from 
# the official PyTorch implementation in Mixup
# => https://github.com/facebookresearch/mixup-cifar10.git
def mixup_data(x, y, alpha=1.0, rank=None):
    """Returns mixed inputs, pairs of targets, and lambda"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if rank is not None:
        index = torch.randperm(batch_size).cuda(rank)
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

