import math
import torch
import torch.nn
import torch.optim
import torchvision
import numpy as np
import os
import sys
sys.path.append('./')
sys.path.append('../')
from INN.model.model_ori import *
import INN.config as c
from INN.dataset_ori import get_data_loaders
import modules.Unet_common as common
from tensorboardX import SummaryWriter
from transform import *
from EditGuard.code.utils.JPEG import DiffJPEG
from skimage.metrics import structural_similarity as ssim
import numpy as np
from FIN.utils.jpeg import JpegSS, JpegTest
from SSlayer import SSlayer
from png import apply_png


# # Function to compute SSIM
def compute_ssim(img1, img2):
    return ssim(img1, img2, data_range=img2.max() - img2.min(),  win_size=3, channel_axis=0)



device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
writer = None



def get_writer():
    global writer
    if writer is None:
        writer = SummaryWriter(comment='hinet', filename_suffix="steg")
    return writer



def load(name):
    state_dicts = torch.load(name)
    network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
    net.load_state_dict(network_state_dict, strict=False)
    try:
        optim.load_state_dict(state_dicts['opt'])
    except:
        print('Cannot load optimizer for some reason or other')


def gauss_noise(shape):

    noise = torch.zeros(shape).to(device)
    for i in range(noise.shape[0]):
        noise[i] = torch.randn(noise[i].shape).to(device)

    return noise

def binary_noise(shape):
    noise = torch.zeros(shape).to(device)
    for i in range(noise.shape[0]):
        # Generate binary noise with values 0 or 1
        noise[i] = torch.randint(0, 2, noise[i].shape).float().to(device)

    return noise

def computePSNR(origin,pred):
    origin = np.array(origin)
    origin = origin.astype(np.float32)
    pred = np.array(pred)
    pred = pred.astype(np.float32)
    mse = np.mean((origin/1.0 - pred/1.0) ** 2 )
    if mse < 1.0e-10:
      return 100
    return 10 * math.log10(255.0**2/mse)


