import torch.nn as nn
import torch as th
from torch.autograd import Function
import nn as nn_modules
import utils
from nn.residual import ResidualBlock, SkipConnection
from nn.encoder import AggressiveDownConv
from nn.encoder import AggressiveConvTo1x1
from nn.decoder import AggressiveUpConv
from utils.utils import LambdaModule, ForcedAlpha, PrintShape, Warp
from nn.predictor import EpropAlphaGateL0rd
from nn.eprop_gate_l0rd import EpropGateL0rd
from nn.eprop_flow_gate_l0rd import EpropFlowGateL0rd
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce
from utils.utils import MultiArgSequential

from typing import Union, Tuple



class CameraFlow(nn.Module):
    def __init__(self, size):
        super(CameraFlow, self).__init__()

        self.register_buffer('grid', th.cat((
            repeat(th.arange(0, size[1]).float(), 'w -> 1 1 h w', h = size[0]),
            repeat(th.arange(0, size[0]).float(), 'h -> 1 1 h w', w = size[1]),
            th.ones((1, 1, *size))
        ), dim=1), persistent=False)

        self.intrinsic_bias  = nn.Parameter(th.tensor([[1,1,0,0,0]]).float())
        self.intrinsic_alpha = nn.Parameter(th.tensor([[0,0,0,0,0]])+1e-16)
        self.pose_alpha = nn.Parameter(th.tensor([[0,0,0,0,0,0]])+1e-16)

        self.depth_bias  = nn.Parameter(th.ones(1))
        self.depth_alpha = nn.Parameter(th.zeros(1)+1e-16)

    def intrinsics_to_matrix(self, intrinsics):
        intrinsics = intrinsics * self.intrinsic_alpha + self.intrinsic_bias
        fx, fy, s, ox, oy  = th.chunk(intrinsics, 5, dim=1)
        zero = th.zeros_like(fx)
        one  = th.ones_like(fx)

        return th.cat((
            fx,      s,  ox,
            zero,   fy,  oy,
            zero, zero, one,
        ), dim=1).reshape(-1, 3, 3)

    def pose_to_matrix(self, pose):
        rx, ry, rz, tx, ty, tz = th.chunk(pose * self.pose_alpha, 6, dim=1)

        zero = th.zeros_like(rx)
        one  = th.ones_like(rx)

        sinx = th.sin(rx)
        cosx = th.cos(rx)

        siny = th.sin(ry)
        cosy = th.cos(ry)

        sinz = th.sin(rz)
        cosz = th.cos(rz)

        zmat = th.cat((
            cosz, -sinz, zero,
            sinz,  cosz, zero,
            zero,  zero,  one
        ), dim=1).reshape(-1, 3, 3)

        ymat = th.cat((
            cosy,  zero, siny,
            zero,   one, zero,
            -siny, zero, cosy
        ), dim=1).reshape(-1, 3, 3)

        xmat = th.cat((
            one,  zero,  zero,
            zero, cosx, -sinx,
            zero, sinx,  cosx
        ), dim=1).reshape(-1, 3, 3)

        return th.cat((xmat @ ymat @ zmat, th.stack((tx, ty, tz), dim=1)), dim=2)


    def forward(self, intrinsics, pose, depth):

        B, _, H, W = depth.shape

        depth = depth * self.depth_alpha + self.depth_bias

        pixel_coodinates  = repeat(self.grid[:,:,:H,:W], '1 c h w -> b c (h w)', b = B)
        intrinsic_matrix  = self.intrinsics_to_matrix(intrinsics)
        pose_matrix       = self.pose_to_matrix(pose)

        projection_matrix     = intrinsic_matrix @ pose_matrix
        rotation, translation = projection_matrix[:,:,:3], projection_matrix[:,:,3:]

        camera_coodinates = intrinsic_matrix.inverse() @ pixel_coodinates
        scene_coodinates  = rearrange(depth, 'b c h w -> b c (h w)') * camera_coodinates
        scene_coodinates  = (rotation @ scene_coodinates) + translation

        X, Y, Z = th.chunk(scene_coodinates, 3, dim=1)
        Z       = th.clamp(Z, 1e-3)

        X = 2 * (X / Z) / (W - 1) - 1
        Y = 2 * (Y / Z) / (H - 1) - 1

        X_mask = ((X > 1)+(X < -1)).detach().float()
        Y_mask = ((Y > 1)+(Y < -1)).detach().float()

        # make sure that no point in warped image is a combination of image and gray
        #X = 2 * X_mask + (1 - X_mask) * X
        #Y = 2 * Y_mask + (1 - Y_mask) * Y

        flow  = rearrange(th.cat((X, Y), dim=1), ' b c (h w) -> b c h w', h=H, w=W)

        #valid_mask = (reduce(flow, 'b c h w -> b 1 h w', 'max') <= 1).float()
        valid_mask = th.cat((X_mask, Y_mask), dim=1)
        valid_mask = 1 - reduce(valid_mask, 'b c (h w) -> b 1 h w', 'max', h=H, w=W)
        return flow, valid_mask, depth

        
