import os
from typing import List, Callable
from collections import defaultdict

import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
from driving_gridworld.actions import ACTIONS
from environment_generator import EnvironmentDataset
from config import Config

class RewardModel(nn.Module):
    """
    A simple convnet used to learn an approximation of the reward function
    in the driving_gridworld environment.
    """
    def __init__(self):
        super(RewardModel, self).__init__()
        self.conv0, self.conv1 = nn.Conv2d(4, 4, 2), nn.Conv2d(4, 4, 2)
        self.config = Config()
        self.fc1, self.fc2, = [], []
        for _ in range(self.config.speed_limit + 1):
            self.fc1.append(nn.Linear(52, 32))
            self.fc2.append(nn.Linear(32, 1))
        self.fc1, self.fc2 = (
            nn.ModuleList(self.fc1),
            nn.ModuleList(self.fc2),
        )

    def onehot_encode_spds(self, spds):
        onehot = torch.zeros(len(spds), self.config.speed_limit + 1)
        for i, spd in enumerate(spds):
            onehot[i][spd] = 1
        return onehot

    def forward(self, minibatch) -> torch.tensor:
        """
        Assumes that all speeds in the minibatch are the same.
        :param layers: A minibatch of layers.
        :returns: an array of floats corresponding to the layers.
        """

        (layers, spds), (layers_after, spds_after) = minibatch
        assert all(spds[0] == spds[i] for i in range(len(spds)))
        spd = spds[0]
        spds_after = self.onehot_encode_spds(spds_after)

        x = torch.cat(
            (
                F.relu(self.conv0(layers)),
                F.relu(self.conv1(layers_after)),
            ),
            axis=1
        )
        x = x.reshape(layers.shape[0], -1)
        x = torch.cat((x, spds_after), axis=1)
        x = F.relu(self.fc1[spd](x))
        x = self.fc2[spd](x)
        return x

def load_all_saved_models(
        eval: bool = True,
        amount_to_load: int = None,
) -> List[nn.Module]:
    """
    Loads all saved reward models located in config.models_dir.
    """
    config = Config()
    models_to_load = os.listdir(config.models_dir)
    if amount_to_load:
        models_to_load = models_to_load[:amount_to_load]
    models = []
    for filename in models_to_load:
        seed = int(filename.split('.')[0])
        model = torch.load(os.path.join(config.models_dir, filename))
        model.seed = seed
        models.append(model)
    return [model.eval() for model in models] if eval else models

def model_to_reward_function(
        model: Callable,
        env: EnvironmentDataset,
):
    """
    Converts a model into a reward function for the driving_gridworld
    environments.
    """
    memo = defaultdict(dict)
    def get_model_input(state_index):
        x = env.state_to_tensor(
            env.board_to_state(
                env.obtain_board_representation(state_index)
            )
        )
        x = (x[0].unsqueeze(0), x[1])
        return x

    def reward_function(state_index, action):
        if state_index in memo and action in memo[state_index]:
            return memo[state_index][action]
        expected_reward = 0
        inp = get_model_input(state_index)
        for i, next_state_prob in enumerate(env.prob_trans_mat[state_index][action]):
            if next_state_prob > 0:
                next_inp = get_model_input(i)
                with torch.no_grad():
                    expected_reward += model((inp, next_inp)) * next_state_prob
            memo[state_index][action] = expected_reward
        return memo[state_index][action]
    return reward_function