net = Model()
net.to(device)
init_model(net)
net = torch.nn.DataParallel(net, device_ids=c.device_ids)
params_trainable = (list(filter(lambda p: p.requires_grad, net.parameters())))
optim = torch.optim.Adam(params_trainable, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
weight_scheduler = torch.optim.lr_scheduler.StepLR(optim, c.weight_step, gamma=c.gamma)

test_noise_layer = JpegTest(85)

c.suffix = 'best.pt'
load(c.MODEL_PATH + c.suffix)

net.eval()
ss_layer = SSlayer(requires_grad=False).to(device)

c.VAL_PATH = 'your_path'

def backward_only(steg_img_path, net, device):
    """Run only the backward pass on a stego image"""
    dwt = common.DWT().to(device)
    iwt = common.IWT().to(device)
    
    with torch.no_grad():
        # Load and preprocess stego image
        steg_img = Image.open(steg_img_path).convert('RGB')
        transform = torchvision.transforms.ToTensor()
        steg_img = transform(steg_img).unsqueeze(0).to(device)
        
        # Generate noise for backward pass
        backward_z = gauss_noise(dwt(steg_img).shape)
        
        # Backward pass
        output_rev = torch.cat((dwt(steg_img), backward_z), 1)
        backward_img = net(output_rev, rev=True)
        
        # Extract recovered secret
        secret_rev = backward_img.narrow(1, 4 * c.channels_in, backward_img.shape[1] - 4 * c.channels_in)
        secret_rev = iwt(secret_rev)
        
        # Save recovered secret
        os.makedirs(c.IMAGE_PATH_secret_rev, exist_ok=True)
        torchvision.utils.save_image(secret_rev, c.IMAGE_PATH_secret_rev + 'recovered_secret.png')
        
        return secret_rev





def inference(testloader, c, transformation):
    dwt = common.DWT().to(device)
    iwt = common.IWT().to(device)
   
    with torch.no_grad():
        psnr_s = []
        psnr_c = []
        for i, data in enumerate(testloader):
            data = data.to(device)
            secret = data[data.shape[0] // 2:, :, :, :]
            cover = data[:data.shape[0] // 2, :, :, :]
            cover_input = dwt(cover)
            secret_input = dwt(secret)
            
            input_img = torch.cat((cover_input, secret_input), 1)

            #################
            #    forward:   #
            #################
            output = net(input_img)
            output_steg = output.narrow(1, 0, 4 * c.channels_in)
            output_steg[:, :c.channels_in, :, :] = cover_input[:, :c.channels_in, :, :]
            # output_z = output.narrow(1, 4 * c.channels_in, output.shape[1] - 4 * c.channels_in)
            steg = iwt(output_steg)
            

            steg_img = steg.to(device)
            #################
            #  add perturbation #
            #################
            batch_size, channels, height, width = steg_img.shape
            # Generate Gaussian noise (mean=0, std=1) with the same shape as steg_img
            noise = torch.randn(batch_size, channels, height, width)

            # Adjust noise intensity by scaling (optional)
            noise_intensity = 0.05  # Set the noise intensity (you can adjust this)
            noise = noise * noise_intensity

            if transformation == 'noise':
                # Apply the noise to steg_img by adding it
                steg_img = steg_img + noise.to(device)

            if transformation == 'brightness':
                steg_img = apply_selected_transformations(
                    steg_img, device, 
                    brightness=True
                )
            
            if transformation == 'contrast':
                steg_img = apply_selected_transformations(
                    steg_img, device, 
                    contrast=True
                )
            
            if transformation == 'blur':
                steg_img = apply_selected_transformations(
                    steg_img, device, 
                    blur=True
                )
            
            if transformation == 'flip':
                steg_img_tmp = apply_selected_transformations(
                    steg_img, device, 
                    flip=True
                )
            

                steg_img = F.hflip(steg_img_tmp)

            if transformation == 'jpeg':
                steg_img = test_noise_layer(steg_img.clone()).to(device)
               
            
            backward_z = gauss_noise(dwt(steg_img).shape)

            # print('steg_img: ', steg_img.shape)
            # print('output_steg: ', dwt(steg_img).shape)
            # print('backward_z: ', backward_z.shape)
            #################
            #   backward:   #
            #################
            # output_rev = torch.cat((output_steg, backward_z), 1)
            output_rev = torch.cat((dwt(steg_img).to(device), backward_z.to(device)), 1)
            bacward_img = net(output_rev, rev=True)
            secret_rev = bacward_img.narrow(1, 4 * c.channels_in, bacward_img.shape[1] - 4 * c.channels_in)
            
            secret_rev = iwt(secret_rev)
            cover_rev = bacward_img.narrow(1, 0, 4 * c.channels_in)
            
            cover_rev = iwt(cover_rev)
            resi_cover = (steg_img.to(device) - cover.to(device)) * 10
            resi_secret = (secret_rev.to(device) - secret.to(device)) * 10

            # secret =ss_layer(secret.to(device))
            # secret_rev =ss_layer(secret_rev.to(device))
            # secret_rev = torch.sigmoid(secret_rev.to(device))

          

            os.makedirs(c.IMAGE_PATH_cover, exist_ok=True)
            os.makedirs(c.IMAGE_PATH_secret, exist_ok=True)
            os.makedirs(c.IMAGE_PATH_steg, exist_ok=True)
            os.makedirs(c.IMAGE_PATH_secret_rev, exist_ok=True)
            os.makedirs(c.IMAGE_PATH_resi_cover, exist_ok=True)
            os.makedirs(c.IMAGE_PATH_resi_secret, exist_ok=True)

            steg = steg_img
        
            torchvision.utils.save_image(cover, c.IMAGE_PATH_cover + 'cover_'+ '%s.png' % transformation)
            torchvision.utils.save_image(secret, c.IMAGE_PATH_secret + 'secret_'+'%s.png' % transformation)
            if transformation == 'flip':
                torchvision.utils.save_image(steg_img_tmp, c.IMAGE_PATH_steg + 'steg_' + '%s.png' % transformation)
            else:
                torchvision.utils.save_image(steg_img, c.IMAGE_PATH_steg + 'steg_' + '%s.png' % transformation)
            torchvision.utils.save_image(secret_rev, c.IMAGE_PATH_secret_rev + 'secret_rev_' + '%s.png' % transformation)
            torchvision.utils.save_image(resi_cover, c.IMAGE_PATH_resi_cover + 'resi_cover_'+'%s.png' % transformation)
            torchvision.utils.save_image(resi_secret, c.IMAGE_PATH_resi_secret + 'resi_secret_'+ '%s.png' % transformation)
            torchvision.utils.save_image(cover_rev, c.IMAGE_PATH_cover + 'cover_rev_' + '%s.png' % transformation)
            secret_rev = secret_rev.cpu().numpy().squeeze() * 255
            np.clip(secret_rev, 0, 255)
            secret = secret.cpu().numpy().squeeze() * 255
            np.clip(secret, 0, 255)
            cover = cover.cpu().numpy().squeeze() * 255
            np.clip(cover, 0, 255)
            steg = steg.cpu().numpy().squeeze() * 255
            np.clip(steg, 0, 255)
            psnr_temp = computePSNR(secret_rev, secret)
            psnr_s.append(psnr_temp)
            psnr_temp_c = computePSNR(cover, steg)
            psnr_c.append(psnr_temp_c)

            # ssim_secret = compute_ssim(secret_rev, secret)
            # print('cover: ', cover.shape)
            ssim_cover = compute_ssim(cover, steg)

            # print(f"SSIM (secret): {ssim_secret}")
            print(f"SSIM (cover): {ssim_cover}")

            

        w = get_writer()
        if w is not None:
            w.add_scalars("PSNR_S", {"average psnr": np.mean(psnr_s)})
            w.add_scalars("PSNR_C", {"average psnr": np.mean(psnr_c)})

        

        return np.mean(psnr_s), np.mean(psnr_c)

if __name__ == '__main__':
    try:
        testloader = get_data_loaders(c, 2, c.cropsize_val, "val")
        transformations = ['none']
        # transformations = ['none','noise', 'brightness', 'contrast', 'blur', 'flip', 'jpeg']
        for transformation in transformations:
            psnr_s, psnr_c = inference(testloader, c, transformation)
            print('Transformation: ', transformation)
            print('PSNR_S: ', psnr_s)
            print('PSNR_C: ', psnr_c)
    finally:
        if writer is not None:
            writer.close()