class EpropAlphaFlowGateL0rd(nn.Module):
    def __init__(self, num_hidden, size, batch_size, reg_lambda, alpha = 1):
        super(EpropAlphaFlowGateL0rd, self).__init__()
        
        self.alpha = nn.Parameter(th.zeros(1)+alpha)
        self.l0rd  = EpropFlowGateL0rd(
            num_inputs  = num_hidden, 
            num_hidden  = num_hidden, 
            num_outputs = num_hidden, 
            reg_lambda  = reg_lambda,
            batch_size  = batch_size,
            size        = size
        )

    def forward(self, input, flow, mask):
        return input + self.alpha * self.l0rd(input, flow, mask)

class PatchDownConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1):
        super(PatchDownConv, self).__init__()
        assert out_channels % in_channels == 0
        
        self.layers = nn.Conv2d(
            in_channels  = in_channels, 
            out_channels = out_channels, 
            kernel_size  = kernel_size,
            stride       = kernel_size,
        )

        self.alpha = nn.Parameter(th.zeros(1) + alpha)
        self.kernel_size = 4
        self.channels_factor = out_channels // in_channels

    def forward(self, input: th.Tensor):
        k = self.kernel_size
        c = self.channels_factor
        skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k)
        skip = repeat(skip, 'b c h w -> b (c n) h w', n=c)
        return skip + self.alpha * self.layers(input)

class BicubicUpscale(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1):
        super(BicubicUpscale, self).__init__()
        assert in_channels % out_channels == 0
        
        self.pre = nn.Sequential(
            nn.Conv2d(
                in_channels  = in_channels, 
                out_channels = out_channels, 
                kernel_size  = 1,
            ), # FIXME FIXME FIXME FIXME FIXME FIXME FIXME FIXME FIXME FIXME FIXME FIXME 
               # not a lesidual skp connection 
            LambdaModule(
                lambda x: nn.functional.interpolate(
                    x, 
                    scale_factor = scale_factor, 
                    mode = 'bicubic',
                    align_corners = True
                )
            )
        )

        self.post = nn.Sequential(
            LambdaModule(lambda x: th.cat(x, dim=1)),
            nn.Conv2d(
                in_channels  = out_channels * 2, 
                out_channels = out_channels, 
                kernel_size  = 3,
                padding      = 1
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels  = out_channels, 
                out_channels = out_channels, 
                kernel_size  = 3,
                padding      = 1
            )
        )

        self.alpha1 = nn.Parameter(th.zeros(1) + alpha)
        self.alpha2 = nn.Parameter(th.zeros(1) + 1e-16)

    def forward(self, lower_maps, encoder_maps):
        lower_maps = self.pre(lower_maps)
        return lower_maps + self.alpha1 * self.post((lower_maps, encoder_maps * self.alpha2))

