import torch.nn as nn
import torch as th
import numpy as np
from torch.autograd import Function
import nn as nn_modules
import utils
from nn.residual import ResidualBlock, SkipConnection
from nn.encoder import AggressiveDownConv
from nn.encoder import AggressiveConvTo1x1
from nn.decoder import AggressiveUpConv
from utils.utils import LambdaModule, ForcedAlpha, PrintShape, Warp, create_grid
from nn.eprop_gate_l0rd import EpropGateL0rd
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce
from utils.utils import MultiArgSequential

from typing import Union, Tuple



class CameraFlow(nn.Module):
    def __init__(self, size):
        super(CameraFlow, self).__init__()

        self.register_buffer('grid', th.cat((
            repeat(th.arange(0, size[1]).float(), 'w -> 1 1 h w', h = size[0]),
            repeat(th.arange(0, size[0]).float(), 'h -> 1 1 h w', w = size[1]),
            th.ones((1, 1, *size))
        ), dim=1), persistent=False)

        self.intrinsic_bias  = nn.Parameter(th.tensor([[1,1,0,0,0]]).float())
        self.intrinsic_alpha = nn.Parameter(th.tensor([[0,0,0,0,0]])+1e-16)
        self.pose_alpha = nn.Parameter(th.tensor([[0,0,0,0,0,0]])+1e-16)

        self.depth_bias  = nn.Parameter(th.ones(1))
        self.depth_alpha = nn.Parameter(th.zeros(1)+1e-16)
        np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})

    def intrinsics_to_matrix(self, intrinsics):
        intrinsics = intrinsics * self.intrinsic_alpha + self.intrinsic_bias
        fx, fy, s, ox, oy  = th.chunk(intrinsics, 5, dim=1)
        zero = th.zeros_like(fx)
        one  = th.ones_like(fx)

        #print("Intrinsic:")
        #print(self.intrinsic_alpha.detach().cpu().numpy())
        #print(self.intrinsic_bias.detach().cpu().numpy())
        #print(reduce(th.abs(th.cat((fx,s,ox,zero,fy,oy,zero,zero,one), dim=1).reshape(-1, 3, 3)), 'b h w -> h w', 'mean').detach().cpu().numpy())

        return th.cat((
            fx,      s,  ox,
            zero,   fy,  oy,
            zero, zero, one,
        ), dim=1).reshape(-1, 3, 3)

    def pose_to_matrix(self, pose):
        rx, ry, rz, tx, ty, tz = th.chunk(pose * self.pose_alpha, 6, dim=1)

        zero = th.zeros_like(rx)
        one  = th.ones_like(rx)

        sinx = th.sin(rx)
        cosx = th.cos(rx)

        siny = th.sin(ry)
        cosy = th.cos(ry)

        sinz = th.sin(rz)
        cosz = th.cos(rz)

        zmat = th.cat((
            cosz, -sinz, zero,
            sinz,  cosz, zero,
            zero,  zero,  one
        ), dim=1).reshape(-1, 3, 3)

        ymat = th.cat((
            cosy,  zero, siny,
            zero,   one, zero,
            -siny, zero, cosy
        ), dim=1).reshape(-1, 3, 3)

        xmat = th.cat((
            one,  zero,  zero,
            zero, cosx, -sinx,
            zero, sinx,  cosx
        ), dim=1).reshape(-1, 3, 3)

        #print("Pose:")
        #print(self.pose_alpha.detach().cpu().numpy())
        #print(reduce(th.abs(th.cat((xmat @ ymat @ zmat, th.stack((tx, ty, tz), dim=1)), dim=2)), 'b h w -> h w', 'mean').detach().cpu().numpy())

        return th.cat((xmat @ ymat @ zmat, th.stack((tx, ty, tz), dim=1)), dim=2)


    def forward(self, intrinsics, pose, depth):

        B, _, H, W = depth.shape

        depth = depth * self.depth_alpha + self.depth_bias
        #print(f'{th.min(depth).item():.2e}, {th.mean(depth).item():.2e} +- {th.std(depth).item():.2e}, {th.max(depth):.2e}')

        pixel_coodinates  = repeat(self.grid[:,:,:H,:W], '1 c h w -> b c (h w)', b = B)
        intrinsic_matrix  = self.intrinsics_to_matrix(intrinsics)
        pose_matrix       = self.pose_to_matrix(pose)

        projection_matrix     = intrinsic_matrix @ pose_matrix
        rotation, translation = projection_matrix[:,:,:3], projection_matrix[:,:,3:]

        camera_coodinates = intrinsic_matrix.inverse() @ pixel_coodinates
        scene_coodinates  = rearrange(depth, 'b c h w -> b c (h w)') * camera_coodinates
        scene_coodinates  = (rotation @ scene_coodinates) + translation

        X, Y, Z = th.chunk(scene_coodinates, 3, dim=1)
        Z       = th.clamp(Z, 1e-3)

        X = 2 * (X / Z) / (W - 1) - 1
        Y = 2 * (Y / Z) / (H - 1) - 1

        X_mask = ((X > 1)+(X < -1)).detach().float()
        Y_mask = ((Y > 1)+(Y < -1)).detach().float()

        # make sure that no point in warped image is a combination of image and gray
        #X = 2 * X_mask + (1 - X_mask) * X
        #Y = 2 * Y_mask + (1 - Y_mask) * Y

        flow  = rearrange(th.cat((X, Y), dim=1), ' b c (h w) -> b c h w', h=H, w=W)

        #valid_mask = (reduce(flow, 'b c h w -> b 1 h w', 'max') <= 1).float()
        valid_mask = th.cat((X_mask, Y_mask), dim=1)
        valid_mask = 1 - reduce(valid_mask, 'b c (h w) -> b 1 h w', 'max', h=H, w=W)
        
        return flow, valid_mask, depth

class PatchDownConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1):
        super(PatchDownConv, self).__init__()
        assert out_channels % in_channels == 0
        
        self.layers = nn.Conv2d(
            in_channels  = in_channels, 
            out_channels = out_channels, 
            kernel_size  = kernel_size,
            stride       = kernel_size,
        )

        self.alpha = nn.Parameter(th.zeros(1) + alpha)
        self.kernel_size = 4
        self.channels_factor = out_channels // in_channels

    def forward(self, input: th.Tensor):
        k = self.kernel_size
        c = self.channels_factor
        skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k)
        skip = repeat(skip, 'b c h w -> b (c n) h w', n=c)
        return skip + self.alpha * self.layers(input)

class BicubicUpscale(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1):
        super(BicubicUpscale, self).__init__()
        assert in_channels % out_channels == 0
        
        self.upscale = LambdaModule(
            lambda x: nn.functional.interpolate(
                x, 
                scale_factor = scale_factor, 
                mode = 'bicubic',
                align_corners = True
            )
        )

        self.channel_skip = SkipConnection(in_channels, out_channels)

        self.residual = nn.Sequential(
            nn.Conv2d(
                in_channels  = in_channels, 
                out_channels = out_channels, 
                kernel_size  = 3,
                padding      = 1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels  = out_channels, 
                out_channels = out_channels, 
                kernel_size  = 3,
                padding      = 1
            )
        )

        self.alpha = nn.Parameter(th.zeros(1) + alpha)

    def forward(self, input):
        input = self.upscale(input)
        return self.channel_skip(input) + self.alpha * self.residual(input)

