#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import torch
from torch import nn
from typing import Optional
from options import logger
# from venv import logger
import math

from .normalization import build_normalization_layer, NORM_LAYER_CLS

norm_layers_tuple = tuple(NORM_LAYER_CLS)

def get_complex_normalization_layer(
        opts,
        num_features: int,
        norm_type: Optional[str] = None,
        num_groups: Optional[int] = None,
        *args,
        **kwargs
) -> nn.Module:
    """
    Helper function to get normalization layers
    """
    class ComplexNorm(nn.Module):
        '''
        Naive approach to complex batch norm, perform batch norm independently on real and imaginary part.
        '''
        def __init__(self, opts, num_features, norm_type, num_groups):
            super(ComplexNorm, self).__init__()
            self.bn_r = build_normalization_layer(opts, num_features, norm_type, num_groups)
            self.bn_i = build_normalization_layer(opts, num_features, norm_type, num_groups)

        def forward(self, input):
            return self.bn_r(input.real).type(torch.complex64) + 1j * self.bn_i(input.imag).type(torch.complex64)

    # return build_normalization_layer(opts, num_features, norm_type, num_groups)
    return ComplexNorm(opts, num_features, norm_type, num_groups)

def get_normalization_layer(
    opts,
    num_features: int,
    norm_type: Optional[str] = None,
    num_groups: Optional[int] = None,
    *args,
    **kwargs
) -> nn.Module:
    """
    Helper function to get normalization layers
    """
    return build_normalization_layer(opts, num_features, norm_type, num_groups)


class AdjustBatchNormMomentum(object):
    """
    This class enables adjusting the momentum in batch normalization layer.

    .. note::
        It's an experimental feature and should be used with caution.
    """

    round_places = 6

    def __init__(self, opts, *args, **kwargs):
        self.is_iteration_based = getattr(opts, "scheduler.is_iteration_based", True)
        self.warmup_iterations = getattr(opts, "scheduler.warmup_iterations", 10000)

        if self.is_iteration_based:
            self.max_steps = getattr(opts, "scheduler.max_iterations", 100000)
            self.max_steps -= self.warmup_iterations
            assert self.max_steps > 0
        else:
            logger.warning(
                "Running {} for epoch-based methods. Not yet validation.".format(
                    self.__class__.__name__
                )
            )
            self.max_steps = getattr(opts, "scheduler.max_epochs", 100)

        self.momentum = getattr(opts, "model.normalization.momentum", 0.1)
        self.min_momentum = getattr(
            opts, "model.normalization.adjust_bn_momentum.final_momentum_value", 1e-6
        )

        if self.min_momentum >= self.momentum:
            logger.error(
                "Min. momentum value in {} should be <= momentum. Got {} and {}".format(
                    self.__class__.__name__, self.min_momentum, self.momentum
                )
            )

        anneal_method = getattr(
            opts, "model.normalization.adjust_bn_momentum.anneal_type", "cosine"
        )
        if anneal_method is None:
            logger.warning(
                "Annealing method in {} is None. Setting to cosine".format(
                    self.__class__.__name__
                )
            )
            anneal_method = "cosine"

        anneal_method = anneal_method.lower()

        if anneal_method == "cosine":
            self.anneal_fn = self._cosine
        elif anneal_method == "linear":
            self.anneal_fn = self._linear
        else:
            raise RuntimeError(
                "Anneal method ({}) not yet implemented".format(anneal_method)
            )
        self.anneal_method = anneal_method

    def _cosine(self, step: int) -> float:
        curr_momentum = self.min_momentum + 0.5 * (
            self.momentum - self.min_momentum
        ) * (1 + math.cos(math.pi * step / self.max_steps))

        return round(curr_momentum, self.round_places)

    def _linear(self, step: int) -> float:
        momentum_step = (self.momentum - self.min_momentum) / self.max_steps
        curr_momentum = self.momentum - (step * momentum_step)
        return round(curr_momentum, self.round_places)

    def adjust_momentum(self, model: nn.Module, iteration: int, epoch: int) -> None:
        if iteration >= self.warmup_iterations:
            step = (
                iteration - self.warmup_iterations if self.is_iteration_based else epoch
            )
            curr_momentum = max(0.0, self.anneal_fn(step))

            for m in model.modules():
                if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)) and m.training:
                    m.momentum = curr_momentum

    def __repr__(self):
        return "{}(iteration_based={}, inital_momentum={}, final_momentum={}, anneal_method={})".format(
            self.__class__.__name__,
            self.is_iteration_based,
            self.momentum,
            self.min_momentum,
            self.anneal_method,
        )
