import torch as th
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from nn.convnext import ConvNeXtEncoder, ConvNeXtBlock
from utils.utils import LambdaModule, MultiArgSequential, Binarize
from nn.upscale import MemoryEfficientUpscaling
from nn.manifold import *
from nn.downscale import MemoryEfficientRGBDecoderStem, MemoryEfficientPatchDownScale
from nn.memory_efficient import MemoryEfficientConv3x3Residual

# TODO train realy good depth and rgb decoder, and then freeze them in loci training and use them only as "learned object loss"
# TODO nur masken dürfensich noch anpassen!!!
# TODO aslso use leared encoders to supervise the gestalt encoder in the pretrainer (and maybe also in loci by masking the object in question)
# TODO instead of freezing mybe just set super smal lerning rate for the rgb and depth decoders ???
# TODO instead of dropout, try randout (randolm noising some bits)

# TODO compute (average) gestalt over diffferent scales / and possitions for the same object and use it as target for the loci encoder
# TODO Besser: gestalt code thes unscalierten objects möglichst mittig berechnen und als target verwenden
# TODO wichtig unterschiedliche gestalt coes für gefliptes bild berechnen !!!

class HyperConv1x1(nn.Module):
    def __init__(self, channels):
        super(HyperConv1x1, self).__init__()
        
        self.linear_manifold = HyperSequential(LinearManifold(channels, channels))

        self.hyper_net = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(channels, channels*4, bias=True),
            nn.SiLU(inplace=True),
            nn.Linear(channels*4, channels * channels + channels, bias=True) # For 1x1 conv: out_channels * in_channels + out_channels
        )

    def forward(self, x):
        
        # Getting the weights from the hypernetwork
        hyper_weights = self.hyper_net(x)
        
        # Using the linear manifold for 1x1 conv
        return self.linear_manifold(x, hyper_weights)

class HyperResidual(nn.Module):
    def __init__(self, in_channels, out_channels = None):
        super(HyperResidual, self).__init__()

        if out_channels is None:
            out_channels = in_channels


        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
            nn.SiLU(inplace=True),
            HyperConv1x1(out_channels),
        )

        self.skip = nn.Identity()
        if in_channels > out_channels:
            self.skip = LambdaModule(lambda x: reduce(x, 'b (c n) h w -> b c h w', 'mean', n=in_channels//out_channels))
        if in_channels < out_channels:
            self.skip = LambdaModule(lambda x: repeat(x, 'b c h w -> b (c n) h w', n=out_channels//in_channels))

    def forward(self, x):
        return self.skip(x) + self.layers(x)
        

class Gaus2D(nn.Module):
    def __init__(self, size = None, position_limit = 1):
        super(Gaus2D, self).__init__()
        self.size = size
        self.position_limit = position_limit
        self.min_std = 0.1
        self.max_std = 0.5

        self.register_buffer("grid_x", th.zeros(1,1,1,1), persistent=False)
        self.register_buffer("grid_y", th.zeros(1,1,1,1), persistent=False)

        if size is not None:
            self.min_std = 1.0 / min(size)
            self.update_grid(size)

        print(f"Gaus2D: min std: {self.min_std}")

    def update_grid(self, size):

        if size != self.grid_x.shape[2:]:
            self.size    = size
            self.min_std = 1.0 / min(size)
            H, W = size

            self.grid_x = th.arange(W, device=self.grid_x.device)
            self.grid_y = th.arange(H, device=self.grid_x.device)

            self.grid_x = (self.grid_x / (W-1)) * 2 - 1
            self.grid_y = (self.grid_y / (H-1)) * 2 - 1

            self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, H, W).clone()
            self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, H, W).clone()

    def forward(self, input: th.Tensor, compute_std = True):
        assert input.shape[1] >= 2 and input.shape[1] <= 4
        H, W = self.size

        x   = rearrange(input[:,0:1], 'b c -> b c 1 1')
        y   = rearrange(input[:,1:2], 'b c -> b c 1 1')
        std = th.zeros_like(x)

        if input.shape[1] == 3:
            std = rearrange(input[:,2:3], 'b c -> b c 1 1')

        if input.shape[1] == 4:
            std = rearrange(input[:,3:4], 'b c -> b c 1 1')

        x   = th.clip(x, -self.position_limit, self.position_limit)
        y   = th.clip(y, -self.position_limit, self.position_limit)

        if compute_std:
            std = th.sigmoid(std) * (self.max_std - self.min_std) + self.min_std

        std = th.clip(std, self.min_std, self.max_std)
            
        std_y = std.clone()
        std_x = std * (H / W)

        return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))

