import torch.nn as nn
import torch as th
from torch.autograd import Function
import nn as nn_modules
import utils
from nn.residual import ResidualBlock, SkipConnection
from nn.encoder import AggressiveDownConv
from nn.encoder import AggressiveConvTo1x1
from nn.decoder import AggressiveUpConv
from utils.utils import LambdaModule, ForcedAlpha, PrintShape, Warp, create_grid
from nn.predictor import EpropAlphaGateL0rd
from nn.eprop_gate_l0rd import EpropGateL0rd
from nn.eprop_flow_gru_l0rd import ResidualEpropFlowGruL0rd
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce
from utils.utils import MultiArgSequential

from typing import Union, Tuple



class CameraFlow(nn.Module):
    def __init__(self, size):
        super(CameraFlow, self).__init__()

        self.register_buffer('grid', th.cat((
            repeat(th.arange(0, size[1]).float(), 'w -> 1 1 h w', h = size[0]),
            repeat(th.arange(0, size[0]).float(), 'h -> 1 1 h w', w = size[1]),
            th.ones((1, 1, *size))
        ), dim=1), persistent=False)

        self.intrinsic_bias  = nn.Parameter(th.tensor([[1,1,0,0,0]]).float())
        self.intrinsic_alpha = nn.Parameter(th.tensor([[0,0,0,0,0]])+1e-16)
        self.pose_alpha = nn.Parameter(th.tensor([[0,0,0,0,0,0]])+1e-16)

        self.depth_bias  = nn.Parameter(th.ones(1))
        self.depth_alpha = nn.Parameter(th.zeros(1)+1e-16)

    def intrinsics_to_matrix(self, intrinsics):
        intrinsics = intrinsics * self.intrinsic_alpha + self.intrinsic_bias
        fx, fy, s, ox, oy  = th.chunk(intrinsics, 5, dim=1)
        zero = th.zeros_like(fx)
        one  = th.ones_like(fx)

        return th.cat((
            fx,      s,  ox,
            zero,   fy,  oy,
            zero, zero, one,
        ), dim=1).reshape(-1, 3, 3)

    def pose_to_matrix(self, pose):
        rx, ry, rz, tx, ty, tz = th.chunk(pose * self.pose_alpha, 6, dim=1)

        zero = th.zeros_like(rx)
        one  = th.ones_like(rx)

        sinx = th.sin(rx)
        cosx = th.cos(rx)

        siny = th.sin(ry)
        cosy = th.cos(ry)

        sinz = th.sin(rz)
        cosz = th.cos(rz)

        zmat = th.cat((
            cosz, -sinz, zero,
            sinz,  cosz, zero,
            zero,  zero,  one
        ), dim=1).reshape(-1, 3, 3)

        ymat = th.cat((
            cosy,  zero, siny,
            zero,   one, zero,
            -siny, zero, cosy
        ), dim=1).reshape(-1, 3, 3)

        xmat = th.cat((
            one,  zero,  zero,
            zero, cosx, -sinx,
            zero, sinx,  cosx
        ), dim=1).reshape(-1, 3, 3)

        return th.cat((xmat @ ymat @ zmat, th.stack((tx, ty, tz), dim=1)), dim=2)


    def forward(self, intrinsics, pose, depth):

        B, _, H, W = depth.shape

        depth = depth * self.depth_alpha + self.depth_bias

        pixel_coodinates  = repeat(self.grid[:,:,:H,:W], '1 c h w -> b c (h w)', b = B)
        intrinsic_matrix  = self.intrinsics_to_matrix(intrinsics)
        pose_matrix       = self.pose_to_matrix(pose)

        projection_matrix     = intrinsic_matrix @ pose_matrix
        rotation, translation = projection_matrix[:,:,:3], projection_matrix[:,:,3:]

        camera_coodinates = intrinsic_matrix.inverse() @ pixel_coodinates
        scene_coodinates  = rearrange(depth, 'b c h w -> b c (h w)') * camera_coodinates
        scene_coodinates  = (rotation @ scene_coodinates) + translation

        X, Y, Z = th.chunk(scene_coodinates, 3, dim=1)
        Z       = th.clamp(Z, 1e-3)

        X = 2 * (X / Z) / (W - 1) - 1
        Y = 2 * (Y / Z) / (H - 1) - 1

        X_mask = ((X > 1)+(X < -1)).detach().float()
        Y_mask = ((Y > 1)+(Y < -1)).detach().float()

        flow  = rearrange(th.cat((X, Y), dim=1), ' b c (h w) -> b c h w', h=H, w=W)

        valid_mask = th.cat((X_mask, Y_mask), dim=1)
        valid_mask = 1 - reduce(valid_mask, 'b c (h w) -> b 1 h w', 'max', h=H, w=W)
        return flow, valid_mask, depth

        
class PatchDownConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1):
        super(PatchDownConv, self).__init__()
        assert out_channels % in_channels == 0
        
        self.layers = nn.Conv2d(
            in_channels  = in_channels, 
            out_channels = out_channels, 
            kernel_size  = kernel_size,
            stride       = kernel_size,
        )

        self.alpha = nn.Parameter(th.zeros(1) + alpha)
        self.kernel_size = 4
        self.channels_factor = out_channels // in_channels

    def forward(self, input: th.Tensor):
        k = self.kernel_size
        c = self.channels_factor
        skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k)
        skip = repeat(skip, 'b c h w -> b (c n) h w', n=c)
        return skip + self.alpha * self.layers(input)

class BicubicUpscale(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1):
        super(BicubicUpscale, self).__init__()
        assert in_channels % out_channels == 0
        
        self.pre = nn.Sequential(
            nn.Conv2d(
                in_channels  = in_channels, 
                out_channels = out_channels, 
                kernel_size  = 1,
            ),
            LambdaModule(
                lambda x: nn.functional.interpolate(
                    x, 
                    scale_factor = scale_factor, 
                    mode = 'bicubic',
                    align_corners = True
                )
            )
        )

        self.post = nn.Sequential(
            LambdaModule(lambda x: th.cat(x, dim=1)),
            nn.Conv2d(
                in_channels  = out_channels * 2, 
                out_channels = out_channels, 
                kernel_size  = 3,
                padding      = 1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels  = out_channels, 
                out_channels = out_channels, 
                kernel_size  = 3,
                padding      = 1
            )
        )

        self.alpha1 = nn.Parameter(th.zeros(1) + alpha)
        self.alpha2 = nn.Parameter(th.zeros(1) + 1e-16)

    def forward(self, lower_maps, encoder_maps):
        lower_maps = self.pre(lower_maps)
        return lower_maps + self.alpha1 * self.post((lower_maps, encoder_maps * self.alpha2))

