import torch.nn as nn
import torch as th
import numpy as np
from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, LambdaModule, Binarize, PartialBinarize
from nn.eprop_lstm import EpropLSTM
from nn.eprop_gate_l0rd import EpropGateL0rd
from nn.residual import ResidualBlock
from nn.tracker import CaterSnitchTracker
from nn.vae import VariationalFunction
from torch.autograd import Function
from einops import rearrange, repeat, reduce

from typing import Tuple, Union, List
import utils
import cv2



class AlphaAttention(nn.Module):
    def __init__(
        self,
        num_hidden,
        num_objects,
        heads,
        dropout = 0.0
    ):
        super(AlphaAttention, self).__init__()

        self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects))
        self.to_batch    = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects))

        self.alpha     = nn.Parameter(th.zeros(1)+1e-12)
        self.attention = nn.MultiheadAttention(
            num_hidden, 
            heads, 
            dropout = dropout, 
            batch_first = True
        )

    def forward(self, x: th.Tensor):
        x = self.to_sequence(x)
        x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0]
        return self.to_batch(x)

class EpropAlphaGateL0rd(nn.Module):
    def __init__(self, num_hidden, batch_size, reg_lambda):
        super(EpropAlphaGateL0rd, self).__init__()
        
        self.alpha = nn.Parameter(th.zeros(1)+1e-12)
        self.l0rd  = EpropGateL0rd(
            num_inputs  = num_hidden, 
            num_hidden  = num_hidden, 
            num_outputs = num_hidden, 
            reg_lambda  = reg_lambda,
            batch_size = batch_size
        )

    def forward(self, input):
        return input + self.alpha * self.l0rd(input)

class InputEmbeding(nn.Module):
    def __init__(self, num_inputs, num_hidden):
        super(InputEmbeding, self).__init__()

        self.embeding = nn.Sequential(
            nn.ReLU(),
            nn.Linear(num_inputs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
        )
        self.skip = LambdaModule(
            lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs)
        )
        self.alpha = nn.Parameter(th.zeros(1)+1e-12)

    def forward(self, input: th.Tensor):
        return self.skip(input) + self.alpha * self.embeding(input)

class OutputEmbeding(nn.Module):
    def __init__(self, num_hidden, num_outputs):
        super(OutputEmbeding, self).__init__()

        self.embeding = nn.Sequential(
            nn.ReLU(),
            nn.Linear(num_hidden, num_outputs),
            nn.ReLU(),
            nn.Linear(num_outputs, num_outputs),
        )
        self.skip = LambdaModule(
            lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs)
        )
        self.alpha = nn.Parameter(th.zeros(1)+1e-12)

    def forward(self, input: th.Tensor):
        return self.skip(input) + self.alpha * self.embeding(input)

class EpropGateL0rdTransformer(nn.Module):
    def __init__(
        self, 
        channels,
        multiplier,
        num_objects,
        batch_size,
        heads, 
        deepth,
        reg_lambda,
        dropout=0.0
    ):
        super(EpropGateL0rdTransformer, self).__init__()

        num_inputs  = channels
        num_outputs = channels
        num_hidden  = channels
        num_hidden  = channels * multiplier

        print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})")

        
        self.deepth = deepth
        _layers = []
        _layers.append(InputEmbeding(num_inputs, num_hidden))

        for i in range(deepth):
            _layers.append(AlphaAttention(num_hidden, num_objects, heads, dropout))
            _layers.append(EpropAlphaGateL0rd(num_hidden, batch_size * num_objects, reg_lambda))

        _layers.append(OutputEmbeding(num_hidden, num_outputs))
        self.layers = nn.Sequential(*_layers)

    def get_openings(self):
        openings = 0
        for i in range(self.deepth):
            openings += self.layers[2 * (i + 1)].l0rd.openings.item()

        return openings / self.deepth

    def get_hidden(self):
        states = []
        for i in range(self.deepth):
            states.append(self.layers[2 * (i + 1)].l0rd.get_hidden())

        return th.cat(states, dim=1)

    def set_hidden(self, hidden):
        states = th.chunk(hidden, self.deepth, dim=1)
        for i in range(self.deepth):
            self.layers[2 * (i + 1)].l0rd.set_hidden(states[i])

    def forward(self, input: th.Tensor) -> th.Tensor:
        return self.layers(input)