class MaskCenter(nn.Module):
    def __init__(self, size, normalize=True):
        super(MaskCenter, self).__init__()

        # Get the mask dimensions
        height, width = size

        # Create meshgrid of coordinates
        if normalize:
            x_range = th.linspace(-1, 1, width)
            y_range = th.linspace(-1, 1, height)
        else:
            x_range = th.linspace(0, width, width)
            y_range = th.linspace(0, height, height)

        y_coords, x_coords = th.meshgrid(y_range, x_range)

        # Broadcast the coordinates to match the mask shape
        self.register_buffer('x_coords', x_coords[None, None, :, :], persistent=False)
        self.register_buffer('y_coords', y_coords[None, None, :, :], persistent=False)

    def forward(self, mask):

        # Compute the center of the mask for each instance in the batch
        center_x = th.sum(self.x_coords * mask, dim=(2, 3)) / (th.sum(mask, dim=(2, 3)) + 1e-8)
        center_y = th.sum(self.y_coords * mask, dim=(2, 3)) / (th.sum(mask, dim=(2, 3)) + 1e-8)

        std = (th.sum(mask, dim=(2, 3)) / th.sum(th.ones_like(mask), dim=(2, 3)))**0.5

        return th.cat((center_x, center_y, std), dim=-1)

class PositionPooling(nn.Module):
    def __init__(self, size, in_channels, out_channels): 
        super(PositionPooling, self).__init__()

        self.gaus2d = Gaus2D(size)
        self.skip = LinearSkip(in_channels, out_channels)
        self.residual = nn.Sequential(
            nn.Linear(in_channels, max(in_channels, out_channels) * 4),
            nn.SiLU(),
            nn.Linear(max(in_channels, out_channels) * 4, out_channels)
        )
        
    def forward(self, feature_maps, position):
        mask = self.gaus2d(position, compute_std=False)
        mask = mask / (reduce(mask, 'b c h w -> b 1 1 1', 'sum') + 1e-8)

        x = reduce(mask * feature_maps, 'b c h w -> b c', 'sum')
        return self.skip(x) + self.residual(x)

class OldMaskDepthToRGBUnet(nn.Module):
    def __init__(self, gestalt_size = 256, base_channels = 16, blocks = [1,2,3]):
        super(MaskDepthToRGBUnet, self).__init__()

        self.stem  = MemoryEfficientRGBDecoderStem(gestalt_size,      base_channels,     scale_factor=4, expand_ratio=16)
        self.down1 = MemoryEfficientPatchDownScale(base_channels,     base_channels * 2, scale_factor=2, expand_ratio=16)
        self.down2 = MemoryEfficientPatchDownScale(base_channels * 2, base_channels * 4, scale_factor=2, expand_ratio=16)

        # FIXME don't reduce resolution by 4 in the beginning (only by 2)
        self.avg_pool = nn.Sequential(
            LambdaModule(lambda x: reduce(x, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=2, w2=2)),
            LambdaModule(lambda x: rearrange(x, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=2, w2=2)),
        )
        self.rgb_in = nn.Conv2d(base_channels + 3 * 2 * 2, base_channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.layer0 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels)     for _ in range(blocks[0])])
        self.layer1 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels * 2) for _ in range(blocks[1])])
        self.layer2 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels * 4) for _ in range(blocks[2])])

        self.up2 = MemoryEfficientUpscaling(base_channels * 4, base_channels * 2, scale_factor=2, expand_ratio=16)
        self.up1 = MemoryEfficientUpscaling(base_channels * 4, base_channels,     scale_factor=2, expand_ratio=16)
        self.up0 = MemoryEfficientUpscaling(base_channels * 2,             3,     scale_factor=4, expand_ratio=16)

    def forward(self, position, gestalt, mask, depth, rgb):
        x0 = self.stem(position, gestalt, mask, depth)
        x1 = self.down1(self.rgb_in(th.cat((x0, self.avg_pool(rgb)), dim=1)))
        x2 = self.down2(x1)

        x0 = self.layer0(x0)
        x1 = self.layer1(x1)
        x2 = self.layer2(x2)

        x = self.up2(x2)
        x = self.up1(th.cat((x, x1), dim=1))
        x = self.up0(th.cat((x, x0), dim=1))

        return x

