import torch
from torch._six import string_classes
import functools
import numpy as np
import sys
from types import MethodType
import warnings
from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts
from .scaler import LossScaler
from ._process_optimizer import _process_optimizer
from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused

if torch.distributed.is_available():
    from ..parallel import DistributedDataParallel as apex_DDP
    from ..parallel.LARC import LARC


def to_type(dtype, t):
    if isinstance(t, torch.Tensor):
        if not t.is_cuda:
            # This should not be a hard error, since it may be legitimate.
            warnings.warn("An input tensor was not cuda.")
        # GANs require this.
        # if t.requires_grad:
        #     warn_or_err("input data requires grad.  Since input data is not a model parameter,\n"
        #         "its gradients will not be properly allreduced by DDP.")
        if t.is_floating_point():
            return t.to(dtype)
        return t
    else:
        # Trust the user's custom batch type, that's all I can do here.
        return t.to(dtype)


# Modified from torch.optim.optimizer.py.  This is a bit more general than casted_args in utils.py.
def applier(value, fn):
    if isinstance(value, torch.Tensor):
        return fn(value)
    elif isinstance(value, string_classes):
        return value
    elif isinstance(value, np.ndarray):
        return value
    elif hasattr(value, "to"): # Allow handling of custom batch classes
        return fn(value)
    elif isinstance(value, container_abcs.Mapping):
        return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
    elif isinstance(value, container_abcs.Iterable):
        return type(value)(applier(v, fn) for v in value)
    else:
        # Do I want this to fire off even if someone chooses to pass something ordinary like
        # an int or float?  May be more annoying than it's worth.
        # print("Warning:  unrecognized type in applier.  If your input data is a custom class, "
        #     "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. "
        #     "Amp will check for your custom to() and invoke it to cast the batch's "
        #     "floating-point Tensors to the appropriate type. "
        #     "Also, if your data is a custom class, it is your responsibility to ensure that "
        #     "any Tensors you want to be cuda are already cuda."
        return value


def check_models(models):
    for model in models:
        parallel_type = None
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            parallel_type = "torch.nn.parallel.DistributedDataParallel"
        if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP):
            parallel_type = "apex.parallel.DistributedDataParallel"
        if isinstance(model, torch.nn.parallel.DataParallel):
            parallel_type = "torch.nn.parallel.DataParallel"
        if parallel_type is not None:
            raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) +
                "Parallel wrappers should only be applied to the model(s) AFTER \n"
                "the model(s) have been returned from amp.initialize.")


def check_params_fp32(models):
    for model in models:
        for name, param in model.named_parameters():
            if param.is_floating_point():
                if 'Half' in param.type():
                    warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
                        "When using amp.initialize, you do not need to call .half() on your model\n"
                        "before passing it, no matter what optimization level you choose.".format(
                        name, param.type()))
                elif not param.is_cuda:
                    warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n"
                        "When using amp.initialize, you need to provide a model with parameters\n"
                        "located on a CUDA device before passing it no matter what optimization level\n"
                        "you chose. Use model.to('cuda') to use the default device.".format(
                        name, param.type()))

        # Backward compatibility for PyTorch 0.4
        if hasattr(model, 'named_buffers'):
            buf_iter = model.named_buffers()
        else:
            buf_iter = model._buffers
        for obj in buf_iter:
            if type(obj)==tuple:
                name, buf = obj
            else:
                name, buf = obj, buf_iter[obj]
            if buf.is_floating_point():
                if 'Half' in buf.type():
                    warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
                        "When using amp.initialize, you do not need to call .half() on your model\n"
                        "before passing it, no matter what optimization level you choose.".format(
                        name, buf.type()))
                elif not buf.is_cuda:
                    warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
                        "When using amp.initialize, you need to provide a model with buffers\n"
                        "located on a CUDA device before passing it no matter what optimization level\n"
                        "you chose. Use model.to('cuda') to use the default device.".format(
                        name, buf.type()))


def check_optimizers(optimizers):
    for optim in optimizers:
        bad_optim_type = None
        if isinstance(optim, FP16_Optimizer_general):
            bad_optim_type = "apex.fp16_utils.FP16_Optimizer"
        if isinstance(optim, FP16_Optimizer_for_fused):
            bad_optim_type = "apex.optimizers.FP16_Optimizer"
        if bad_optim_type is not None:
            raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) +
                               "The optimizer(s) passed to amp.initialize() must be bare \n"
                               "instances of either ordinary Pytorch optimizers, or Apex fused \n"
                               "optimizers.\n")


