from typing import Union

import gym
import torch as th
import torch.nn as nn
from stable_baselines3.common.utils import get_device, get_schedule_fn, update_learning_rate

from ..utils.features_extractor import ResizeFeatureExtractors
from .features_extractor import ActionsExtractor, ObservationsExtractor


class _BinaryClassifier(nn.Module):
    
    def __init__(self, input_dim) -> None:
        super().__init__()

        self.model = nn.Sequential(
            nn.BatchNorm1d(input_dim),
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.BatchNorm1d(256),
            nn.Linear(256, 96),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.BatchNorm1d(96),
            nn.Linear(96, 1)
        )

    def forward(self, input):
        return self.model(input)


class InapplicableActionsClassifier(nn.Module):
    
    def __init__(
        self, 
        action_space: gym.Space, 
        observation_space: gym.Space, 
        use_policy_features: bool = False,
        device: Union[th.device, str] = "auto",
    ) -> None:
        super().__init__()

        self._observation_space = observation_space
        self.use_policy_features = use_policy_features
        self.device = get_device(device)

        self.observations_extractor = ObservationsExtractor(self._observation_space)
        self.actions_extractor = ActionsExtractor(action_space)

        self.classifier = _BinaryClassifier(self.actions_extractor.output_dim + self.observations_extractor.output_dim)
        self.criterion = nn.BCEWithLogitsLoss()
        self.lr_schedule = get_schedule_fn(3e-4)
        self.optimizer = th.optim.Adam(self.parameters(), lr=self.lr_schedule(1), eps=1e-5)

        self.to(self.device)

    def _combine_inputs(self, actions, observations):
        if not self.use_policy_features:
            observations = self.observations_extractor(observations)
        actions = self.actions_extractor(actions)

        return th.cat([actions, observations], dim=1)

    def forward(self, actions=None, observations=None):
        return self.classifier(self._combine_inputs(actions, observations))

    def loss(self, y_pred, y_true):
        return self.criterion(y_pred, y_true)

    def update_learning_rate(self, progress_remaining):
        update_learning_rate(self.optimizer, self.lr_schedule(progress_remaining))

    def save(self, path):
        th.save(self, path)

    def load(self, path, load_optimizer=False, load_linear=True):
        model = th.load(path)

        if model.observations_extractor._observation_space.shape[1] != self.observations_extractor._observation_space.shape[1]:
            self.observations_extractor = ResizeFeatureExtractors(model.observations_extractor)
        else:
            self.observations_extractor.load(model.observations_extractor, load_linear=load_linear)

        self.actions_extractor.load(model.actions_extractor, load_linear=load_linear)
        self.classifier = model.classifier

        if load_optimizer:
            self.optimizer = model.optimizer
        else:
            self.lr_schedule = get_schedule_fn(1e-4)
            self.optimizer = th.optim.Adam(self.parameters(), lr=self.lr_schedule(1), eps=1e-5)