class MaskDepthToRGBUnet(nn.Module):
    def __init__(self, size, gestalt_size = 256, base_channels = 16, blocks = [1,2,3,4]):
        super(MaskDepthToRGBUnet, self).__init__()

        scale      = size // 16
        num_levels = min(max(1, int(np.log2(scale))), 3)  # At most 3 downscaling steps
        self.num_levels = num_levels

        self.stem  = MemoryEfficientRGBDecoderStem(gestalt_size,      base_channels,     scale_factor=2, expand_ratio=16)
        self.down1 = MemoryEfficientPatchDownScale(base_channels,     base_channels * 2, scale_factor=2, expand_ratio=16) if num_levels >= 1 else nn.Identity()
        self.down2 = MemoryEfficientPatchDownScale(base_channels * 2, base_channels * 4, scale_factor=2, expand_ratio=16) if num_levels >= 2 else nn.Identity()
        self.down3 = MemoryEfficientPatchDownScale(base_channels * 4, base_channels * 8, scale_factor=2, expand_ratio=16) if num_levels >= 3 else nn.Identity()

        self.rgb_in = nn.Conv2d(base_channels + 3, base_channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.layer0 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels)     for _ in range(blocks[0])])
        self.layer1 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels * 2) for _ in range(blocks[1])]) if num_levels >= 1 else nn.Identity()
        self.layer2 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels * 4) for _ in range(blocks[2])]) if num_levels >= 2 else nn.Identity()
        self.layer3 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels * 8) for _ in range(blocks[3])]) if num_levels >= 3 else nn.Identity()

        self.up3 = MemoryEfficientUpscaling(base_channels * 8,                             base_channels * 4, scale_factor=2, expand_ratio=16) if num_levels >= 3 else nn.Identity()
        self.up2 = MemoryEfficientUpscaling(base_channels * (8 if num_levels >= 3 else 4), base_channels * 2, scale_factor=2, expand_ratio=16) if num_levels >= 2 else nn.Identity()
        self.up1 = MemoryEfficientUpscaling(base_channels * (4 if num_levels >= 2 else 2), base_channels,     scale_factor=2, expand_ratio=16) if num_levels >= 1 else nn.Identity()
        self.up0 = MemoryEfficientUpscaling(base_channels * (2 if num_levels >= 1 else 1),             3,     scale_factor=2, expand_ratio=16) 

    def forward(self, position, gestalt, mask, depth, rgb):
        x0 = self.stem(position, gestalt, mask, depth)
        x0 = self.rgb_in(th.cat((x0, rgb), dim=1))

        x1 = self.down1(self.layer0(x0))
        x2 = self.down2(self.layer1(x1))
        x3 = self.down3(self.layer2(x2))

        x = self.up3(self.layer3(x3))
        x = self.up2(th.cat((x, x2), dim=1) if self.num_levels >= 3 else x)
        x = self.up1(th.cat((x, x1), dim=1) if self.num_levels >= 2 else x)
        x = self.up0(th.cat((x, x0), dim=1) if self.num_levels >= 1 else x)

        return x * 0.1