class O2StateDictHook(object):
    def __init__(self, fn):
        self.fn = fn

    def __call__(self, module, state_dict, prefix, local_metadata):
        for key in state_dict:
            param = state_dict[key]
            if 'Half' in param.type():
                param = param.to(torch.float32)
                state_dict[key] = param


def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
    from .amp import init as amp_init

    optimizers_was_list = False
    if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)):
        optimizers = [optimizers]
    elif optimizers is None:
        optimizers = []
    elif isinstance(optimizers, list):
        optimizers_was_list = True
        check_optimizers(optimizers)
    else:
        check_optimizers([optimizers])
        raise TypeError("optimizers must be either a single optimizer or a list of optimizers.")

    if isinstance(models, torch.nn.Module):
        models_was_list = False
        models = [models]
    elif isinstance(models, list):
        models_was_list = True
    else:
        raise TypeError("models must be either a single model or a list of models.")

    check_models(models)

    if not _amp_state.allow_incoming_model_not_fp32:
        check_params_fp32(models)

    # In the future, when FP16_Optimizer can be deprecated and master weights can
    # become an attribute, remember to stash master weights before casting the model.

    if properties.cast_model_type:
        if properties.keep_batchnorm_fp32:
            for model in models:
                convert_network(model, properties.cast_model_type)
        else:
            for model in models:
                model.to(properties.cast_model_type)

        input_caster = functools.partial(to_type, properties.cast_model_type)
        if cast_model_outputs is not None:
            output_caster = functools.partial(to_type, cast_model_outputs)
        else:
            output_caster = functools.partial(to_type, torch.float32)

        for model in models:
            # Patch the forward method to cast incoming data to the correct type, and
            # outgoing data to float32, so "the user never needs to call .half()."
            # I like writing things explicitly more than decorators.
            def patch_forward(old_fwd):
                def new_fwd(*args, **kwargs):
                    output = old_fwd(*applier(args, input_caster),
                                     **applier(kwargs, input_caster))
                    return applier(output, output_caster)
                return new_fwd

            model.forward = patch_forward(model.forward)

        # State dict trick to recast any preexisting per-param state tensors
        for optimizer in optimizers:
            optimizer.load_state_dict(optimizer.state_dict())

        # patch model.state_dict() to return float32 params
        for model in models:
            for module in model.modules():
                module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32)))

    elif cast_model_outputs is not None:
        output_caster = functools.partial(to_type, cast_model_outputs)

        for model in models:
            def patch_forward(old_fwd):
                def new_fwd(*args, **kwargs):
                    output = old_fwd(*args, **kwargs)
                    return applier(output, output_caster)
                return new_fwd

            model.forward = patch_forward(model.forward)

    for i, optimizer in enumerate(optimizers):
        optimizers[i] = _process_optimizer(optimizer, properties)

    _amp_state.loss_scalers = []
    for _ in range(num_losses):
        _amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
                                                  min_loss_scale=_amp_state.min_loss_scale,
                                                  max_loss_scale=_amp_state.max_loss_scale))

    if properties.patch_torch_functions:
        # handle is unused here. It's accessible later through a global value anyway.
        handle = amp_init(loss_scale=properties.loss_scale, verbose=(_amp_state.verbosity == 2))
        for optimizer in optimizers:
            # Disable Amp casting for the optimizer step, because it should only be
            # applied to FP32 master params anyway.
            def patch_step(old_step):
                def new_step(self, *args, **kwargs):
                    with disable_casts():
                        output = old_step(*args, **kwargs)
                    return output
                return new_step

            optimizer.step = MethodType(patch_step(optimizer.step), optimizer)

    if optimizers_was_list:
        if models_was_list:
            return models, optimizers
        else:
            return models[0], optimizers
    else:
        if models_was_list:
            if len(optimizers) == 0:
                return models
            else:
                return models, optimizers[0]
        else:
            if len(optimizers) == 0:
                return models[0]
            else:
                return models[0], optimizers[0]
