from typing import Dict, Optional

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from mcu.arm.utils.vpt_lib.action_head import fan_in_linear
from mcu.arm.utils.vpt_lib.normalize_ewma import NormalizeEwma


class ScaledMSEHead(nn.Module):
    """
    Linear output layer that scales itself so that targets are always normalized to N(0, 1)
    """

    def __init__(
        self, input_size: int, output_size: int, norm_type: Optional[str] = "ewma", norm_kwargs: Optional[Dict] = None
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.norm_type = norm_type

        self.linear = nn.Linear(self.input_size, self.output_size)

        norm_kwargs = {} if norm_kwargs is None else norm_kwargs
        self.normalizer = NormalizeEwma(output_size, **norm_kwargs)

    def reset_parameters(self):
        init.orthogonal_(self.linear.weight)
        fan_in_linear(self.linear)
        self.normalizer.reset_parameters()

    def forward(self, input_data):
        return self.linear(input_data)

    def loss(self, prediction, target, reduction="mean"):
        """
        Calculate the MSE loss between output and a target.
        'Prediction' has to be normalized while target is denormalized.
        Loss is calculated in a 'normalized' space.
        """
        return F.mse_loss(prediction, self.normalizer(target), reduction=reduction)

    def denormalize(self, input_data):
        """Convert input value from a normalized space into the original one"""
        # import logging
        # logger = logging.getLogger("ray")
        # logger.info(f"{(input_data.flatten())[0]}, {(self.normalizer.denormalize(input_data).flatten())[0]}, {input_data.shape}, {self.normalizer.norm_axes}")
        return self.normalizer.denormalize(input_data)

    def normalize(self, input_data):
        return self.normalizer(input_data)
