import torch.nn as nn
import torch as th
import numpy as np
import nn as nn_modules
from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, LambdaModule, ForcedAlpha, PrintGradient
from nn.eprop_gate_l0rd import EpropGateL0rd
from nn.residual import ResidualBlock, LinearResidual
from nn.tracker import CaterSnitchTracker
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce

from typing import Tuple, Union, List
import utils
import cv2



class NewtonianCorrection(nn.Module):
    def __init__(self, latent_size):
        super(NewtonianCorrection, self).__init__()

        self.position_bias = nn.Parameter(th.ones(1)*5)
        self.position_gate = nn.Sequential(
            LinearResidual(latent_size, latent_size),
            LinearResidual(latent_size, 2),
            LambdaModule(lambda x: x + self.position_bias),
            nn.Sigmoid()
        )

        self.velocity_gate = nn.Sequential(
            LinearResidual(latent_size, latent_size),
            LinearResidual(latent_size, 2),
            nn.Sigmoid()
        )

        self.mean_velocity_update  = 0
        self.mean_position_update  = 0
        self.mean_velocity_update2 = 0
        self.mean_position_update2 = 0
        self.mean_sum              = 0

    def forward(
        self, 
        encoded_position, 
        corrected_position_last,
        predicted_position_last, 
        predicted_velocity_last, 
        latent_state
    ):
        update             = self.position_gate(latent_state)
        corrected_position = encoded_position * update + (1 - update) * predicted_position_last

        self.mean_position_update  = self.mean_position_update  * 0.999 + th.sum(1 - update).item()
        self.mean_position_update2 = self.mean_position_update2 * 0.999 + th.sum((1 - update)**2).item()

        velocity           = corrected_position - corrected_position_last
        update             = self.velocity_gate(latent_state)
        corrected_velocity = velocity * update + (1 - update) * predicted_velocity_last

        self.mean_velocity_update  = self.mean_velocity_update  * 0.999 + th.sum(1 - update).item()
        self.mean_velocity_update2 = self.mean_velocity_update2 * 0.999 + th.sum((1 - update)**2).item()
        self.mean_sum              = self.mean_sum              * 0.999 + update.numel()

        return corrected_position, corrected_velocity 

class EulerIntegrator(nn.Module):
    def __init__(self, latent_size):
        super(EulerIntegrator, self).__init__()

        self.alpha = nn.Parameter(th.ones(1) * 1e-16)
        self.compute_force = nn.Sequential(
            LinearResidual(latent_size, latent_size),
            LinearResidual(latent_size, 2),
            LambdaModule(lambda x: x * self.alpha),
        )

        self.mean_force  = 0
        self.mean_force2 = 0
        self.mean_sum    = 0

    def forward(self, position, velocity, latent_state):

        force = self.compute_force(latent_state)       

        velocity = velocity + force
        position = position + velocity

        self.mean_force  = self.mean_force  * 0.999 + th.sum(th.abs(force)).item()
        self.mean_force2 = self.mean_force2 * 0.999 + th.sum(force**2).item()
        self.mean_sum    = self.mean_sum    * 0.999 + force.numel()

        return position, velocity

class AlphaAttention(nn.Module):
    def __init__(
        self,
        num_hidden,
        num_objects,
        heads,
        dropout = 0.0
    ):
        super(AlphaAttention, self).__init__()

        self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects))
        self.to_batch    = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects))

        self.embedding_alpha = nn.Parameter(th.zeros(1)+1e-12)
        self.newtonian_embedding = nn.Sequential(
            LambdaModule(lambda x: th.cat(x, dim=1)),
            nn.Linear(num_hidden + 4, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden , num_hidden),
            LambdaModule(lambda x: x * self.embedding_alpha)
        )


        self.alpha     = nn.Parameter(th.zeros(1)+1e-12)
        self.attention = nn.MultiheadAttention(
            num_hidden, 
            heads, 
            dropout = dropout, 
            batch_first = True
        )

    def forward(self, position, velocity, latent_state):
        x = latent_state + self.newtonian_embedding((position, velocity, latent_state))
        x = self.to_sequence(x)
        x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0]
        return self.to_batch(x)

