import torch.nn as nn
import torch as th
import numpy as np
from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, LambdaModule, Binarize, PartialBinarize, Buffer
from einops import rearrange, repeat, reduce

from typing import Tuple, Union, List
import utils
import cv2
from nn.eprop_gate_l0rd import ReTanh

class UpdateModule(nn.Module):
    def __init__(
        self,
        num_inputs: int,
        num_hidden: list,
        bias: bool,
        num_objects: int,
        gate_noise_level: float = 0.1,
        reg_lambda: float = 0.000005
    ):
        super(UpdateModule, self).__init__()

        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o=num_objects))

        self.layers = nn.Sequential(
            nn.Linear(num_inputs, num_hidden[0], bias = bias),
            nn.Tanh(),
            nn.Linear(num_hidden[0], num_hidden[1], bias = bias),
            nn.Tanh(),
            nn.Linear(num_hidden[1], 2, bias = bias)
        )
        self.output_function = ReTanh(reg_lambda)
        self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False)
        self.init_weights(self.layers)

    def init_weights(self, layers):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform(layer.weight)
                layer.bias.data.fill_(0.01)

    def forward(self, position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2, evaluate=False):

        position_cur = self.to_batch(position_cur)
        gestalt_cur  = self.to_batch(gestalt_cur)
        priority_cur = self.to_batch(priority_cur)
        position_last = self.to_batch(position_last)
        gestalt_last  = self.to_batch(gestalt_last)
        priority_last = self.to_batch(priority_last)
        slots_occlusionfactor_cur = self.to_batch(slots_occlusionfactor_cur).detach()
        slots_occlusionfactor_last = self.to_batch(slots_occlusionfactor_last).detach()
        position_last2 = self.to_batch(position_last2).detach()

        input  = th.cat((position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2), dim=1)
        output = self.layers(input)
        if evaluate:
            output = self.output_function(output)
        else:
            noise  = th.normal(mean=0, std=self.noise, size=output.shape, device=output.device)    
            output = self.output_function(output + noise)

        return self.to_shared(output)