class DepthUnet(nn.Module):
    def __init__(
        self, 
        img_channels: int, 
        input_size,
        level2_channels,
        level1_channels,
        latent_channels,
        level2_layers,
        level1_layers,
        latent_layers,
        batch_size,
        reg_lambda,
        size,
    ):
        super(DepthUnet, self).__init__()
        self.level = 1

        latent_size = [input_size[0] // 16, input_size[1] // 16]
        level1_size = [input_size[0] //  4, input_size[1] //  4]

        self.level2 = nn.Sequential(
            ResidualBlock(img_channels + 1, level2_channels, alpha_residual = False),
            *[ResidualBlock(level2_channels, level2_channels, alpha_residual = False) for _ in range(level2_layers)]
        )

        self.down_level2 = PatchDownConv(level2_channels, level1_channels, alpha = 1)

        self.level1 = nn.Sequential(
            *[ResidualBlock(level1_channels, level1_channels) for _ in range(level1_layers)]
        )

        self.down_level1 = nn.Sequential(
            PatchDownConv(level1_channels, latent_channels),
            ResidualBlock(latent_channels, latent_channels)
        )

        self.level0 = MultiArgSequential(
            *[ResidualEpropFlowGruL0rd(
                num_inputs = latent_channels, 
                num_hidden = latent_channels // 2,
                batch_size = batch_size,
                size       = latent_size,
                reg_lambda = reg_lambda * 0.01
            ) for _ in range(latent_layers)]
        )

        self.upscale_level1 = BicubicUpscale(latent_channels, level1_channels)

        self.up_level1 = MultiArgSequential(
            *[ResidualEpropFlowGruL0rd(
                num_inputs = level1_channels, 
                num_hidden = level1_channels // 2,
                batch_size = batch_size,
                size       = level1_size,
                reg_lambda = reg_lambda * 0.01
            ) for _ in range(level1_layers)]
        )

        self.up_level2 = MultiArgSequential(
            BicubicUpscale(level1_channels, level2_channels, alpha = 1),
            *[ResidualBlock(level2_channels, level2_channels, alpha_residual = False) for _ in range(level2_layers)]
        )

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels + 1, latent_channels),
            SkipConnection(img_channels + 1, level1_channels),
            SkipConnection(img_channels + 1, img_channels + 1),
        ])

        self.to_depth = nn.ModuleList([
            SkipConnection(latent_channels, 1),
            SkipConnection(level1_channels, 1),
            ResidualBlock(level2_channels,  1, alpha_residual = True)
        ])

    def set_level(self, level):
        self.level = level

    def get_openings(self):
        openings = 0
        for l0rd in self.level0:
            openings += l0rd.l0rd.openings.item()

        for l0rd in self.up_level1:
            openings += l0rd.l0rd.openings.item()

        return openings / (len(self.level0) + len(self.up_level1))

    def forward(self, input, last_flow, last_depth_warped):
        
        input = th.cat((input, last_depth_warped), dim = 1)
        latent_level2 = latent_level1 = latent = self.to_channels[self.level](input)

        if self.level >= 2:
            latent_level2 = self.level2(latent_level2)
            latent_level1 = self.down_level2(latent_level2)

        if self.level >= 1:
            latent_level1 = self.level1(latent_level1)
            latent        = self.down_level1(latent_level1)

        latent = self.level0(latent, last_flow)[0]

        if self.level >= 1:
            latent = self.upscale_level1(latent, latent_level1)
            latent = self.up_level1(latent, last_flow)[0]
            
        if self.level >= 2:
            latent = self.up_level2(latent, latent_level2)

        return self.to_depth[self.level](latent)