class DepthUnet(nn.Module):
    def __init__(
        self, 
        img_channels: int, 
        level2_channels,
        level1_channels,
        latent_channels,
        level2_layers,
        level1_layers,
        latent_layers,
        batch_size,
        reg_lambda,
        size,
    ):
        super(DepthUnet, self).__init__()
        self.level = 1

        self.level2 = nn.Sequential(
            ResidualBlock(img_channels, level2_channels, alpha_residual = False),
            *[ResidualBlock(level2_channels, level2_channels, alpha_residual = False) for _ in range(level2_layers)]
        )

        self.down_level2 = PatchDownConv(level2_channels, level1_channels, alpha = 1)

        self.level1 = nn.Sequential(
            *[ResidualBlock(level1_channels, level1_channels) for _ in range(level1_layers)]
        )

        self.down_level1 = nn.Sequential(
            PatchDownConv(level1_channels, latent_channels),
            ResidualBlock(latent_channels, latent_channels)
        )

        self.level0_restnet = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for _ in range(latent_layers)]
        )
        self.level0_l0rds = nn.Sequential(
            *[EpropAlphaFlowGateL0rd(
                num_hidden = latent_channels,
                size       = size, 
                batch_size = batch_size, 
                reg_lambda = reg_lambda * 0.1
            )  for _ in range(latent_layers)]
        )

        self.up_level1 = MultiArgSequential(
            BicubicUpscale(latent_channels, level1_channels),
            *[ResidualBlock(level1_channels, level1_channels) for _ in range(level1_layers)]
        )

        self.up_level2 = MultiArgSequential(
            BicubicUpscale(level1_channels, level2_channels, alpha = 1),
            *[ResidualBlock(level2_channels, level2_channels, alpha_residual = False) for _ in range(level2_layers)]
        )

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels, latent_channels),
            SkipConnection(img_channels, level1_channels),
            SkipConnection(img_channels, img_channels),
        ])

        self.to_depth = nn.ModuleList([
            SkipConnection(latent_channels, 1),
            SkipConnection(level1_channels, 1),
            ResidualBlock(level2_channels,  1, alpha_residual = True)
        ])

    def set_level(self, level):
        self.level = level

    def get_openings(self):
        openings = 0
        for l0rd in self.level0_l0rds:
            openings += l0rd.l0rd.openings.item()

        return openings / len(self.level0_l0rds)

    def forward(self, input, last_flow, last_flow_mask):

        latent_level2 = latent_level1 = latent = self.to_channels[self.level](input)

        if self.level >= 2:
            latent_level2 = self.level2(latent_level2)
            latent_level1 = self.down_level2(latent_level2)

        if self.level >= 1:
            latent_level1 = self.level1(latent_level1)
            latent        = self.down_level1(latent_level1)

        for i in range(len(self.level0_l0rds)):
            #latent = self.level0_l0rds[i](latent, last_flow, last_flow_mask)
            latent = self.level0_restnet[i](latent)

        if self.level >= 1:
            latent = self.up_level1(latent, latent_level1)
            
        if self.level >= 2:
            latent = self.up_level2(latent, latent_level2)

        return self.to_depth[self.level](latent)