class RGBPretrainer(nn.Module):
    def __init__(self, size, gestalt_size=256, base_channels = 16, blocks = [1,2,3,0]):
        super(RGBPretrainer, self).__init__()

        latent_size = [size[0] // 16, size[1] // 16]

        self.encoder = ConvNeXtEncoder(3, base_channels * 4, blocks)
        self.pool   = MultiArgSequential(
            PositionPooling(latent_size, base_channels * 16, base_channels * 16), 
            Binarize()
        )

        self.mask_center = MaskCenter(size)
            
        #self.decoder = MaskDepthToRGBUnet(base_channels * 16, base_channels, blocks)

        self.decoder8 = nn.Sequential(
            nn.Linear(gestalt_size + 3, gestalt_size),
            nn.SiLU(),
            nn.Linear(gestalt_size, gestalt_size),
            nn.SiLU(),
            nn.Linear(gestalt_size, 3 * 8 * 8),
            LambdaModule(lambda x: rearrange(x, 'b (c h w) -> b c h w', c=3, h=8, w=8))
        )

        self.decoder16  = MaskDepthToRGBUnet(16, gestalt_size, base_channels)
        self.decoder32  = MaskDepthToRGBUnet(32, gestalt_size, base_channels)
        self.decoder64  = MaskDepthToRGBUnet(64, gestalt_size, base_channels)
        self.decoder128 = MaskDepthToRGBUnet(128, gestalt_size, base_channels)

    def forward(self, mask, depth, rgb):
        position = self.mask_center(mask)
        gestalt  = self.pool(self.encoder(rgb), position)

        depth16 = reduce(depth, 'b c (h h2) (w w2) -> b c h w', 'mean', h=16, w=16)
        depth32 = reduce(depth, 'b c (h h2) (w w2) -> b c h w', 'mean', h=32, w=32)
        depth64 = reduce(depth, 'b c (h h2) (w w2) -> b c h w', 'mean', h=64, w=64)
        depth128 = depth

        mask8  = reduce(mask, 'b c (h h2) (w w2) -> b c h w', 'mean', h=8,  w=8)
        mask16 = reduce(mask, 'b c (h h2) (w w2) -> b c h w', 'mean', h=16, w=16)
        mask32 = reduce(mask, 'b c (h h2) (w w2) -> b c h w', 'mean', h=32, w=32)
        mask64 = reduce(mask, 'b c (h h2) (w w2) -> b c h w', 'mean', h=64, w=64)
        mask128 = mask

        with th.no_grad():
            x = self.decoder8(th.cat((gestalt.detach(), position), dim=1))
            x2 = repeat(x, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
            x = x2 + self.decoder16(position, gestalt,  mask16,  depth16,  x)
            x2 = repeat(x, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
            x = x2 + self.decoder32(position, gestalt,  mask32,  depth32,  x)
            x2 = repeat(x, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
            x = x2 + self.decoder64(position, gestalt,  mask64,  depth64,  x)
            x2 = repeat(x, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
            x = x2 + self.decoder128(position, gestalt, mask128, depth128, x)

        rgb8  = reduce(rgb, 'b c (h h2) (w w2) -> b c h w', 'mean', h=8, w=8)
        rgb16 = reduce(rgb, 'b c (h h2) (w w2) -> b c h w', 'mean', h=16, w=16)
        rgb32 = reduce(rgb, 'b c (h h2) (w w2) -> b c h w', 'mean', h=32, w=32)
        rgb64 = reduce(rgb, 'b c (h h2) (w w2) -> b c h w', 'mean', h=64, w=64)

        target8   = rgb8
        target16  = rgb16 - repeat(rgb8, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
        target32  = rgb32 - repeat(rgb16, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
        target64  = rgb64 - repeat(rgb32, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
        target128 = rgb - repeat(rgb64, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)

        x8   = self.decoder8(th.cat((gestalt, position), dim=1))
        x16  = self.decoder16(position, gestalt,  mask16,  depth16,  rgb8)
        x32  = self.decoder32(position, gestalt,  mask32,  depth32,  rgb16)
        x64  = self.decoder64(position, gestalt,  mask64,  depth64,  rgb32)
        x128 = self.decoder128(position, gestalt, mask128, depth128, rgb64)


        loss8   = th.mean((x8   - target8)**2 * mask8)
        loss16  = th.mean((x16  - target16)**2 * mask16)
        loss32  = th.mean((x32  - target32)**2 * mask32)
        loss64  = th.mean((x64  - target64)**2 * mask64)
        loss128 = th.mean((x128 - target128)**2 * mask128)

        loss = loss8 + loss16 + loss32 + loss64 + loss128

        rec8 = x8
        rec16 = x16 + repeat(rgb8, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
        rec32 = x32 + repeat(rgb16, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
        rec64 = x64 + repeat(rgb32, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)
        rec128 = x128 + repeat(rgb64, 'b c h w -> b c (h h2) (w w2)', h2=2, w2=2)

        return loss, x, gestalt, (rec8, rec16, rec32, rec64, rec128), (rgb8, rgb16, rgb32, rgb64, rgb), (mask8, mask16, mask32, mask64, mask128)

if __name__ == '__main__':
    from data.datasets.hdf5_lightning_objects import HDF5_Dataset
    from utils.optimizers import Ranger
    from torch.utils.data import Dataset, DataLoader
    from utils.loss import L1SSIMLoss, MaskedL1SSIMLoss
    from utils.io import Timer, UEMA
    import cv2

    DEVICE = th.device('cuda:1' if th.cuda.is_available() else 'cpu')

    dataset = HDF5_Dataset('/media/chief/data/Datasets-HDF5-Compressed/Kubric-Datasets/movi-e-train-256x256.hdf5', (128,128))

    net = RGBPretrainer((128,128)).to(DEVICE)
    opt = Ranger(net.parameters(), lr=2.5e-4)

    #net.load_state_dict(th.load('net003_000000.pth'))

    print(f'Number of parameters:            {sum(p.numel() for p in net.parameters())}')
    print(f'Number of encoder parameters:    {sum(p.numel() for p in net.encoder.parameters())}')
    print(f'Number of decoder8 parameters:   {sum(p.numel() for p in net.decoder8.parameters())}')
    print(f'Number of decoder16 parameters:  {sum(p.numel() for p in net.decoder16.parameters())}')
    print(f'Number of decoder32 parameters:  {sum(p.numel() for p in net.decoder32.parameters())}')
    print(f'Number of decoder64 parameters:  {sum(p.numel() for p in net.decoder64.parameters())}')
    print(f'Number of decoder128 parameters: {sum(p.numel() for p in net.decoder128.parameters())}')

    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    l1ssim = MaskedL1SSIMLoss().to(DEVICE)

    avg_loss = UEMA()
    avg_rgb_l1 = UEMA()
    avg_rgb_ssim = UEMA()
    avg_rgb_rec_l1 = UEMA()
    avg_rgb_rec_ssim = UEMA()
    avg_bin_mean = UEMA()
    avg_bin_std = UEMA()
    avg_gestalt_mean = UEMA()

    timer = Timer()

    for epoch in range(100):
        for i, batch in enumerate(dataloader):

            rgb, depth, mask = batch[0].to(DEVICE), batch[1].to(DEVICE), batch[2].to(DEVICE)

            depth_mean = th.sum(depth * mask, dim=(1,2,3), keepdim=True) / (th.sum(mask, dim=(1,2,3), keepdim=True) + 1e-6)
            depth_std  = th.sqrt(th.sum((depth - depth_mean)**2 * mask, dim=(1,2,3), keepdim=True) / (th.sum(mask, dim=(1,2,3), keepdim=True) + 1e-6))

            depth = ((depth - depth_mean) / (depth_std + 1e-6)) * mask
            rgb = rgb * mask

            #rgb_targets, rgbs_out, masks, gestalt, rgb_rec = net(mask, depth, rgb)
            loss, rgb_rec, gestalt, rgbs_out, rgb_targets, masks  = net(mask, depth, rgb)

            gestalt_mean    = th.mean(gestalt)
            binarized_mean  = th.mean(th.minimum(th.abs(gestalt), th.abs(1 - gestalt)))
            binarized_mean2 = th.mean(th.minimum(th.abs(gestalt), th.abs(1 - gestalt))**2)
            binarized_std   = th.sqrt(binarized_mean2 - binarized_mean**2)

            _ , rgb_rec_l1, rgb_rec_ssim = l1ssim(rgb_rec, rgb, mask)

            #loss8 = th.mean((rgb_targets[0] - rgbs_out[0])**2)
            loss16,  rgb_l116,  rgb_ssim16  = l1ssim(rgb_targets[1], rgbs_out[1], masks[1])
            loss32,  rgb_l132,  rgb_ssim32  = l1ssim(rgb_targets[2], rgbs_out[2], masks[2])
            loss64,  rgb_l164,  rgb_ssim64  = l1ssim(rgb_targets[3], rgbs_out[3], masks[3])
            loss128, rgb_l1128, rgb_ssim128 = l1ssim(rgb_targets[4], rgbs_out[4], masks[4])

            #loss     = loss8       + loss16     + loss32     + loss64      + loss128
            rgb_l1   = (rgb_l116   + rgb_l132   + rgb_l164   + rgb_l1128) / 4
            rgb_ssim = (rgb_ssim16 + rgb_ssim32 + rgb_ssim64 + rgb_ssim128) / 4

            opt.zero_grad()
            loss.backward()
            opt.step()

            avg_loss.update(loss.item())
            avg_rgb_l1.update(rgb_l1.item())
            avg_rgb_ssim.update(rgb_ssim.item())
            avg_rgb_rec_l1.update(rgb_rec_l1.item())
            avg_rgb_rec_ssim.update(rgb_rec_ssim.item())
            avg_bin_mean.update(binarized_mean.item())
            avg_bin_std.update(binarized_std.item())
            avg_gestalt_mean.update(gestalt_mean.item())

            print("[{}|{}|{}|{:.2f}%] {}, Loss: {:.2e}, L1: {:.2e}|{:.2e}, SSIM: {:.2e}|{:.2e}, binar: {:.2e}|{:.2e}|{:.3f}".format(
                epoch, i, len(dataloader), i/len(dataloader)*100,
                str(timer),
                float(avg_loss),
                float(avg_rgb_l1), float(avg_rgb_rec_l1),
                float(avg_rgb_ssim), float(avg_rgb_rec_ssim),
                float(avg_bin_mean), float(avg_bin_std), float(avg_gestalt_mean)
            ), flush=True)

            if i % 100 == 0:
                for b in range(mask.shape[0]):
                    cv2.imwrite(f'out/mask_{b:06d}.jpg', mask[b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_rec_{b:06d}.jpg', rgb_rec[b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb8_{b:06d}.jpg', rgb_targets[0][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb16_{b:06d}.jpg', rgb_targets[1][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb32_{b:06d}.jpg', rgb_targets[2][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb64_{b:06d}.jpg', rgb_targets[3][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb128_{b:06d}.jpg', rgb_targets[4][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out8_{b:06d}.jpg', rgbs_out[0][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out16_{b:06d}.jpg', rgbs_out[1][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out32_{b:06d}.jpg', rgbs_out[2][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out64_{b:06d}.jpg', rgbs_out[3][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out128_{b:06d}.jpg', rgbs_out[4][b].permute(1,2,0).detach().cpu().numpy() * 255)

                    cv2.imwrite(f'out/depth_{b:06d}.jpg', th.sigmoid(depth[b]).permute(1,2,0).detach().cpu().numpy() * 255)

                    #cv2.imwrite(f'out/position2d_{epoch:03d}_{i:06d}.jpg', position2d[0].permute(1,2,0).detach().cpu().numpy() * 255)
                    #cv2.imwrite(f'out/position2d_mask_{epoch:03d}_{i:06d}.jpg', (position2d * 0.5 + mask * 0.5)[0].permute(1,2,0).detach().cpu().numpy() * 255)

                    abs_error = th.clip(th.mean(th.abs(rgb_rec[b] - rgb[b]), dim=0) * 5, 0, 1)
                    cv2.imwrite(f'out/abs_error_{b:06d}.jpg', abs_error.detach().cpu().numpy() * 255)

                #cv2.imwrite(f'out/epoch_rgb_out_{epoch:03d}_{i:06d}.jpg', rgb_out[0].permute(1,2,0).detach().cpu().numpy() * 255)
                #cv2.imwrite(f'out/epoch_rgb_{epoch:03d}_{i:06d}.jpg', rgb[0].permute(1,2,0).detach().cpu().numpy() * 255)
                #cv2.imwrite(f'out/epoch_abs_error_{epoch:03d}_{i:06d}.jpg', th.mean(th.abs(rgb_out[0] - rgb[0]), dim=0).detach().cpu().numpy() * 255)

            if i % 10000 == 0:
                th.save(net.state_dict(), f'out/net{epoch:03d}_{i:06d}.pth')