class LociPredictor(nn.Module): 
    def __init__(
        self, 
        heads: int, 
        layers: int,
        channels_multiplier: int,
        reg_lambda: float,
        vae_factor: float,
        num_objects: int, 
        gestalt_size: int, 
        batch_size: int,
        bottleneck: str,
        camera_view_matrix = None,
        zero_elevation = None,
    ):
        super(LociPredictor, self).__init__()
        self.num_objects = num_objects
        self.std_alpha   = nn.Parameter(th.zeros(1)+1e-16)
        self.bottleneck_type = bottleneck
        self.gestalt_size = gestalt_size

        self.reg_lambda = reg_lambda
        self.predictor  = EpropGateL0rdTransformer(
            channels    = gestalt_size + 3 + 1 + 2, 
            multiplier  = channels_multiplier,
            heads       = heads, 
            deepth      = layers,
            num_objects = num_objects,
            reg_lambda  = reg_lambda, 
            batch_size  = batch_size,
        )

        self.tracker = None
        if camera_view_matrix is not None:
            self.tracker = CaterSnitchTracker(
                latent_size        = gestalt_size + 4,
                num_objects        = num_objects,
                camera_view_matrix = camera_view_matrix,
                zero_elevation     = zero_elevation
            )
        
        if bottleneck == 'binar':
            print("Binary bottleneck")
            self.bottleneck = nn.Sequential(
                LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
                ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
                Binarize(),
                LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
            )

        elif bottleneck == "partial_binar":
             print("Partial Binar bottleneck")
             self.bottleneck = nn.Sequential(
                LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
                ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
                PartialBinarize(gestalt_size = gestalt_size),
                LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
            )                       

        elif bottleneck == 'vae':
            print("VAE bottleneck")
            self.bottleneck = nn.Sequential(
                LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
                ResidualBlock(gestalt_size, gestalt_size * 2, kernel_size=1),
                VariationalFunction(factor = vae_factor),
                nn.Sigmoid(),
                LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
            )
        else:
            print("unrestricted bottleneck")
            self.bottleneck = nn.Sequential(
                LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
                ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
                nn.Sigmoid(),
                LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
            )
                
        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))

    def get_openings(self):
        return self.predictor.get_openings()

    def get_hidden(self):
        return self.predictor.get_hidden()

    def set_hidden(self, hidden):
        self.predictor.set_hidden(hidden)

    def get_snitch_position(self):
        snitch_position = None
        if self.tracker is not None:
            snitch_position = self.tracker(*self.tracker_input)

        return snitch_position

    def forward(
        self, 
        gestalt: th.Tensor, 
        priority: th.Tensor,
        position: th.Tensor,
        slots_closed: th.Tensor,
    ):

        position = self.to_batch(position)
        gestalt_cur  = self.to_batch(gestalt)
        priority = self.to_batch(priority)
        slots_closed = rearrange(slots_closed, 'b o c -> (b o) c').detach()

        input  = th.cat((gestalt_cur, priority, position, slots_closed), dim=1)
        output = self.predictor(input)

        gestalt  = output[:, :self.gestalt_size]
        priority = output[:,self.gestalt_size:(self.gestalt_size+1)]
        xy       = output[:,(self.gestalt_size+1):(self.gestalt_size+3)]
        std      = output[:,(self.gestalt_size+3):(self.gestalt_size+4)]

        if self.tracker is not None:
            self.tracker_input = (xy, th.cat((gestalt, priority, std, slots_closed), dim=1))

        position = th.cat((xy, std * self.std_alpha), dim=1)
        
        position = self.to_shared(position)
        gestalt  = self.bottleneck(gestalt)
        priority = self.to_shared(priority)

        return position, gestalt, priority
