# Test code for MM24: DERO
from email.policy import strict
from model import Encoder_Decoder_Latent_Inversion
from Noise_Layer import Identity,GaussianNoise,GaussianFilter,MedFilter,Jpeg,SP,Dropout
import torchvision.models as models
from random import choice
from torch.autograd import Variable
from torchvision.utils import save_image
import torch.optim as optim
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime
import cv2
import torch.nn as nn
import scipy.io as scio
# import skimage
import cv2
import math
import time
import lpips
import pytorch_ssim
import random
os.environ["CUDA_VISIBLE_DEVICES"] = '3'
from einops import rearrange
import argparse, os, sys, glob
import PIL
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
from torchvision.utils import make_grid
from torch import autocast
from contextlib import nullcontext
import time
import torchvision
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from diffusers import AutoencoderKL
import lpips

def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model


def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from {path}")
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

class Solver(object):
    """Solver for training and testing StarGAN."""

    def __init__(self, data_loader,data_loader_test, config):
        """Initialize configurations."""

        # Data loader.
        self.data_loader = data_loader
        self.data_loader_test = data_loader_test
        # Model configurations.
        self.image_size = config.image_size
        self.num_channels = config.num_channels
        self.decoder_number = config.decoder_number

        # Training configurations.
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.lambda1 = config.lambda1
        self.lambda2 = config.lambda2
        self.lambda3 = config.lambda3
        self.num_epoch = config.num_epoch
        self.resume_epoch = config.resume_epoch
        self.distortion = config.distortion

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Directories.
        self.log_dir = config.log_dir
        self.model_save_dir = config.model_save_dir
        self.model_name = config.model_name
        self.result_dir = config.result_dir
        self.embedding_epoch = config.embedding_epoch

        # Step size.
        self.log_step = config.log_step
        self.model_save_step = config.model_save_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

    def build_model(self):
        if self.dataset in ['test_embed']:
        	config = OmegaConf.load("stable-diffusion/v1-inference.yaml")
        	model = load_model_from_config(config, "stable-diffusion/stable-diffusion-v1/sd-v1-4.ckpt")
        	self.model = model.to(self.device)
        	self.net_ED = Encoder_Decoder_Latent_Inversion(self.distortion)
        	self.net_ED.load_state_dict(torch.load('Encoder_Decoder_Model_latent_inversion.pth'),strict=False)
        	self.net_ED.to(self.device)
        	self.net_D = self.net_ED.Decoder
            #-------------------------------------------------------------------------------------------------------------------------------------------------------------#
        elif self.dataset in ['test_accuracy']:
        	config = OmegaConf.load("stable-diffusion/v1-inference.yaml")
        	model = load_model_from_config(config, "stable-diffusion/stable-diffusion-v1/sd-v1-4.ckpt")
        	self.model = model.to(self.device)
        	self.net_ED = Encoder_Decoder_Latent_Inversion(self.distortion)
        	self.net_ED.to(self.device)
        	self.net_D = self.net_ED.Decoder
        	self.net_ED.load_state_dict(torch.load('Encoder_Decoder_Model_latent_inversion.pth'),strict=False)
			
    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)
		        
    def test_embed(self):
        data_loader = self.data_loader
        self.net_ED.eval()
        self.net_D.eval()
        total = 0
        acc = [0,0,0,0,0,0,0,0,0]
        alpha = 1
        # Noiser = [JpegTest(50),JpegTest(60),JpegTest(70),GaussianNoise(0.01),GaussianNoise(0.02),GaussianNoise(0.05),Dropout(0.7),Dropout(0.8),Dropout(0.9)]
        Noiser = [SP(0.01),SP(0.02),SP(0.03),GaussianFilter(1),GaussianFilter(2),GaussianFilter(3),MedFilter(3),MedFilter(5),MedFilter(7)]
        Watermarked_Image = np.zeros([512,512,3])
        for i, (data, m, num) in enumerate(data_loader):
            inputs, m = Variable(data), Variable(m.float())
            inputs, m = inputs.to(self.device), m.to(self.device)
            num = num.to('cpu').numpy()
            # Encoded_image, Decoded_message = self.net_ED(inputs, m, self.model)
            Encoded_image = self.net_ED.Encoder(inputs,m)
            Encoded_image = (Encoded_image-inputs)*alpha+inputs
            for j in range(9):
                Noised_image = Noiser[j](Encoded_image)
                distorted_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(Noised_image))
                Decoded_message = self.net_ED.Decoder(distorted_latent.float())
                decoded_rounded = Decoded_message.round().clip(0, 1)
                acc[j] += float(torch.sum(decoded_rounded.squeeze(1) == m.clip(0, 1)))	
            total += m.shape[0] * m.shape[1]
            for j in range(Encoded_image.shape[0]):
            	I1 = (inputs[j,:,:,:].detach().to('cpu').numpy()+1)/2*255
            	I1 = np.transpose(I1,(1,2,0))
            	II = np.zeros((I1.shape[0],I1.shape[1]*3,I1.shape[2]))
            	I2 = (Encoded_image[j,:,:,:].detach().to('cpu').numpy()+1)/2*255
            	I2 = np.transpose(I2,(1,2,0))
            	I3 = (Noised_image[j,:,:,:].detach().to('cpu').numpy()+1)/2*255
            	I3 = np.transpose(I3,(1,2,0))
            	II[:,:I1.shape[1],:] = I1
            	II[:,I1.shape[1]:I1.shape[1]*2,:] = I2
            	II[:,I1.shape[1]*2:I1.shape[1]*3,:] = I3
            	III = np.zeros_like(II)
            	index = num[j]
            	for c in range(3):
            	    III[:,:,c] = II[:,:,2-c]
            	if not os.path.exists(self.result_dir+'/Image_test_'+self.distortion+'/images_embed_'+str(self.embedding_epoch)+'/Embedded/'):
            		os.makedirs(self.result_dir+'/Image_test_'+self.distortion+'/images_embed_'+str(self.embedding_epoch)+'/Embedded/')
            	cv2.imwrite(self.result_dir+'/Image_test_'+self.distortion+'/images_embed_'+str(self.embedding_epoch)+'/'+str(index)+'.png',III)
            	cv2.imwrite(self.result_dir+'/Image_test_'+self.distortion+'/images_embed_'+str(self.embedding_epoch)+'/Embedded/'+str(index)+'.png',III[:,I1.shape[1]:I1.shape[1]*2,:])

        for j in range(9):
            print("Correct Rate:%.3f"%((acc[j]/total)*100)+'%')
        print('Embed finished!')

    def test_accuracy(self):
        data_loader = self.data_loader
        self.net_D.eval()
        acc1 = 0
        total = 0
        current_prompt = ' '
        with torch.no_grad():
            for i, (data, m, num) in enumerate(data_loader):
                inputs, m = Variable(data), Variable(m.float())
                inputs, m = inputs.to(self.device), m.to(self.device)
                num = num.to('cpu').numpy()
                init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(inputs))
                Decoded_message1 = self.net_D(init_latent.to(inputs.dtype))
                

                decoded_rounded = Decoded_message1.detach().cpu().numpy().round().clip(0, 1)
                correct = np.sum(np.abs(decoded_rounded - m.detach().cpu().numpy()))
                # print(num)
                # print(correct)
                acc1 = acc1 + correct
                total = total + inputs.shape[0]

            print(acc1)
            print("Correct Rate:%.3f"%((1-(acc1/((i+1)*inputs.shape[0] * m.shape[1])))*100)+'%')