from dataclasses import dataclass

import math
import torch

from ..classes import MLP, Hyperparameters, ModelInterface


def _split_list(array, segments):
    assert len(array) >= segments
    num_elements_per_segment = math.ceil(len(array) / segments)
    ret = []
    for _ in range(segments - 1):
        ret.append(array[:num_elements_per_segment])
        array = array[num_elements_per_segment:]
    ret.append(array)
    return ret


@dataclass
class DecoderHyperparameters(Hyperparameters):
    prefix = "decoder"
    hidden_sizes: str = "512,512,512,512,512,512,512"
    disable_skip_connection: bool = False
    dropout_prob: float = 0.2
    weight_norm: bool = False


class Decoder(ModelInterface):
    def __init__(self,
                 input_size: int,
                 output_size: int,
                 hidden_sizes: list,
                 activation_func_module: torch.nn.Module,
                 dropout_prob: float,
                 weight_norm: bool,
                 same_dims=False):
        super().__init__()
        segments = _split_list(hidden_sizes, 2)
        assert len(segments) == 2
        hidden_sizes_first_half = segments[0]
        hidden_sizes_second_half = segments[1]
        if same_dims:
            self.mlp_first_half = MLP(
                input_size=input_size,
                output_size=hidden_sizes_first_half[-1] - input_size,
                hidden_sizes=hidden_sizes_first_half[:-1],
                activation_func_module=activation_func_module,
                dropout_prob=dropout_prob,
                weight_norm=weight_norm,
                output_activation=True)
            self.mlp_second_half = MLP(
                input_size=hidden_sizes_first_half[-1],
                output_size=output_size,
                hidden_sizes=hidden_sizes_second_half,
                dropout_prob=dropout_prob,
                weight_norm=weight_norm,
                activation_func_module=activation_func_module)
        else:
            self.mlp_first_half = MLP(
                input_size=input_size,
                output_size=hidden_sizes_first_half[-1],
                hidden_sizes=hidden_sizes_first_half[:-1],
                activation_func_module=activation_func_module,
                dropout_prob=dropout_prob,
                weight_norm=weight_norm,
                output_activation=True)
            self.mlp_second_half = MLP(
                input_size=hidden_sizes_first_half[-1] + input_size,
                output_size=output_size,
                hidden_sizes=hidden_sizes_second_half,
                dropout_prob=dropout_prob,
                weight_norm=weight_norm,
                activation_func_module=activation_func_module)

    def forward(self, x: torch.Tensor, z: torch.Tensor = None):
        if z is not None:
            x = torch.cat((x, z), dim=2)
        output_first_half = self.mlp_first_half(x)
        input_second_half = torch.cat((x, output_first_half), dim=2)
        output_second_half = self.mlp_second_half(input_second_half)
        return output_second_half
