from fast_scaler import FastScaler
import torch.nn as nn
import torch

from models.model_wrapper import ModelWrapper

try:
    # If kernprof is running, profile will be available as builtin
    profile
except NameError:
    # Otherwise import no-op version
    from nn_util import profile

class ScaleWrapper(ModelWrapper):
    def __init__(self, wrapped: nn.Module, input_scaler: FastScaler | None, output_scaler: FastScaler | None):
        super(ScaleWrapper, self).__init__()

        self.wrapped = wrapped
        self.input_scaler = input_scaler
        self.output_scaler = output_scaler

    @profile
    def forward(self, input):
        if self.input_scaler is not None:
            input = self.input_scaler.transform(input)

        output = self.wrapped(input)

        if self.output_scaler is not None and not self.wrapped.training:
            output = self.output_scaler.inverse_transform(output)

        return output

    def to(self, *args, **kwargs):
        result = super().to(*args, **kwargs)

        if self.input_scaler:
            self.input_scaler.to_device(self.device)
        if self.output_scaler:
            self.output_scaler.to_device(self.device)

        return result

    def __getattr__(self, name):
        wrapped = self._modules['wrapped']
        if name == "wrapped":
            return wrapped
        try:
            return getattr(wrapped, name)
        except AttributeError:
            raise AttributeError(f"'{type(self).__name__}' object and its wrapped module have no attribute '{name}'")

class DiffusionScaleWrapper(nn.Module):
    def __init__(self, wrapped: nn.Module, input_scaler: FastScaler | None, action_min, action_max, action_len):
        super(DiffusionScaleWrapper, self).__init__()

        self.wrapped = wrapped
        self.input_scaler = input_scaler
        if self.input_scaler is not None:
            self.input_scaler = self.input_scaler.to_device(wrapped.device)
        self.action_min = action_min.to(wrapped.device)
        self.action_max = action_max.to(wrapped.device)
        self.action_len = action_len

    def forward(self, input):
        if self.wrapped.model.training:
            if self.input_scaler is not None:
                input[:, :-self.action_len] = self.input_scaler.transform(input[:, :-self.action_len])

            input[:, -self.action_len:].sub_(self.action_min).div_(self.action_max - self.action_min).mul_(2).sub_(1)

            # At training time, this will be noise loss - don't touch it
            output = self.wrapped(input)
        else:
            if self.input_scaler is not None:
                input = self.input_scaler.transform(input)

            output = self.wrapped(input)

            output.add_(1).div_(2).mul_(self.action_max - self.action_min).add_(self.action_min)

        return output

    def to(self, *args, **kwargs):
        result = super().to(*args, **kwargs)

        self.wrapped.to(*args, **kwargs)

        new_device = None
        if args:
            if isinstance(args[0], (torch.device, str, int)):
                new_device = torch.device(args[0])
        elif 'device' in kwargs:
            new_device = torch.device(kwargs['device'])

        if new_device:
            if self.input_scaler:
                self.input_scaler.to_device(new_device)

        self.action_min = self.action_min.to(*args, **kwargs)
        self.action_max = self.action_max.to(*args, **kwargs)

        return result

    def __getattr__(self, name):
        wrapped = self._modules['wrapped']
        if name == "wrapped":
            return wrapped
        try:
            return getattr(wrapped, name)
        except AttributeError:
            raise AttributeError(f"'{type(self).__name__}' object and its wrapped module have no attribute '{name}'")
