import torch as th
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from nn.convnext import ConvNeXtEncoder, ConvNeXtBlock
from utils.utils import LambdaModule, MultiArgSequential, Binarize
from nn.upscale import MemoryEfficientUpscaling
from nn.manifold import *
from nn.downscale import MemoryEfficientRGBDecoderStem, MemoryEfficientPatchDownScale
from nn.memory_efficient import MemoryEfficientConv3x3Residual

# TODO train realy good depth and rgb decoder, and then freeze them in loci training and use them only as "learned object loss"
# TODO nur masken dürfensich noch anpassen!!!
# TODO aslso use leared encoders to supervise the gestalt encoder in the pretrainer (and maybe also in loci by masking the object in question)
# TODO instead of freezing mybe just set super smal lerning rate for the rgb and depth decoders ???
# TODO instead of dropout, try randout (randolm noising some bits)

# TODO compute (average) gestalt over diffferent scales / and possitions for the same object and use it as target for the loci encoder
# TODO Besser: gestalt code thes unscalierten objects möglichst mittig berechnen und als target verwenden
# TODO wichtig unterschiedliche gestalt coes für gefliptes bild berechnen !!!

class HyperConv1x1(nn.Module):
    def __init__(self, channels):
        super(HyperConv1x1, self).__init__()
        
        self.linear_manifold = HyperSequential(LinearManifold(channels, channels))

        self.hyper_net = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(channels, channels*4, bias=True),
            nn.SiLU(inplace=True),
            nn.Linear(channels*4, channels * channels + channels, bias=True) # For 1x1 conv: out_channels * in_channels + out_channels
        )

    def forward(self, x):
        
        # Getting the weights from the hypernetwork
        hyper_weights = self.hyper_net(x)
        
        # Using the linear manifold for 1x1 conv
        return self.linear_manifold(x, hyper_weights)

class HyperResidual(nn.Module):
    def __init__(self, in_channels, out_channels = None):
        super(HyperResidual, self).__init__()

        if out_channels is None:
            out_channels = in_channels


        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
            nn.SiLU(inplace=True),
            HyperConv1x1(out_channels),
        )

        self.skip = nn.Identity()
        if in_channels > out_channels:
            self.skip = LambdaModule(lambda x: reduce(x, 'b (c n) h w -> b c h w', 'mean', n=in_channels//out_channels))
        if in_channels < out_channels:
            self.skip = LambdaModule(lambda x: repeat(x, 'b c h w -> b (c n) h w', n=out_channels//in_channels))

    def forward(self, x):
        return self.skip(x) + self.layers(x)
        

class Gaus2D(nn.Module):
    def __init__(self, size = None, position_limit = 1):
        super(Gaus2D, self).__init__()
        self.size = size
        self.position_limit = position_limit
        self.min_std = 0.1
        self.max_std = 0.5

        self.register_buffer("grid_x", th.zeros(1,1,1,1), persistent=False)
        self.register_buffer("grid_y", th.zeros(1,1,1,1), persistent=False)

        if size is not None:
            self.min_std = 1.0 / min(size)
            self.update_grid(size)

        print(f"Gaus2D: min std: {self.min_std}")

    def update_grid(self, size):

        if size != self.grid_x.shape[2:]:
            self.size    = size
            self.min_std = 1.0 / min(size)
            H, W = size

            self.grid_x = th.arange(W, device=self.grid_x.device)
            self.grid_y = th.arange(H, device=self.grid_x.device)

            self.grid_x = (self.grid_x / (W-1)) * 2 - 1
            self.grid_y = (self.grid_y / (H-1)) * 2 - 1

            self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, H, W).clone()
            self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, H, W).clone()

    def forward(self, input: th.Tensor, compute_std = True):
        assert input.shape[1] >= 2 and input.shape[1] <= 4
        H, W = self.size

        x   = rearrange(input[:,0:1], 'b c -> b c 1 1')
        y   = rearrange(input[:,1:2], 'b c -> b c 1 1')
        std = th.zeros_like(x)

        if input.shape[1] == 3:
            std = rearrange(input[:,2:3], 'b c -> b c 1 1')

        if input.shape[1] == 4:
            std = rearrange(input[:,3:4], 'b c -> b c 1 1')

        x   = th.clip(x, -self.position_limit, self.position_limit)
        y   = th.clip(y, -self.position_limit, self.position_limit)

        if compute_std:
            std = th.sigmoid(std) * (self.max_std - self.min_std) + self.min_std

        std = th.clip(std, self.min_std, self.max_std)
            
        std_y = std.clone()
        std_x = std * (H / W)

        return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))

