import torch
import torch.nn as nn
import torch.nn.functional as F

import os
from typing import Tuple, List

from config import Config

def head_state_to_tensor(models):
    def state_to_tensor(state):
        board, speed = state
        x = torch.Tensor(board).unsqueeze(0).unsqueeze(0)
        speed = speed.unsqueeze(0)
        convnet_representation = torch.cat([model.obtain_features(x)
                                            for model in models], 1)
        return (convnet_representation, speed)
    return state_to_tensor

class HeadModel(nn.Module):
    def __init__(
            self,
            num_actions: int,
            which_features: List[int],
    ):
        """
        :param which_features: Which features to use
        from the conv layer feature set for this
        HeadModel in particular.
        """
        super(HeadModel, self).__init__()
        self.which_features = which_features
        self.fc1 = nn.Linear(132, 52)
        self.fc2 = nn.Linear(52, 20)
        self.fc3 = nn.Linear(20, num_actions)

    def forward(
            self,
            minibatch: Tuple[torch.Tensor, torch.Tensor],
    ) -> torch.tensor:
        features, spds = minibatch
        x = features[:, self.which_features]
        x = x.reshape(x.shape[0], -1)
        x = torch.cat((x, spds), dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def load_all_saved_models(
        eval: bool = True,
        amount_to_load: int = None,
) -> List[nn.Module]:
    """
    Loads all saved reward models located in config.heads_dir.
    """
    config = Config()
    models_to_load = os.listdir(config.heads_dir)
    if amount_to_load:
        models_to_load = models_to_load[:amount_to_load]
    models = [torch.load(os.path.join(config.heads_dir, filename))
              for filename in models_to_load]
    return [model.eval() for model in models] if eval else models
