import torch as th
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from nn.encoder import PatchDownConv, ConvNeXtBlock
from utils.utils import LambdaModule, MultiArgSequential
from nn.manifold import *

# TODO test different object depth to rgb architectures 
# TODO (mask object and depth befor processing)
# TODO use dept to space and space to depth in frist layer to get a real high res skip connection!!!
# FIXME fix data augmentation resizing!!!!

class ShuffelUpDown(nn.Module):
    def __init__(self, mode='down', scale_factor=2):
        super(ShuffelUpDown, self).__init__()
        if mode == 'up':
            self.f = nn.PixelShuffle(scale_factor)
        elif mode == 'down':
            self.f = nn.PixelUnshuffle(scale_factor)

    def forward(self, x):
        return self.f(x)

class UnetDown(nn.Module):
    def __init__(self, in_channels, base_channels, layers):
        super(UnetDown, self).__init__()
        assert base_channels % (in_channels * 16) == 0

        self.shuffel = ShuffelUpDown('down', 4)
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels*16, base_channels, 3, padding=1),
            *[nn.Sequential(nn.SiLU(), nn.Conv2d(base_channels, base_channels, 3, padding=1)) for _ in range(layers-1)],
        )
        self.skip = LambdaModule(lambda x: repeat(x, 'b c h w -> b (n c) h w', n=base_channels//(in_channels*16)))

        self.alpha = nn.Parameter(th.ones(1) * 1e-6)

    def forward(self, x):
        x = self.shuffel(x)
        return self.layers(x) * self.alpha + self.skip(x)

class UnetUp(nn.Module):
    def __init__(self, base_channels, out_channels, layers):
        super(UnetUp, self).__init__()
        out_base_channels = max(base_channels, out_channels*16)

        self.shuffel = ShuffelUpDown('up', 4)
        self.layers = nn.Sequential(
            nn.Conv2d(base_channels, out_base_channels, 3, padding=1),
            *[nn.Sequential(nn.SiLU(), nn.Conv2d(out_base_channels, out_base_channels, 3, padding=1)) for _ in range(layers-2)],
            nn.SiLU(),
            nn.Conv2d(out_base_channels, out_channels * 16, 3, padding=1),
        )
        if out_base_channels > base_channels:
            self.skip = LambdaModule(lambda x: repeat(x, 'b c h w -> b (n c) h w', n=out_base_channels//base_channels))
        elif out_base_channels < base_channels:
            self.skip = LambdaModule(lambda x: reduce(x, 'b (n c) h w -> b c h w', 'mean', n=base_channels//out_base_channels))
        else:
            self.skip = LambdaModule(lambda x: x)
        
        self.alpha = nn.Parameter(th.ones(1) * 1e-6)

    def forward(self, x):
        return self.shuffel(self.layers(x) * self.alpha + self.skip(x))

class PatchDownConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4):
        super(PatchDownConv, self).__init__()
        assert out_channels % in_channels == 0
        
        self.layers = nn.Linear(in_channels * kernel_size**2, out_channels)

        self.kernel_size     = kernel_size
        self.channels_factor = out_channels // in_channels

        self.alpha = nn.Parameter(th.ones(1) * 1e-6)

    def forward(self, input: th.Tensor):
        H, W = input.shape[2:]
        K    = self.kernel_size
        C    = self.channels_factor

        skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=K, w2=K)
        skip = repeat(skip, 'b c h w -> b (c n) h w', n=C)

        input    = rearrange(input, 'b c (h h2) (w w2) -> (b h w) (c h2 w2)', h2=K, w2=K)
        residual = self.layers(input)
        residual = rearrange(residual, '(b h w) c -> b c h w', h = H // K, w = W // K)

        return skip + residual * self.alpha

class PatchUpscale(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4):
        super(PatchUpscale, self).__init__()
        assert in_channels % out_channels == 0
        
        self.skip = SkipConnection(in_channels, out_channels, scale_factor=kernel_size)

        self.residual = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(
                in_channels  = in_channels, 
                out_channels = in_channels, 
                kernel_size  = 3,
                padding      = 1
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels  = in_channels, 
                out_channels = out_channels, 
                kernel_size  = kernel_size,
                stride       = kernel_size,
            ),
        )

        self.alpha = nn.Parameter(th.ones(1) * 1e-6)

    def forward(self, input):
        return self.skip(input) + self.residual(input) * self.alpha

class HyperConvNeXtUnet(nn.Module):
    def __init__(
        self, 
        in_channels,
        out_channels, 
        base_channels, 
        blocks=[3,3,9,3], 
        norm_group_size = 32, 
        alpha = 1e-6,
        hyper_in_channels = 16,
        hyper_channels = 32,
        hyper_layers = 3,
    ):
        super(HyperConvNeXtUnet, self).__init__()
        assert base_channels >= in_channels * 16
        out_base_channels = max(base_channels, out_channels*16)

        self.down0 = UnetDown(in_channels, base_channels, blocks[0])
        
        self.down1 = HyperSequential(
            NonHyperWrapper(PatchDownConv(base_channels, base_channels * 2, 2)),
            *[HyperConvNextBlock(base_channels * 2, alpha=1e-6) for _ in range(blocks[1])],
        )
        self.down1_weights = nn.Sequential(
            nn.Linear(hyper_in_channels, hyper_channels),
            *[nn.Sequential(nn.SiLU(), nn.Linear(hyper_channels, hyper_channels)) for _ in range(hyper_layers-2)],
            nn.SiLU(),
            nn.Linear(hyper_channels, self.down1.num_weights()),
        )
            
        self.down2 = HyperSequential(
            NonHyperWrapper(PatchDownConv(base_channels * 2, base_channels * 4, 2)),
            *[HyperConvNextBlock(base_channels * 4, alpha=1e-6) for _ in range(blocks[2])],

        )
        self.down2_weights= nn.Sequential(
            nn.Linear(hyper_in_channels, hyper_channels),
            *[nn.Sequential(nn.SiLU(), nn.Linear(hyper_channels, hyper_channels)) for _ in range(hyper_layers-2)],
            nn.SiLU(),
            nn.Linear(hyper_channels, self.down2.num_weights()),
        )

        self.bottom = HyperSequential(
            NonHyperWrapper(PatchDownConv(base_channels * 4, base_channels * 8, 2)),
            *[HyperConvNextBlock(base_channels * 8, alpha=1e-6) for _ in range(blocks[3])],
            NonHyperWrapper(PatchUpscale(base_channels * 8, base_channels * 4, 2)),
        )
        self.bottom_weights = nn.Sequential(
            nn.Linear(hyper_in_channels, hyper_channels),
            *[nn.Sequential(nn.SiLU(), nn.Linear(hyper_channels, hyper_channels)) for _ in range(hyper_layers-2)],
            nn.SiLU(),
            nn.Linear(hyper_channels, self.bottom.num_weights()),
        )

        self.up2 = HyperSequential(
            *[HyperConvNextBlock(base_channels * 4, alpha=1e-6) for _ in range(blocks[2])],
            NonHyperWrapper(PatchUpscale(base_channels * 4, base_channels * 2, 2)),
        )
        self.up2_weights = nn.Sequential(
            nn.Linear(hyper_in_channels, hyper_channels),
            *[nn.Sequential(nn.SiLU(), nn.Linear(hyper_channels, hyper_channels)) for _ in range(hyper_layers-2)],
            nn.SiLU(),
            nn.Linear(hyper_channels, self.up2.num_weights()),
        )

        self.up1 = HyperSequential(
            *[HyperConvNextBlock(base_channels * 2, alpha=1e-6) for _ in range(blocks[1])],
            NonHyperWrapper(PatchUpscale(base_channels * 2, base_channels, 2)),
        )
        self.up1_weights = nn.Sequential(
            nn.Linear(hyper_in_channels, hyper_channels),
            *[nn.Sequential(nn.SiLU(), nn.Linear(hyper_channels, hyper_channels)) for _ in range(hyper_layers-2)],
            nn.SiLU(),
            nn.Linear(hyper_channels, self.up1.num_weights()),
        )

        self.up0 = UnetUp(base_channels, out_channels, blocks[0])

    def forward(self, x, c):
        x0 = self.down0(x)
        x1 = self.down1(x0, self.down1_weights(c))
        x2 = self.down2(x1, self.down2_weights(c))
        x = self.bottom(x2, self.bottom_weights(c))
        x = self.up2(x + x2, self.up2_weights(c))
        x = self.up1(x + x1, self.up1_weights(c))
        x = self.up0(x + x0)
        return x

# TODO build a mask decoder from patches (output coodebook!!!!) (No this is just a de convolution)
# TODO build a decoder just from deconvolutions !!!
# TODO build memory efficient (patch) upscale with inverted bottleneck (linear n -> 4n, conv transposed 4n -> n)


class Net(nn.Module):
    def __init__(self, in_channels = 1, out_channels = 3, base_channels = 16, blocks = [2,1,1,2]):
        super(Net, self).__init__()
        self.unet = HyperConvNeXtUnet(in_channels, out_channels, base_channels, blocks, hyper_in_channels=out_channels*256, hyper_channels=out_channels*256, hyper_layers=3)

        self.encoder = nn.Sequential(
            PatchDownConv(out_channels, out_channels * 32),
            ConvNeXtBlock(out_channels * 32),
            PatchDownConv(out_channels * 32, out_channels * 64, 2),
            ConvNeXtBlock(out_channels * 64),
            ConvNeXtBlock(out_channels * 64),
            PatchDownConv(out_channels * 64, out_channels * 128, 2),
            ConvNeXtBlock(out_channels * 128),
            ConvNeXtBlock(out_channels * 128),
            ConvNeXtBlock(out_channels * 128),
            PatchDownConv(out_channels * 128, out_channels * 256, 2),
            ConvNeXtBlock(out_channels * 256),
            ConvNeXtBlock(out_channels * 256),
            ConvNeXtBlock(out_channels * 256),
            ConvNeXtBlock(out_channels * 256),
            LambdaModule(lambda x: reduce(x, 'b c h w -> b c', 'max')),
        )
            

    def forward(self, x, c):
        return self.unet(x, self.encoder(c))
    

if __name__ == '__main__':
    from data.datasets.hdf5_lightning_objects import HDF5_Dataset
    from utils.optimizers import Ranger
    from torch.utils.data import Dataset, DataLoader
    from utils.loss import L1SSIMLoss, MaskedL1SSIMLoss
    from utils.io import Timer, UEMA
    import cv2

    DEVICE = th.device('cuda:1' if th.cuda.is_available() else 'cpu')

    dataset = HDF5_Dataset('/media/chief/data/Datasets-HDF5-Compressed/Kubric-Datasets/movi-e-train-256x256.hdf5', (128,128), max_scale_factor=0.5)

    net = Net().to(DEVICE)
    opt = Ranger(net.parameters(), lr=1e-4)

    #net.load_state_dict(th.load('convnext/revnet003_000000.pth'))

    print(f'Number of parameters: {sum(p.numel() for p in net.parameters())}')

    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    #l1ssim = L1SSIMLoss()
    l1ssim = MaskedL1SSIMLoss()

    avg_loss = UEMA()
    avg_rgb_l1 = UEMA()
    avg_rgb_ssim = UEMA()

    timer = Timer()

    for epoch in range(100):
        for i, batch in enumerate(dataloader):

            rgb, depth, mask = batch[0], batch[1], batch[2]


            rgb = rgb.to(DEVICE)
            depth = depth.to(DEVICE)
            mask = mask.to(DEVICE)

            #rgb = rgb * mask
            #depth = depth * mask

            #mean_depth = th.sum(depth * mask, dim=(1,2,3), keepdim=True) / (th.sum(mask, dim=(1,2,3), keepdim=True) + 1e-8)
            #std_depth  = th.sqrt(th.sum((depth - mean_depth)**2 * mask, dim=(1,2,3), keepdim=True) / (th.sum(mask, dim=(1,2,3), keepdim=True) + 1e-8))
            #norm_depth = (depth - mean_depth) * mask

            rgb_out = net(depth, rgb)

            loss, rgb_l1, rgb_ssim = l1ssim(rgb_out, rgb, mask)

            opt.zero_grad()
            loss.backward()
            opt.step()

            avg_loss.update(loss.item())
            avg_rgb_l1.update(rgb_l1.item())
            avg_rgb_ssim.update(rgb_ssim.item())

            print("[{}|{}|{}|{:.2f}%] {}, Loss: {:.2e}, L1: {:.2e}, SSIM: {:.2e}".format(
                epoch, i, len(dataloader), i/len(dataloader)*100,
                str(timer),
                float(avg_loss),
                float(avg_rgb_l1), float(avg_rgb_ssim),
            ), flush=True)

            if i % 100 == 0:
                cv2.imwrite(f'out/rgb_{epoch:03d}_{i:06d}.jpg', rgb[0].permute(1,2,0).detach().cpu().numpy() * 255)
                cv2.imwrite(f'out/rgb_out_{epoch:03d}_{i:06d}.jpg', rgb_out[0].permute(1,2,0).detach().cpu().numpy() * 255)

            if i % 10000 == 0:
                th.save(net.state_dict(), f'out/net{epoch:03d}_{i:06d}.pth')