class BackgroundEnhancer(nn.Module):
    def __init__(
        self, 
        input_size: Tuple[int, int], 
        img_channels: int, 
        level1_channels,
        latent_channels,
        gestalt_size,
        batch_size,
        reg_lambda,
        vae_factor,
        deepth
    ):
        super(BackgroundEnhancer, self).__init__()

        latent_size = [input_size[0] // 16, input_size[1] // 16]
        self.input_size = input_size

        self.register_buffer('init', th.zeros(1).long())
        self.alpha = nn.Parameter(th.zeros(1)+1e-16)
        self.level = 1
        
        self.warp  = nn.ModuleList([
            Warp((input_size[0] // 16, input_size[1] // 16)),
            Warp((input_size[0] // 4,  input_size[1] // 4)),
            Warp(input_size)
        ])
        self.flow  = CameraFlow(input_size)
        self.depth = DepthUnet(
            img_channels = img_channels,
            level2_channels = level1_channels // 2,
            level1_channels = level1_channels,
            latent_channels = latent_channels,
            level2_layers   = 0,
            level1_layers   = 2,
            latent_layers   = 4,
            batch_size    = batch_size,
            reg_lambda    = reg_lambda,
            size          = latent_size
        )

        self.down_level2 = nn.Sequential(
            PatchDownConv(img_channels, level1_channels, alpha = 1),
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = False) for i in range(deepth)]
        )

        self.down_level1 = nn.Sequential(
            PatchDownConv(level1_channels, latent_channels, alpha = 1),
            *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(deepth)]
        )

        self.intrinsics = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels, latent_size),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropGateL0rd(
                num_inputs  = latent_channels,
                num_hidden  = latent_channels,
                num_outputs = 5,
                batch_size  = batch_size, 
                reg_lambda  = reg_lambda*10
            ),
        )

        self.pose = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels, latent_size),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropGateL0rd(
                num_inputs  = latent_channels,
                num_hidden  = latent_channels,
                num_outputs = 6,
                batch_size  = batch_size, 
                reg_lambda  = reg_lambda
            ),
        )

        self.rgb_level0 = nn.Sequential(
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels, latent_size),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
            EpropAlphaGateL0rd(latent_channels, batch_size, reg_lambda),
            LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
            ResidualBlock(latent_channels, gestalt_size * 2),
            VariationalFunction(factor = vae_factor),
        )

        self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size)))
        self.to_grid = nn.Sequential(
            LambdaModule(lambda x: x + self.bias),
            ResidualBlock(gestalt_size, latent_channels),
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            ResidualBlock(latent_channels, gestalt_size),
        )            

        self.rgb_up_level1 = nn.Sequential(
            ResidualBlock(gestalt_size, latent_channels),
            *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(deepth)],
            AggressiveUpConv(latent_channels, level1_channels, alpha = 1),
        )

        self.rgb_up_level2 = nn.Sequential(
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(deepth)],
            AggressiveUpConv(level1_channels, img_channels, alpha = 1e-16),
        )

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels, latent_channels),
            SkipConnection(img_channels, level1_channels),
            SkipConnection(img_channels, img_channels),
        ])

        self.to_rgb = nn.ModuleList([
            SkipConnection(latent_channels, img_channels),
            SkipConnection(level1_channels, img_channels),
            SkipConnection(img_channels,    img_channels),
        ])

        self.mask   = nn.Parameter(th.ones(1, 1, *input_size) * 10)
        self.register_buffer('latent', th.zeros((batch_size, gestalt_size, 1, 1)), persistent=False)
        self.last_flow = None
        self.last_flow_mask = None

    def get_openings(self):
        return self.intrinsics[-1].openings.item(), self.pose[-1].openings.item(), self.depth.get_openings(), self.rgb_level0[-4].l0rd.openings.item()
        #return self.intrinsics[-1].openings.item() * 0.5 + 0.5 * self.pose[-1].openings.item()

    def get_init(self):
        return self.init.item()

    def step_init(self):
        self.init = self.init + 1

    def detach(self):
        self.latent = self.latent.detach()

    def reset_state(self):
        self.latent = th.zeros_like(self.latent)
        self.last_flow = None
        self.last_flow_mask = None

    def set_level(self, level):
        self.level = level
        self.depth.set_level(level)

    def get_last_latent_gird(self):
        return self.to_grid(self.latent)

    def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None):
        latent = self.to_channels[self.level](input)

        if self.level >= 2:
            latent = self.down_level2(latent)

        if self.level >= 1:
            latent = self.down_level1(latent)
        
        intrinsics = self.intrinsics(latent)
        pose       = self.pose(latent)
        depth      = self.depth(input, self.last_flow, self.last_flow_mask)

        flow, flow_mask, depth = self.flow(intrinsics, pose, depth)
        rgb_warped = self.warp[self.level](input, flow)
        
        self.last_flow = flow.detach()
        self.last_flow_mask = flow_mask.detach()

        return None, rgb_warped, depth, flow, flow_mask

        #rgb   = self.rgb_level0(latent)
        #rgb   = self.to_grid(rgb)

        #if self.level >= 1:
        #    rgb   = self.rgb_up_level1(rgb)

        #if self.level >= 2:
        #    rgb   = self.rgb_up_level2(rgb)

        #rgb   = self.to_rgb[self.level](rgb)


        #return rgb, rgb_warped, depth, flow, flow_mask

