from dataclasses import dataclass

import math
import torch

from ...classes import MLP, Hyperparameters, ModelInterface


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


class Decoder(torch.nn.Module):
    def __init__(self, input_size: int, output_size: int, hidden_sizes: list,
                 activation_func_module: torch.nn.Module):
        super().__init__()
        segments = _split_list(hidden_sizes, 2)
        assert len(segments) == 2
        hidden_sizes_first_half = segments[0]
        hidden_sizes_second_half = segments[1]
        self.mlp_first_half = MLP(
            input_size=input_size,
            output_size=hidden_sizes_first_half[-1],
            hidden_sizes=hidden_sizes_first_half[:-1],
            activation_func_module=activation_func_module,
            output_activation=True)
        self.mlp_second_half = MLP(
            input_size=hidden_sizes_first_half[-1] + input_size,
            output_size=output_size,
            hidden_sizes=hidden_sizes_second_half,
            activation_func_module=activation_func_module)

    def forward(self, X: torch.Tensor, h: torch.Tensor):
        X = torch.cat((X, h), dim=2)
        output_first_half = self.mlp_first_half(X)
        input_second_half = torch.cat((X, output_first_half), dim=2)
        output_second_half = self.mlp_second_half(input_second_half)
        return output_second_half