class EpropAlphaGateL0rd(nn.Module):
    def __init__(self, num_hidden, batch_size, reg_lambda):
        super(EpropAlphaGateL0rd, self).__init__()
        
        self.alpha = nn.Parameter(th.zeros(1)+1e-12)
        self.l0rd  = EpropGateL0rd(
            num_inputs  = num_hidden + 4, 
            num_hidden  = num_hidden, 
            num_outputs = num_hidden, 
            reg_lambda  = reg_lambda,
            batch_size = batch_size
        )

    def forward(self, position, velocity, latent_state):
        return latent_state + self.alpha * self.l0rd(th.cat((position, velocity, latent_state), dim=1))

class InputEmbeding(nn.Module):
    def __init__(self, num_inputs, num_hidden):
        super(InputEmbeding, self).__init__()

        self.embeding = nn.Sequential(
            nn.ReLU(),
            nn.Linear(num_inputs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
        )
        self.skip = LambdaModule(
            lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs)
        )
        self.alpha = nn.Parameter(th.zeros(1)+1e-12)

    def forward(self, input: th.Tensor):
        return self.skip(input) + self.alpha * self.embeding(input)

class OutputEmbeding(nn.Module):
    def __init__(self, num_hidden, num_outputs):
        super(OutputEmbeding, self).__init__()

        self.embeding = nn.Sequential(
            nn.ReLU(),
            nn.Linear(num_hidden, num_outputs),
            nn.ReLU(),
            nn.Linear(num_outputs, num_outputs),
        )
        self.skip = LambdaModule(
            lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs)
        )
        self.alpha = nn.Parameter(th.zeros(1)+1e-12)

    def forward(self, input: th.Tensor):
        return self.skip(input) + self.alpha * self.embeding(input)

class EpropGateL0rdTransformer(nn.Module):
    def __init__(
        self, 
        channels,
        multiplier,
        num_objects,
        batch_size,
        heads, 
        deepth,
        reg_lambda,
        dropout=0.0
    ):
        super(EpropGateL0rdTransformer, self).__init__()

        num_hidden  = (channels + 4) * multiplier

        print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})")

        self.newtonian_correction = NewtonianCorrection(channels)

        
        self.deepth = deepth
        self.input_embeding = InputEmbeding(channels + 4, num_hidden)

        _attentions  = []
        _l0rds       = []
        _integrators = []

        for i in range(deepth):
            _attentions.append(AlphaAttention(num_hidden, num_objects, heads, dropout))
            _l0rds.append(EpropAlphaGateL0rd(num_hidden, batch_size * num_objects, reg_lambda))
            _integrators.append(EulerIntegrator(num_hidden))

        self.attentions  = nn.Sequential(*_attentions)
        self.l0rds       = nn.Sequential(*_l0rds)
        self.integrators = nn.Sequential(*_integrators)

        self.output_embeding = OutputEmbeding(num_hidden, channels + 4)

        self.register_buffer("corrected_position_last", th.zeros((batch_size * num_objects, 2)), persistent = False)
        self.register_buffer("predicted_position_last", th.zeros((batch_size * num_objects, 2)), persistent = False)
        self.register_buffer("predicted_velocity_last", th.zeros((batch_size * num_objects, 2)), persistent = False)

    def reset_state(self):
        self.corrected_position_last.zero_()
        self.predicted_position_last.zero_()
        self.predicted_velocity_last.zero_()

    def detach(self):
        self.corrected_position_last = self.corrected_position_last.detach()
        self.predicted_position_last = self.predicted_position_last.detach()
        self.predicted_velocity_last = self.predicted_velocity_last.detach()

    def get_openings(self):
        openings = 0
        for i in range(self.deepth):
            openings += self.l0rds[i].l0rd.openings.item()

        return openings / self.deepth

    def get_hidden(self):
        states = []
        for i in range(self.deepth):
            states.append(self.l0rds[i].l0rd.get_hidden())

        return th.cat(states, dim=1)

    def set_hidden(self, hidden):
        states = th.chunk(hidden, self.deepth, dim=1)
        for i in range(self.deepth):
            self.l0rds[i].l0rd.set_hidden(states[i])

    def forward(self, encoded_position, latent_state) -> th.Tensor:
        
        corrected_position, corrected_velocity = self.newtonian_correction(
            encoded_position, 
            self.corrected_position_last,
            self.predicted_position_last,
            self.predicted_velocity_last,
            latent_state
        )

        position = corrected_position
        velocity = corrected_velocity

        latent_state = self.input_embeding(th.cat((position, velocity, latent_state), dim=1))

        for i in range(self.deepth):
            latent_state       = self.attentions[i](position, velocity, latent_state)
            latent_state       = self.l0rds[i](position, velocity, latent_state)
            position, velocity = self.integrators[i](position, velocity, latent_state)
        
        latent_state = self.output_embeding(latent_state)

        self.corrected_position_last = corrected_position
        self.predicted_position_last = position
        self.predicted_velocity_last = velocity

        return position, velocity, latent_state[:,4:]