class BackgroundEnhancer(nn.Module):
    def __init__(
        self, 
        input_size: Tuple[int, int], 
        img_channels: int, 
        level1_channels,
        latent_channels,
        gestalt_size,
        batch_size,
        reg_lambda,
        vae_factor,
        deepth
    ):
        super(BackgroundEnhancer, self).__init__()

        latent_size = [input_size[0] // 16, input_size[1] // 16]
        self.input_size = input_size
        self.batch_size = batch_size

        self.register_buffer('init', th.zeros(1).long())
        self.alpha = nn.Parameter(th.zeros(1)+1e-16)
        self.level = 1
        
        self.warp  = nn.ModuleList([
            Warp((input_size[0] // 16, input_size[1] // 16)),
            Warp((input_size[0] // 4,  input_size[1] // 4)),
            Warp(input_size)
        ])
        self.flow  = CameraFlow(input_size)
        self.depth = DepthUnet(
            img_channels = img_channels,
            input_size   = input_size,
            level2_channels = level1_channels // 2,
            level1_channels = level1_channels,
            latent_channels = latent_channels,
            level2_layers   = 0,
            level1_layers   = 2,
            latent_layers   = 4,
            batch_size    = batch_size,
            reg_lambda    = reg_lambda,
            size          = latent_size
        )

        self.down_level2 = nn.Sequential(
            PatchDownConv(img_channels, level1_channels, alpha = 1),
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = False) for i in range(deepth)]
        )

        self.down_level1 = nn.Sequential(
            PatchDownConv(level1_channels, latent_channels, alpha = 1),
            *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(deepth)]
        )

        self.intrinsics = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels, latent_size),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropGateL0rd(
                num_inputs  = latent_channels,
                num_hidden  = latent_channels,
                num_outputs = 5,
                batch_size  = batch_size, 
                reg_lambda  = reg_lambda*10
            ),
        )

        self.pose = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels, latent_size),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropGateL0rd(
                num_inputs  = latent_channels,
                num_hidden  = latent_channels,
                num_outputs = 6,
                batch_size  = batch_size, 
                reg_lambda  = reg_lambda
            ),
        )

        self.rgb_level0 = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels, latent_size),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropAlphaGateL0rd(latent_channels, batch_size, reg_lambda),
            LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
            ResidualBlock(latent_channels, gestalt_size * 2),
            VariationalFunction(factor = vae_factor),
        )

        self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size)))
        self.to_grid = nn.Sequential(
            LambdaModule(lambda x: x + self.bias),
            ResidualBlock(gestalt_size, latent_channels),
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            ResidualBlock(latent_channels, gestalt_size),
        )            

        self.rgb_up_level1 = nn.Sequential(
            ResidualBlock(gestalt_size, latent_channels),
            *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(deepth)],
            AggressiveUpConv(latent_channels, level1_channels, alpha = 1),
        )

        self.rgb_up_level2 = nn.Sequential(
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(deepth)],
            AggressiveUpConv(level1_channels, img_channels, alpha = 1e-16),
        )

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels, latent_channels),
            SkipConnection(img_channels, level1_channels),
            SkipConnection(img_channels, img_channels),
        ])

        self.to_rgb = nn.ModuleList([
            SkipConnection(latent_channels, img_channels),
            SkipConnection(level1_channels, img_channels),
            SkipConnection(img_channels,    img_channels),
        ])

        self.mask   = nn.Parameter(th.ones(1, 1, *input_size) * 10)
        self.register_buffer('latent', th.zeros((batch_size, gestalt_size, 1, 1)), persistent=False)
        self.register_buffer('zero_flow', create_grid(input_size), persistent=False)
        self.register_buffer('last_flow', repeat(self.zero_flow, ' 1 c h w -> b c h w', b = batch_size), persistent=False)
        self.register_buffer('last_depth_warped', th.zeros(batch_size, 1, *input_size), persistent=False)

    def get_openings(self):
        return self.intrinsics[-1].openings.item(), self.pose[-1].openings.item(), self.depth.get_openings(), self.rgb_level0[-4].l0rd.openings.item()
        #return self.intrinsics[-1].openings.item() * 0.5 + 0.5 * self.pose[-1].openings.item()

    def get_init(self):
        return self.init.item()

    def step_init(self):
        self.init = self.init + 1

    def detach(self):
        self.latent = self.latent.detach()
        self.last_flow = self.last_flow.detach()
        self.last_depth_warped = self.last_depth_warped.detach()

    def reset_state(self):
        self.latent = th.zeros_like(self.latent)
        self.last_depth_warped = th.zeros_like(self.last_depth_warped)
        self.last_flow = repeat(self.zero_flow, ' 1 c h w -> b c h w', b = self.batch_size)
        self.last_flow_mask = None

    def set_level(self, level):
        self.level = level
        self.depth.set_level(level)

    def get_last_latent_gird(self):
        return self.to_grid(self.latent)

    def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None):
        latent = self.to_channels[self.level](input)

        if self.level >= 2:
            latent = self.down_level2(latent)

        if self.level >= 1:
            latent = self.down_level1(latent)
        
        intrinsics = self.intrinsics(latent)
        pose       = self.pose(latent)
        depth      = self.depth(input, self.last_flow, self.last_depth_warped)

        flow, flow_mask, depth = self.flow(intrinsics, pose, depth)
        rgb_warped = self.warp[self.level](input, flow)
        self.last_depth_warped = self.warp[self.level](depth, flow)
        
        self.last_flow = flow.detach()
        self.last_flow_mask = flow_mask.detach()

        return None, rgb_warped, depth, flow, flow_mask

        #rgb   = self.rgb_level0(latent)
        #rgb   = self.to_grid(rgb)

        #if self.level >= 1:
        #    rgb   = self.rgb_up_level1(rgb)

        #if self.level >= 2:
        #    rgb   = self.rgb_up_level2(rgb)

        #rgb   = self.to_rgb[self.level](rgb)


        #return rgb, rgb_warped, depth, flow, flow_mask