class MaskCenter(nn.Module):
    def __init__(self, size, normalize=True):
        super(MaskCenter, self).__init__()

        # Get the mask dimensions
        height, width = size

        # Create meshgrid of coordinates
        if normalize:
            x_range = th.linspace(-1, 1, width)
            y_range = th.linspace(-1, 1, height)
        else:
            x_range = th.linspace(0, width, width)
            y_range = th.linspace(0, height, height)

        y_coords, x_coords = th.meshgrid(y_range, x_range)

        # Broadcast the coordinates to match the mask shape
        self.register_buffer('x_coords', x_coords[None, None, :, :], persistent=False)
        self.register_buffer('y_coords', y_coords[None, None, :, :], persistent=False)

    def forward(self, mask):

        # Compute the center of the mask for each instance in the batch
        center_x = th.sum(self.x_coords * mask, dim=(2, 3)) / (th.sum(mask, dim=(2, 3)) + 1e-8)
        center_y = th.sum(self.y_coords * mask, dim=(2, 3)) / (th.sum(mask, dim=(2, 3)) + 1e-8)

        std = (th.sum(mask, dim=(2, 3)) / th.sum(th.ones_like(mask), dim=(2, 3)))**0.5

        return th.cat((center_x, center_y, std), dim=-1)

class PositionPooling(nn.Module):
    def __init__(self, size, in_channels, out_channels): 
        super(PositionPooling, self).__init__()

        self.gaus2d = Gaus2D(size)
        self.skip = LinearSkip(in_channels, out_channels)
        self.residual = nn.Sequential(
            nn.Linear(in_channels, max(in_channels, out_channels) * 4),
            nn.SiLU(),
            nn.Linear(max(in_channels, out_channels) * 4, out_channels)
        )
        
    def forward(self, feature_maps, position):
        mask = self.gaus2d(position, compute_std=False)
        mask = mask / (reduce(mask, 'b c h w -> b 1 1 1', 'sum') + 1e-8)

        x = reduce(mask * feature_maps, 'b c h w -> b c', 'sum')
        return self.skip(x) + self.residual(x)


class MaskDepthToRGBUnet(nn.Module):
    def __init__(self, gestalt_size = 256, base_channels = 16, blocks = [1,2,3,4]):
        super(MaskDepthToRGBUnet, self).__init__()

        self.stem  = MemoryEfficientRGBDecoderStem(gestalt_size,      base_channels,     scale_factor=2, expand_ratio=16)
        self.down1 = MemoryEfficientPatchDownScale(base_channels,     base_channels * 2, scale_factor=2, expand_ratio=16) 
        self.down2 = MemoryEfficientPatchDownScale(base_channels * 2, base_channels * 4, scale_factor=2, expand_ratio=16) 
        self.down3 = MemoryEfficientPatchDownScale(base_channels * 4, base_channels * 8, scale_factor=2, expand_ratio=16) 

        self.layer0 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels)     for _ in range(blocks[0])])
        self.layer1 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels * 2) for _ in range(blocks[1])]) 
        self.layer2 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels * 4) for _ in range(blocks[2])]) 
        self.layer3 = nn.Sequential(*[MemoryEfficientConv3x3Residual(base_channels * 8) for _ in range(blocks[3])]) 

        self.up3 = MemoryEfficientUpscaling(base_channels * 8, base_channels * 4, scale_factor=2, expand_ratio=16) 
        self.up2 = MemoryEfficientUpscaling(base_channels * 8, base_channels * 2, scale_factor=2, expand_ratio=16) 
        self.up1 = MemoryEfficientUpscaling(base_channels * 4, base_channels,     scale_factor=2, expand_ratio=16) 
        self.up0 = MemoryEfficientUpscaling(base_channels * 2, 3,                 scale_factor=2, expand_ratio=16) 

        #self.out3 = nn.Conv2d(base_channels * 8, 3, 1)
        #self.out2 = nn.Conv2d(base_channels * 8, 3, 1)
        #self.out1 = nn.Conv2d(base_channels * 4, 3, 1)
        #self.out0 = nn.Conv2d(base_channels * 2, 3, 1)
        self.out3 = nn.Conv2d(base_channels * 4, 3, 1)
        self.out2 = nn.Conv2d(base_channels * 2, 3, 1)
        self.out1 = nn.Conv2d(base_channels * 1, 3, 1)

    def forward(self, position, gestalt, mask, depth):
        x0 = self.stem(position, gestalt, mask, depth)
        x1 = self.down1(x0)
        x2 = self.down2(x1)
        x3 = self.down3(x2)

        x0 = self.layer0(x0)
        x1 = self.layer1(x1)
        x2 = self.layer2(x2)
        x3 = self.layer3(x3)

        #out3 = self.out3(x3)
        x = self.up3(x3)
        out3 = self.out3(x)

        #out2 = self.out2(th.cat((x, x2), dim=1) 
        x = self.up2(th.cat((x, x2), dim=1)) 
        out2 = self.out2(x)

        #out1 = self.out1(th.cat((x, x1), dim=1) 
        x = self.up1(th.cat((x, x1), dim=1))
        out1 = self.out1(x)

        #out0 = self.out0(th.cat((x, x0), dim=1) 
        x = self.up0(th.cat((x, x0), dim=1))

        out3 = F.interpolate(out3, size=x.shape[-2:], mode='bilinear', align_corners=False)
        out2 = F.interpolate(out2, size=x.shape[-2:], mode='bilinear', align_corners=False)
        out1 = F.interpolate(out1, size=x.shape[-2:], mode='bilinear', align_corners=False)

        return out3, out2, out1, x