class EpropAlphaGateL0rd2D(nn.Module):
    def __init__(self, num_hidden, batch_size, reg_lambda):
        super(EpropAlphaGateL0rd2D, self).__init__()
        
        self.alpha = nn.Parameter(th.zeros(1)+1e-12)
        self.l0rd  = EpropGateL0rd(
            num_inputs  = num_hidden, 
            num_hidden  = num_hidden, 
            num_outputs = num_hidden, 
            reg_lambda  = reg_lambda,
            batch_size = batch_size
        )

    def forward(self, input):
        H, W   = input.shape[2:]
        input  = rearrange(input, 'b c h w -> (b h w) c')
        output = input + self.alpha * self.l0rd(input)
        return rearrange(output, '(b h w) c -> b c h w', h = H, w = W)

class ImageAttention(nn.Module):
    def __init__(self, channels, channels_per_head = 14, dropout = 0.0):
        super(ImageAttention, self).__init__()

        assert channels % channels_per_head == 0
        heads = channels // channels_per_head

        self.alpha = nn.Parameter(th.zeros(1)+1e-16)
        self.attention = nn.MultiheadAttention(
            channels, 
            heads, 
            dropout = dropout, 
            batch_first = True
        )

        self.image_attention = LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0])

    def forward(self, input: th.Tensor) -> th.Tensor:
        H, W   = input.shape[2:]
        input  = rearrange(input, 'b c h w -> b (h w) c')
        output = input + self.image_attention(input) * self.alpha
        return rearrange(output, 'b (h w) c -> b c h w', h = H, w = W)

