import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from .. import utils
from ..utils.nn import one_hot
from .multcombination import MultCombination

class DeterministicModel(utils.nn.Network):
    def __init__(self, state_size, n_actions, n_factors=64, n_hidden_layers=1, n_nodes_per_layer=256, learning_rate=1e-4):
        super().__init__()
        self.state_size = state_size
        self.n_actions = n_actions
        self.n_factors = n_factors
        self.n_hidden_layers = n_hidden_layers
        self.n_nodes_per_layer = n_nodes_per_layer
        self.learning_rate = learning_rate

        self.combination = MultCombination(state_size, n_actions, n_factors)

        fc_layers = []
        for i in range(n_hidden_layers+1):
            input_size  = state_size if (i==0) else n_nodes_per_layer
            output_size = n_nodes_per_layer if (i<n_hidden_layers) else state_size
            fc_layers.append(nn.Linear(input_size, output_size))
            if i < n_hidden_layers-1:
                fc_layers.append(nn.ReLU())
        self.fc_net = nn.Sequential(*fc_layers)

    def forward(self, x, a):
        conditioned_x = self.combination(x,a)
        x_ = self.fc_net(conditioned_x)
        return x_

    def train_batch(self, x, a, x_):
        pass

def main():
    x_size = 128
    a_size = 5
    n_factors = 16
    batch_size = 32
    T = DeterministicModel(x_size, a_size, n_factors)
    x = torch.ones(batch_size, x_size)
    a = one_hot(torch.multinomial(torch.ones(a_size), batch_size, replacement=True), a_size)
    x_ = 2 * torch.ones_like(x)
    xhat_ = T(x, a)

    n_epochs = 10
    for i in tqdm(range(n_epochs)):
        T.train_batch(x,a,x_)
    xhat_ = T(x, a)

    assert x_.shape == x.shape
    print('Testing complete.')

if __name__ == '__main__':
    main()