class RGBPretrainer(nn.Module):
    def __init__(self, size, gestalt_size=256, base_channels = 16, blocks = [1,2,3,0]):
        super(RGBPretrainer, self).__init__()

        latent_size = [size[0] // 16, size[1] // 16]

        self.encoder = ConvNeXtEncoder(3, base_channels * 4, blocks)
        self.pool   = MultiArgSequential(
            PositionPooling(latent_size, base_channels * 16, base_channels * 16), 
            Binarize()
        )

        self.mask_center = MaskCenter(size)
        self.decoder = MaskDepthToRGBUnet()

    def forward(self, mask, depth, rgb):
        position = self.mask_center(mask)
        gestalt  = self.pool(self.encoder(rgb), position)

        x = self.decoder(position, gestalt,  mask,  depth)

        return x, gestalt

if __name__ == '__main__':
    from data.datasets.hdf5_lightning_objects import HDF5_Dataset
    from utils.optimizers import Ranger
    from torch.utils.data import Dataset, DataLoader
    from utils.loss import L1SSIMLoss, MaskedL1SSIMLoss
    from utils.io import Timer, UEMA
    import cv2

    DEVICE = th.device('cuda:1' if th.cuda.is_available() else 'cpu')

    dataset = HDF5_Dataset('/media/chief/data/Datasets-HDF5-Compressed/Kubric-Datasets/movi-e-train-256x256.hdf5', (128,128))

    net = RGBPretrainer((128,128)).to(DEVICE)
    opt = Ranger(net.parameters(), lr=2.5e-4)

    #net.load_state_dict(th.load('net003_000000.pth'))

    print(f'Number of parameters:            {sum(p.numel() for p in net.parameters())}')
    print(f'Number of encoder parameters:    {sum(p.numel() for p in net.encoder.parameters())}')
    print(f'Number of decoder parameters:    {sum(p.numel() for p in net.decoder.parameters())}')

    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    l1ssim = MaskedL1SSIMLoss().to(DEVICE)

    avg_loss = UEMA()
    avg_rgb_l1 = UEMA()
    avg_rgb_ssim = UEMA()
    avg_rgb_rec_l1 = UEMA()
    avg_rgb_rec_ssim = UEMA()
    avg_bin_mean = UEMA()
    avg_bin_std = UEMA()
    avg_gestalt_mean = UEMA()

    timer = Timer()

    for epoch in range(100):
        for i, batch in enumerate(dataloader):

            rgb, depth, mask = batch[0].to(DEVICE), batch[1].to(DEVICE), batch[2].to(DEVICE)

            depth_mean = th.sum(depth * mask, dim=(1,2,3), keepdim=True) / (th.sum(mask, dim=(1,2,3), keepdim=True) + 1e-6)
            depth_std  = th.sqrt(th.sum((depth - depth_mean)**2 * mask, dim=(1,2,3), keepdim=True) / (th.sum(mask, dim=(1,2,3), keepdim=True) + 1e-6))

            depth = ((depth - depth_mean) / (depth_std + 1e-6)) * mask
            rgb = rgb * mask

            #rgb_targets, rgbs_out, masks, gestalt, rgb_rec = net(mask, depth, rgb)
            rgbs_out, gestalt = net(mask, depth, rgb)

            gestalt_mean    = th.mean(gestalt)
            binarized_mean  = th.mean(th.minimum(th.abs(gestalt), th.abs(1 - gestalt)))
            binarized_mean2 = th.mean(th.minimum(th.abs(gestalt), th.abs(1 - gestalt))**2)
            binarized_std   = th.sqrt(binarized_mean2 - binarized_mean**2)

            loss = l1 =  ssim = sum = rgb_l1 = rgb_ssim = 0
            for n, rgb_out in enumerate(rgbs_out):
                cur_loss, cur_l1, cur_ssim = l1ssim(rgb_out, rgb, mask)
                loss += cur_loss / (2**n)
                l1   += cur_l1   / (2**n)
                ssim += cur_ssim / (2**n)
                sum  += 1 / (2**n)

                if n == len(rgbs_out) - 1:
                    rgb_l1   = cur_l1
                    rgb_ssim = cur_ssim

            loss /= sum
            l1   /= sum
            ssim /= sum

            opt.zero_grad()
            loss.backward()
            opt.step()

            avg_loss.update(loss.item())
            avg_rgb_l1.update(l1.item())
            avg_rgb_ssim.update(ssim.item())
            avg_rgb_rec_l1.update(rgb_l1.item())
            avg_rgb_rec_ssim.update(rgb_ssim.item())
            avg_bin_mean.update(binarized_mean.item())
            avg_bin_std.update(binarized_std.item())
            avg_gestalt_mean.update(gestalt_mean.item())

            print("[{}|{}|{}|{:.2f}%] {}, Loss: {:.2e}, L1: {:.2e}|{:.2e}, SSIM: {:.2e}|{:.2e}, binar: {:.2e}|{:.2e}|{:.3f}".format(
                epoch, i, len(dataloader), i/len(dataloader)*100,
                str(timer),
                float(avg_loss),
                float(avg_rgb_l1), float(avg_rgb_rec_l1),
                float(avg_rgb_ssim), float(avg_rgb_rec_ssim),
                float(avg_bin_mean), float(avg_bin_std), float(avg_gestalt_mean)
            ), flush=True)

            if i % 100 == 0:
                for b in range(mask.shape[0]):
                    cv2.imwrite(f'out/mask_{b:06d}.jpg', mask[b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_{b:06d}.jpg', rgb[b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out16_{b:06d}.jpg', rgbs_out[-4][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out32_{b:06d}.jpg', rgbs_out[-3][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out64_{b:06d}.jpg', rgbs_out[-2][b].permute(1,2,0).detach().cpu().numpy() * 255)
                    cv2.imwrite(f'out/rgb_out128_{b:06d}.jpg', rgbs_out[-1][b].permute(1,2,0).detach().cpu().numpy() * 255)

                    cv2.imwrite(f'out/depth_{b:06d}.jpg', th.sigmoid(depth[b]).permute(1,2,0).detach().cpu().numpy() * 255)

                    #cv2.imwrite(f'out/position2d_{epoch:03d}_{i:06d}.jpg', position2d[0].permute(1,2,0).detach().cpu().numpy() * 255)
                    #cv2.imwrite(f'out/position2d_mask_{epoch:03d}_{i:06d}.jpg', (position2d * 0.5 + mask * 0.5)[0].permute(1,2,0).detach().cpu().numpy() * 255)

                    abs_error = th.clip(th.mean(th.abs(rgb_out[b][-1] - rgb[b]), dim=0) * 5, 0, 1)
                    cv2.imwrite(f'out/abs_error_{b:06d}.jpg', abs_error.detach().cpu().numpy() * 255)

                #cv2.imwrite(f'out/epoch_rgb_out_{epoch:03d}_{i:06d}.jpg', rgb_out[0].permute(1,2,0).detach().cpu().numpy() * 255)
                #cv2.imwrite(f'out/epoch_rgb_{epoch:03d}_{i:06d}.jpg', rgb[0].permute(1,2,0).detach().cpu().numpy() * 255)
                #cv2.imwrite(f'out/epoch_abs_error_{epoch:03d}_{i:06d}.jpg', th.mean(th.abs(rgb_out[0] - rgb[0]), dim=0).detach().cpu().numpy() * 255)

            if i % 10000 == 0:
                th.save(net.state_dict(), f'out/net{epoch:03d}_{i:06d}.pth')