class BackgroundEnhancer(nn.Module):
    def __init__(
        self, 
        input_size: Tuple[int, int], 
        img_channels: int, 
        level1_channels,
        latent_channels,
        gestalt_size,
        batch_size,
        reg_lambda,
        vae_factor,
        deepth
    ):
        super(BackgroundEnhancer, self).__init__()

        latent_size = [input_size[0] // 16, input_size[1] // 16]
        self.input_size = input_size

        level1_layers = deepth
        level0_layers = deepth
        img_channels  = 3 + 1 + 2 + 1 # rgb, warped depth, flow, error

        self.input_depth_alpha = nn.Parameter(th.ones(1))
        self.input_flow_alpha  = nn.Parameter(th.ones(1))
        self.input_error_alpha = nn.Parameter(th.ones(1))

        self.input_depth_beta = nn.Parameter(th.zeros(1))
        self.input_flow_beta  = nn.Parameter(th.zeros(1))
        self.input_error_beta = nn.Parameter(th.zeros(1))

        self.register_buffer('init', th.zeros(1).long())
        self.alpha = nn.Parameter(th.zeros(1)+1e-16)
        self.level = 1
        
        self.warp  = nn.ModuleList([
            Warp((input_size[0] // 16, input_size[1] // 16)),
            Warp((input_size[0] // 4,  input_size[1] // 4)),
            Warp(input_size)
        ])
        self.flow  = CameraFlow(input_size)

        self.down_level2 = nn.Sequential(
            PatchDownConv(img_channels, level1_channels, alpha = 1e-16),
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(level0_layers)]
        )

        self.bias = nn.Parameter(th.zeros(1, latent_channels, *latent_size))
        self.down_level1 = nn.Sequential(
            PatchDownConv(level1_channels, latent_channels, alpha = 1e-16),
            LambdaModule(lambda x: x + self.bias),
            ResidualBlock(latent_channels, latent_channels, alpha_residual = True),
            ResidualBlock(latent_channels, latent_channels, alpha_residual = True),
        )

        self.level0 = nn.Sequential(
            *[nn.Sequential(
                #ImageAttention(latent_channels),
                ResidualBlock(latent_channels, latent_channels, alpha_residual = True),
                #EpropAlphaGateL0rd2D(latent_channels, batch_size * np.prod(latent_size), reg_lambda),
                ResidualBlock(latent_channels, latent_channels, alpha_residual = True),
            ) for _ in range(level0_layers) ]
        )

        self.up_level1 = nn.Sequential(
            BicubicUpscale(latent_channels, level1_channels, alpha = 1e-16),
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for _ in range(level1_layers)]
        )

        self.up_level2 = BicubicUpscale(level1_channels, 1, alpha = 1e-16)

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels, latent_channels),
            SkipConnection(img_channels, level1_channels),
            SkipConnection(img_channels, img_channels),
        ])

        self.to_depth = nn.ModuleList([
            SkipConnection(latent_channels, 1),
            SkipConnection(level1_channels, 1),
            SkipConnection(1,               1),
        ])
                

        self.intrinsics = nn.Sequential(
            AggressiveConvTo1x1(latent_channels, latent_size, alpha=1e-16),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropGateL0rd(
                num_inputs  = latent_channels,
                num_hidden  = latent_channels,
                num_outputs = 5,
                batch_size  = batch_size, 
                reg_lambda  = reg_lambda*10
            ),
        )

        self.pose = nn.Sequential(
            AggressiveConvTo1x1(latent_channels, latent_size, alpha=1e-16),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropGateL0rd(
                num_inputs  = latent_channels,
                num_hidden  = latent_channels,
                num_outputs = 6,
                batch_size  = batch_size, 
                reg_lambda  = reg_lambda
            ),
        )

        self.last_flow = None
        self.last_flow_mask = None
        self.last_depth_warped = None
        self.last_output = None

    def get_openings(self):
        level0_openings = 0
        #for layer in self.level0:
        #    level0_openings += layer[2].l0rd.openings.item()

        #level0_openings /= len(self.level0)
        return self.intrinsics[-1].openings.item(), self.pose[-1].openings.item(), level0_openings

    def get_init(self):
        return self.init.item()

    def step_init(self):
        self.init = self.init + 1

    def reset_state(self):
        self.last_flow = None
        self.last_flow_mask = None
        self.last_depth_warped = None
        self.last_output = None

    def set_level(self, level):
        self.level = level

    def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None):
        
        error = th.zeros_like(input)[:,:1]
        depth = th.zeros_like(input)[:,:1]
        flow  = th.zeros_like(input)[:,:2]

        if self.last_output is not None:
            error = th.sqrt(reduce((input - self.last_output)**2, 'b c h w -> b 1 h w', 'mean'))
            error = error.detach() * self.input_error_alpha * 10 + self.input_error_beta

        if self.last_depth_warped is not None:
            depth = self.last_depth_warped * self.input_depth_alpha + self.input_depth_beta

        if self.last_flow is not None:
            flow = self.warp[self.level].get_raw_flow(self.last_flow) 
            flow = flow.detach() * self.input_flow_alpha * 10 + self.input_flow_beta


        latent = self.to_channels[self.level](th.cat((input, error, depth, flow), dim=1))

        if self.level >= 2:
            latent = self.down_level2(latent)

        if self.level >= 1:
            latent = self.down_level1(latent)
        
        latent     = self.level0(latent)
        intrinsics = self.intrinsics(latent)
        pose       = self.pose(latent)

        if self.level >= 1:
            latent = self.up_level1(latent)
            
        if self.level >= 2:
            latent = self.up_level2(latent)

        depth = self.to_depth[self.level](latent)

        flow, flow_mask, depth = self.flow(intrinsics, pose, depth)
        rgb_warped = self.warp[self.level](input, flow)
        
        self.last_flow         = flow.detach() * flow_mask.detach()
        self.last_flow_mask    = flow_mask.detach()
        self.last_depth_warped = self.warp[self.level](depth, flow).detach() * flow_mask.detach()
        self.last_output       = rgb_warped.detach() * flow_mask.detach()

        return None, rgb_warped, depth, flow, flow_mask

        #rgb   = self.rgb_level0(latent)
        #rgb   = self.to_grid(rgb)

        #if self.level >= 1:
        #    rgb   = self.rgb_up_level1(rgb)

        #if self.level >= 2:
        #    rgb   = self.rgb_up_level2(rgb)

        #rgb   = self.to_rgb[self.level](rgb)


        #return rgb, rgb_warped, depth, flow, flow_mask

