#%%
import sys
sys.path.append("..")
sys.path.append("../..")
sys.path.append("../../..")


import torch
from torch import nn
from utils.mrctools import * 
from utils.visualization import *
from my_tomotwin.modules.networks.SiameseNet3D import * 

#%%
# maps number of channels to spatial dimensions in TomoTwin SiameseNet3Damese architecture, this info is needed for the decoder
NUM_CHANS_TO_SPAT_DIM_37 = {
    64: 18,  # example: if inpput spatial fimensions are 37x37x37, features after first encoder block have 64 channels and spatial dimensions of 18x18x18
    128: 14, 
    256: 10,
    512: 6,
    1024: 2,
}

NUM_CHANS_TO_SPAT_DIM_64 = {
    64: 31,  # example: if inpput spatial fimensions are 64x64x64, features after first encoder block have 64 channels and spatial dimensions of 31x31x31
    128: 27,
    256: 23,
    512: 19,
    1024: 15,
}


class ConvLayer(nn.Module):
    def __init__(self, in_chans: int, out_chans: int, activation=nn.LeakyReLU()) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv3d(in_chans, out_chans, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(out_chans, out_chans, eps=1e-05, affine=True),
            activation,
        )
    
    def forward(self, x):
        return self.layers(x)


class SpatialUpsampling(nn.Module):
    def __init__(self, out_size, in_chans: int, out_chans: int, activation=nn.LeakyReLU()) -> None:
        super().__init__()
        self.layers = nn.Sequential(
            nn.Upsample(size=out_size, mode='trilinear'),
            ConvLayer(in_chans, out_chans, activation=activation),
        )
    
    def forward(self, x):
        return self.layers(x)

class LatentToFILM(nn.Module):
    def __init__(self, out_dim: int, in_dim: int=32, activation=nn.LeakyReLU()) -> None:
        super().__init__()
        if activation is None:
            activation = nn.Identity()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            #nn.GroupNorm(out_dim, out_dim, eps=1e-05, affine=True),
            activation,
        )
    
    def forward(self, x):
        return self.layers(x)


class DecoderBlock(nn.Module):
    def __init__(self, out_size, in_chans, out_chans, prompt_dim=32, film_activation=nn.LeakyReLU()):
        super().__init__()
        self.out_size = out_size
        self.in_chans = in_chans
        self.out_chans = out_chans
        self.prompt_dim = prompt_dim
        self.upsampler = SpatialUpsampling(
            out_size=out_size,
            in_chans=in_chans,
            out_chans=out_chans,
        )
        self.conv_layer0 = ConvLayer(2*out_chans, out_chans)
        self.conv_layer1 = ConvLayer(out_chans, out_chans)
        self.film_scale_layer = LatentToFILM(in_dim=prompt_dim, out_dim=out_chans, activation=film_activation)
        self.film_loc_layer = LatentToFILM(in_dim=prompt_dim, out_dim=out_chans, activation=film_activation)

    def forward(self, x, cat, prompt):
        x = self.upsampler(x)
        x = torch.concat((x, cat), dim=1)
        x = self.conv_layer0(x)
        if prompt is not None:
            film_slope = self.film_scale_layer(prompt).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            film_intercept = self.film_loc_layer(prompt).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            x = x * film_slope + film_intercept
        x = self.conv_layer1(x)
        return x

class SiameseNet3DDecoder(nn.Module):
    def __init__(self, out_size=64, prompt_dim=32, out_chans=1, film_activation=nn.LeakyReLU(), final_activation=nn.Sigmoid(), encoder_cat_layer_ids=[0, 1, 2, 3]):
        super().__init__()
        self.out_size = out_size
        self.prompt_dim = prompt_dim
        self.out_chans = out_chans
        self.film_activation = film_activation
        self.final_activation = final_activation
        self.encoder_cat_layer_ids = encoder_cat_layer_ids
        self.decoder_blocks = nn.ModuleList()
        if out_size == 37:
            NUM_CHANS_TO_SPAT_DIM = NUM_CHANS_TO_SPAT_DIM_37
        elif out_size == 64:
            NUM_CHANS_TO_SPAT_DIM = NUM_CHANS_TO_SPAT_DIM_64
        else:
            raise ValueError(f"Unsupported out_size: {out_size}")
        num_chans_list = [512, 256, 128, 64]
        num_chans_list = [1024] + [num_chans_list[i] for i in encoder_cat_layer_ids]  # 1024 is alays used as this is the decoder input size
        for k in range(len(num_chans_list)-1):
            decoder_block = DecoderBlock(
                out_size=NUM_CHANS_TO_SPAT_DIM[num_chans_list[k+1]],
                in_chans=num_chans_list[k],
                out_chans=num_chans_list[k+1],
                prompt_dim=prompt_dim,
                film_activation=film_activation,
                #depth=decoder_block_depth,
            )
            self.decoder_blocks.append(decoder_block)

        self.final_upsampler = SpatialUpsampling(
            out_size=out_size,
            in_chans=num_chans_list[-1],
            out_chans=out_chans,
            activation=final_activation,
        )

    def forward(self, x, cats, prompt):
        cats = [cat for i, cat in enumerate(cats) if i in self.encoder_cat_layer_ids]
        assert len(cats) == len(self.decoder_blocks), "Number of cat elements (skip connection) must match number of decoder blocks"
        for decoder_block, cat in zip(self.decoder_blocks, cats):
            x = decoder_block(x, cat, prompt)
        x = self.final_upsampler(x)
        return x

