import warnings

import torch
from torch import optim

from inclearn import models
from inclearn.backbones import resnet
from inclearn.lib import data, schedulers


def get_optimizer(params, optimizer, lr, weight_decay=0.0):
    if optimizer == "adam":
        return optim.Adam(params, lr=lr, weight_decay=weight_decay)
    elif optimizer == "adamw":
        return optim.AdamW(params, lr=lr, weight_decay=weight_decay)
    elif optimizer == "sgd":
        return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9)
    elif optimizer == "sgd_nesterov":
        return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9, nesterov=True)

    raise NotImplementedError


def get_encoder(encoder_type, **kwargs):
    if encoder_type == "resnet18":
        return resnet.resnet18(**kwargs)
    elif encoder_type == "resnet101":
        return resnet.resnet101(**kwargs)
    elif encoder_type == "resnet34":
        return resnet.resnet34(**kwargs)
    elif encoder_type == "resnet32":
        return resnet.resnet32(**kwargs)
    elif encoder_type == 'slowfast':
        from inclearn.backbones import slowfast
        return slowfast.SlowFast(**kwargs)
    elif encoder_type == 'vit16':
        from inclearn.backbones import vit
        return vit.vit_large_patch16(**kwargs)
    raise NotImplementedError("Unknwon convnet type {}.".format(convnet_type))


def get_model(args):
    dict_models = {
        "finetune": models.Finetune,
        "joint": models.Joint,
        "er": models.ER,
        "agem": models.AGEM,
        "der": models.DER,
        "esmer": models.ESMER,
        "udil": models.UDIL,
        "afc": models.AFC,
        "ssil": models.SSIL,
        "co2l": models.Co2L,
        "ours": models.Ours,
    }

    model = args["model"].lower()

    if model not in dict_models:
        raise NotImplementedError(
            "Unknown model {}, must be among {}.".format(args["model"], list(dict_models.keys()))
        )

    return dict_models[model](args)


def get_data(args, run_id):
    return data.IncrementalDataset(args, run_id)


def set_device(args):
    devices = []

    for device_type in args["device"]:
        if device_type == -1:
            device = torch.device("cpu")
        else:
            device = torch.device("cuda:{}".format(device_type))

        devices.append(device)

    args["device"] = devices


def get_sampler(args):
    if args["sampler"] is None:
        return None

    sampler_type = args["sampler"].lower().strip()

    if sampler_type == "npair":
        return data.NPairSampler
    elif sampler_type == "triplet":
        return data.TripletSampler
    elif sampler_type == "tripletsemihard":
        return data.TripletCKSampler
    elif sampler_type == "balance":
        return data.ImbalancedSampler

    raise ValueError("Unknown sampler {}.".format(sampler_type))


def get_lr_scheduler(
    scheduling_config, optimizer, nb_epochs, lr_decay=0.1, warmup_config=None, task=0
):
    if scheduling_config is None:
        return None
    elif isinstance(scheduling_config, str):
        warnings.warn("Use a dict not a string for scheduling config!", DeprecationWarning)
        scheduling_config = {"type": scheduling_config}
    elif isinstance(scheduling_config, list):
        warnings.warn("Use a dict not a list for scheduling config!", DeprecationWarning)
        scheduling_config = {"type": "step", "epochs": scheduling_config}

    if scheduling_config["type"] == "step":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            scheduling_config["epochs"],
            gamma=scheduling_config.get("gamma") or lr_decay
        )
    elif scheduling_config["type"] == "exponential":
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, scheduling_config["gamma"])
    elif scheduling_config["type"] == "plateau":
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=scheduling_config["gamma"]
        )
    elif scheduling_config["type"] == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, nb_epochs, eta_min=scheduling_config.get("min_lr", 0.))
    elif scheduling_config["type"] == "cosine_with_restart":
        scheduler = schedulers.CosineWithRestarts(
            optimizer,
            t_max=scheduling_config.get("cycle_len", nb_epochs),
            factor=scheduling_config.get("factor", 1.)
        )
    elif scheduling_config["type"] == "cosine_annealing_with_restart":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, T_0=1, T_mult=2, eta_min=scheduling_config.get("min_lr")
        )
    else:
        raise ValueError("Unknown LR scheduling type {}.".format(scheduling_config["type"]))

    if warmup_config:
        if warmup_config.get("only_first_step", False) and task != 0:
            pass
        elif warmup_config.get("exclude_first_step", True) and task == 0:
            pass
        else:
            print("Using WarmUp")
            scheduler = schedulers.GradualWarmupScheduler(
                optimizer=optimizer, after_scheduler=scheduler, **warmup_config
            )

    return scheduler
