# Test code for MM24: DERO
from ctypes import resize
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from Noise_Layer import Dropout,Cropout,Identity,GaussianNoise,GaussianFilter,MedFilter,Jpeg,SP,JpegSS,JpegMask,JpegTest,Combined
from model_MBRS import Decoder_Diffusion, Encoder_MP_Diffusion
import random
from torch.nn import Parameter
import torchvision
import time
from ViT_Mask import vit_base_patch16_128_decoder
from diffusers import AutoencoderKL
from random import choice

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

class ConvBNRelu(nn.Module):
    def __init__(self, channels_in, channels_out, stride=1):

        super(ConvBNRelu, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(channels_in, channels_out, 3, stride, padding=1),
            nn.BatchNorm2d(channels_out),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.layers(x)

class SingleConv(nn.Module):
    def __init__(self, inchannel, outchannel, s):
        super(SingleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=s, padding=1, bias=True),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, s):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=s, padding=1, bias=False),
            # nn.BatchNorm2d(outchannel),
            nn.InstanceNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            # nn.BatchNorm2d(outchannel)
            nn.InstanceNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if s != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=s, bias=False),
                # nn.BatchNorm2d(outchannel)
                nn.InstanceNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class DoubleConv(nn.Module):
    def __init__(self, inchannel, outchannel):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class up_conv(nn.Module):
    def __init__(self, inchannel, outchannel):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

from einops import rearrange, repeat, reduce

class Decoder_Latent(nn.Module):
    def __init__(self):
        super(Decoder_Latent,self).__init__()
        self.conv = nn.Sequential(SingleConv(4,64,1),ResidualBlock(64,64,1),ResidualBlock(64,64,2),ResidualBlock(64,64,1),nn.Conv2d(64,1,kernel_size=1,stride=1,padding=0,bias=False))
        self.linear1 = nn.Linear(32*32,512)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(512,256)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.shape[0],-1)
        x1 = self.relu(self.linear1(x))
        Message = self.linear2(x1)
        return Message

class U_Net_Encoder_Diffusion(nn.Module):
    def __init__(self, inchannel=3, outchannel=3):
        super(U_Net_Encoder_Diffusion, self).__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Globalpool = nn.MaxPool2d(kernel_size=4, stride=4)

        self.Conv1 = DoubleConv(inchannel, 16)
        self.Conv2 = DoubleConv(16, 32)
        self.Conv3 = DoubleConv(32, 64)

        self.Up4 = up_conv(64*3, 64)
        self.Conv7 = DoubleConv(64*3, 64)

        self.Up3 = up_conv(64, 32)
        self.Conv8 = DoubleConv(32*2+64, 32)

        self.Up2 = up_conv(32, 16)
        self.Conv9 = DoubleConv(16*2+64, 16)

        self.Conv_1x1 = nn.Conv2d(16, outchannel, kernel_size=1, stride=1, padding=0)
        self.linear = nn.Linear(256,4096)
        self.Conv_message = DoubleConv(1,64)


    def forward(self, x, watermark):
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)

        x6 = self.Globalpool(x4)
        x7 = x6.repeat(1,1,4,4)
        expanded_message = self.linear(watermark)
        expanded_message = expanded_message.view(-1,1,64,64)
        expanded_message_1 = self.Conv_message(expanded_message)
        # print(x4.shape)
        # print(x7.shape)
        # print(expanded_message.shape)
        x4 = torch.cat((x4, x7, expanded_message_1), dim=1)

        d4 = self.Up4(x4)
        # expanded_message = self.linear(watermark)
        # expanded_message = expanded_message.view(-1,1,64,64)
        expanded_message_2 = torch.nn.functional.interpolate(expanded_message,size=(d4.shape[2],d4.shape[3]),mode='bilinear')
        expanded_message_2 = self.Conv_message(expanded_message_2)
        d4 = torch.cat((x3, d4, expanded_message_2), dim=1)
        d4 = self.Conv7(d4)

        d3 = self.Up3(d4)
        # expanded_message = self.linear(watermark)
        # expanded_message = expanded_message.view(-1,1,64,64)
        expanded_message_3 = torch.nn.functional.interpolate(expanded_message,size=(d3.shape[2],d3.shape[3]),mode='bilinear')
        expanded_message_3 = self.Conv_message(expanded_message_3)
        d3 = torch.cat((x2, d3, expanded_message_3), dim=1)
        d3 = self.Conv8(d3)

        d2 = self.Up2(d3)
        # expanded_message = self.linear(watermark)
        # expanded_message = expanded_message.view(-1,1,64,64)
        expanded_message_4 = torch.nn.functional.interpolate(expanded_message,size=(d2.shape[2],d2.shape[3]),mode='bilinear')
        expanded_message_4 = self.Conv_message(expanded_message_4)
        d2 = torch.cat((x1, d2, expanded_message_4), dim=1)
        d2 = self.Conv9(d2)

        watermark_p = self.Conv_1x1(d2)
        # watermark_p = watermark_p.repeat(1,3,1,1)
        out = watermark_p + x
        return out

class Encoder_Decoder_Latent_Inversion(nn.Module):
	def __init__(self, distortion):
		super(Encoder_Decoder_Latent_Inversion, self).__init__()
		self.Encoder = U_Net_Encoder_Diffusion()
		self.Decoder = Decoder_Latent()
		self.Distortion = distortion

		
	def forward(self, x, m, model, simulation_net=None):
		resize_up = torchvision.transforms.Resize([512,512])
		resize_down = [torchvision.transforms.Resize([128,128]),torchvision.transforms.Resize([256,256]),torchvision.transforms.Resize([512,512])]
		Noiser = [GaussianFilter(2),GaussianFilter(3),GaussianFilter(4)]
		K = choice([0,1,2])
		Encoded_image = self.Encoder(x,m)
		if self.Distortion == 'Combined':
		    Noiser = Combined([MedFilter(7),GaussianNoise(0.05),JpegTest(50),Cropout(0.4,0.4),Dropout(0.4)])
		    Noised_image = Noiser(Encoded_image)
		    Noised_latent = model.get_first_stage_encoding(model.encode_first_stage(Noised_image))
		else:
		    Noised_image = resize_up(Noiser[K](resize_down[K](Encoded_image))) + torch.randn_like(Encoded_image)*0.5
		    init_latent = model.get_first_stage_encoding(model.encode_first_stage(Encoded_image)) 
		    distorted_latent = model.get_first_stage_encoding(model.encode_first_stage(Noised_image))
		    noised_latent = init_latent*0.6 + torch.randn_like(init_latent)*0.8
		    Noised_latent = choice([init_latent,noised_latent,noised_latent,noised_latent,distorted_latent])
		Decoded_message = self.Decoder(Noised_latent.float())
		return Encoded_image, Decoded_message