class PriorityEncoder(nn.Module):
    def __init__(self, num_objects, batch_size):
        super(PriorityEncoder, self).__init__()

        self.num_objects = num_objects
        self.register_buffer("indices", repeat(th.arange(num_objects), 'a -> (b a) 1', b=batch_size), persistent=False)

        self.index_factor    = nn.Parameter(th.ones(1))
        self.priority_factor = nn.Parameter(th.ones(1))

    def forward(self, priority: th.Tensor) -> th.Tensor:
        
        priority = priority * self.num_objects + th.randn_like(priority) * 0.1
        priority = priority * self.priority_factor 
        priority = priority + self.indices * self.index_factor
        priority = rearrange(priority, '(b o) 1 -> b o', o=self.num_objects)

        return priority * 25

class LatentEpropPredictor(nn.Module): 
    def __init__(
        self, 
        heads: int, 
        layers: int,
        channels_multiplier: int,
        reg_lambda: float,
        num_objects: int, 
        gestalt_size: int, 
        vae_factor: float, 
        batch_size: int,
        camera_view_matrix = None,
        zero_elevation = None
    ):
        super(LatentEpropPredictor, self).__init__()
        self.num_objects = num_objects
        self.std_alpha   = nn.Parameter(th.zeros(1)+1e-16)

        self.reg_lambda = reg_lambda
        self.predictor  = EpropGateL0rdTransformer(
            channels    = gestalt_size + 2,
            multiplier  = channels_multiplier,
            heads       = heads, 
            deepth      = layers,
            num_objects = num_objects,
            reg_lambda  = reg_lambda, 
            batch_size  = batch_size,
        )

        self.tracker = None
        if camera_view_matrix is not None:
            self.tracker = CaterSnitchTracker(
                latent_size        = gestalt_size + 2,
                num_objects        = num_objects,
                camera_view_matrix = camera_view_matrix,
                zero_elevation     = zero_elevation
            )

        self.vae = nn.Sequential(
            LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
            ResidualBlock(gestalt_size, gestalt_size * 2, kernel_size=1),
            VariationalFunction(factor = vae_factor),
            nn.Sigmoid(),
            LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
        )

        self.priority_encoder = PriorityEncoder(num_objects, batch_size)
                
        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))

    def get_openings(self):
        return self.predictor.get_openings()

    def get_hidden(self):
        return self.predictor.get_hidden()

    def set_hidden(self, hidden):
        self.predictor.set_hidden(hidden)

    def forward(
        self, 
        position: th.Tensor, 
        gestalt: th.Tensor, 
        priority: th.Tensor,
    ):

        position = self.to_batch(position)
        gestalt  = self.to_batch(gestalt)
        priority = self.to_batch(priority)

        input  = th.cat((position[:,2:3], gestalt, priority), dim=1)
        xy, _, output = self.predictor(position[:,:2], input)

        std      = output[:,:1]
        gestalt  = output[:,1:-1]
        priority = output[:,-1:]

        snitch_position = None
        if self.tracker is not None:
            snitch_position = self.tracker(xy, output)

        position = th.cat((xy, std * self.std_alpha), dim=1)
        
        position = self.to_shared(position)
        gestalt  = self.vae(gestalt)
        priority = self.priority_encoder(priority)

        return position, gestalt, priority, snitch_position
