import os
from util.dataset import trainloader
from util import viz
import util.config as c
import INN_based_hiding_network.Unet_common as common
import warnings
from optimization_based_adversary_module.fr_util import generate_high
import time
from torch.autograd import Variable
import numpy as np

from util.utils import *
from INN_based_hiding_network.model import *

from util.sec_reshape import sec_reshape
from skimage.metrics import structural_similarity as SSIM, peak_signal_noise_ratio as PSNR
from torchvision.utils import save_image
from optimization_based_adversary_module.cov import Net 
import torchvision.transforms as T

#device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

warnings.filterwarnings("ignore")
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')

def gauss_noise(shape):
    noise = torch.zeros(shape).to(device)
    for i in range(noise.shape[0]):
        noise[i] = torch.randn(noise[i].shape).to(device)

    return noise

def concealing_loss(output, bicubic_image):
    loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
    loss = loss_fn(output, bicubic_image)
    return loss.to(device)

def revealing_loss(rev_input, input):
    loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
    loss = loss_fn(rev_input, input)
    return loss.to(device)

def low_pass_filter_loss(ll_input, gt_input):
    loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
    loss = loss_fn(ll_input, gt_input)
    return loss.to(device)


def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

def calc_psnr_ssim(cover_imgv, container_img):
        N, _, _, _ = cover_imgv.shape

        # print("cover_imgv.shape:", cover_imgv.shape)
        
        cover_img_numpy = cover_imgv.clone().cpu().detach().numpy()
        container_img_numpy = container_img.clone().cpu().detach().numpy()
        
        cover_img_numpy = cover_img_numpy.transpose(0, 2, 3, 1)
        container_img_numpy = container_img_numpy.transpose(0, 2, 3, 1)
        
        psnr = np.zeros((N, 3))
        for i in range(N):
            psnr[i, 0] = PSNR(cover_img_numpy[i, :, :, 0], container_img_numpy[i, :, :, 0], data_range=1)
            psnr[i, 1] = PSNR(cover_img_numpy[i, :, :, 1], container_img_numpy[i, :, :, 1], data_range=1)
            psnr[i, 2] = PSNR(cover_img_numpy[i, :, :, 2], container_img_numpy[i, :, :, 2], data_range=1)
        psnr_res = psnr.mean().item()
        
        ssim = np.zeros(N)
        for i in range(N):
            ssim[i] = SSIM(cover_img_numpy[i], container_img_numpy[i], data_range=1, channel_axis=2)
        ssim_res = ssim.mean().item()
        
        return psnr_res, ssim_res
def high_pass_filter_loss(output, bicubic_image):
    loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
    loss = loss_fn(output, bicubic_image)
    return loss.to(device)

def calc_msg_acc(secret_imgv_nh, rev_secret_img):
    secret_imgv_nh = secret_imgv_nh.cuda()
    rev_secret_img = rev_secret_img.cuda()
    
    decoder_acc = (rev_secret_img >= 0.5).eq(secret_imgv_nh >= 0.5).sum().float() / secret_imgv_nh.numel()
    return decoder_acc

def load(name):
    state_dicts = torch.load(name)
    network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
    net.load_state_dict(network_state_dict)
    try:
        optim.load_state_dict(state_dicts['opt'])
    except:
        print('Cannot load optimizer for some reason or other')
        
################################
#  Generator  Model initialize: #
#################################
net = Model()
net.to(device)
init_model(net)
new_state_dict = {}

para = get_parameter_number(net)
print(para)
params_trainable = (list(filter(lambda p: p.requires_grad, net.parameters())))

optim = torch.optim.Adam(params_trainable, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
weight_scheduler = torch.optim.lr_scheduler.StepLR(optim, c.weight_step, gamma=c.gamma)

IMAGE_NUM = 5000
BATCH_SIZE = c.batch_size_2
RUN_BATCH = IMAGE_NUM // BATCH_SIZE


dwt = common.DWT().to(device)
iwt = common.IWT().to(device)

if c.tain_next_2:
    load(c.train_first_phase)

############################
#  a CNN-based steganalyzer: #
#############################
#CovNet is trained and  fixed 
pretrained_model = c.CovNet_model
WEIGHT_DECAY=5e-4
target_model = Net().to(device)  
params = target_model.parameters()

params_wd, params_rest = [], []
for param_item in params:
    if param_item.requires_grad:
        (params_wd if param_item.dim() != 1 else params_rest).append(param_item)

param_groups = [{"params": params_wd, "weight_decay": WEIGHT_DECAY}, {"params": params_rest}]
optimizer = torch.optim.SGD(param_groups, lr=0.0001, momentum=0.9)

all_state = torch.load(pretrained_model)
original_state = all_state["original_state"]
optimizer_state = all_state["optimizer_state"]
target_model.load_state_dict(original_state)
optimizer.load_state_dict(optimizer_state)

target_model.eval()
############################
#     Model      Train: #
#############################
torch.autograd.set_detect_anomaly(True)

psnr_history=[]
ssim_history=[]
acc_history=[]   
total_acc=0.0
try:
    
    log_path = "train_second_phase.txt"
    for i_batch, mydata in enumerate(trainloader):    
        totalTime = time.time()
        data = mydata.to(device) 
        cover = data.to(device)    
        secret = np.random.choice([0, 1], size=[cover.size(0), 1, 256, 256])      
        sec_map = torch.from_numpy(sec_reshape(secret)).to(device)   #1,3,256,256

    #P setting
        X_1 = torch.full((1, 3, 256, 256), 0.0001).to(device)   
        X_ori = X_1.to(device)
        X_ori = Variable(X_ori, requires_grad=True)
        optim2 = torch.optim.Adam([X_ori], lr=c.lr2)  
        loss_history = []
        L_con_history = []
        L_rev_history = []
        L_lpf_history = []
        for i_epoch in range(c.epochs):
            P = X_ori.to(device)
            cover_1 = P+cover 
            cover_input = dwt(cover_1).to(device)   #1,12,128,128
            secret_input = dwt(sec_map).to(device) #1,4,128,128  
            input_img = torch.cat((cover_input, secret_input), 1).to(device) 
        #forward
            output = net(input_img).to(device)   #1,16,128,128
            output_steg = output.narrow(1, 0, 4 * c.channels_in)  #1,12,128,128
            output_z = output.narrow(1, 4 * c.channels_in, output.shape[1] - 4 * c.channels_in)   #1,4,128,128
            steg_img = iwt(output_steg).to(device) 
        #backward     
            output_z_guass = gauss_noise(output_z.shape).to(device) 
            output_rev = torch.cat((output_steg, output_z_guass), 1).to(device) 
            output_image = net(output_rev, rev=True).to(device) 
            secret_rev = output_image.narrow(1, 4 * c.channels_in, output_image.shape[1] - 4 * c.channels_in).to(device) 
            secret_rev = iwt(secret_rev).to(device)     

        #####    INN LOSS :    #####  
            L_con = concealing_loss(steg_img.cuda(), cover.cuda()).to(device) 
            L_rev = revealing_loss(secret_rev, sec_map.float()).to(device) 
            steg_low = output_steg.narrow(1, 0, c.channels_in).to(device) 
            cover_low = cover_input.narrow(1, 0, c.channels_in).to(device) 
            L_lpf = low_pass_filter_loss(steg_low, cover_low).to(device)   
            total_loss = c.lamda_c * L_con+c.lamda_r * L_rev + c.lamda_l * L_lpf   

        #####    HPF LOSS  :    #####        
            clean_hfc = generate_high(cover, r=12)
            per_hfc = generate_high(steg_img, r=12)
            hpf_loss = high_pass_filter_loss(clean_hfc, per_hfc)       
        #####  steganalyzer:    #####     
            cover_target=cover.clone() .to(device)
            stego_target=steg_img.clone() .to(device) 
            images_cat = torch.cat((cover_target, stego_target), 0).to(device) 
            label = np.array([0,0], dtype="int32")   #cover,0;stego,1
            target_labels = torch.from_numpy(label.reshape(-1)).to(device, dtype=torch.long)
            images_cat = images_cat *255
            target_images= torch.clamp(images_cat,0,255).to(device, dtype=torch.float)
            target_outputs = target_model(target_images).to(device)
            probs = torch.softmax(target_outputs, dim=1) 
            for i, sample_probs in enumerate(probs):
                if i ==0:
                    cover_percentage= sample_probs[0] * 100
                if i==1:
                    stego_percentage = sample_probs[0] * 100           

            acc = calc_msg_acc(sec_map, secret_rev) 
            psnr, ssim = calc_psnr_ssim(cover, steg_img) 
   
            if stego_percentage >= 99.99 and cover_percentage >= 99.99 and acc>0.99 and psnr >50:
            # if stego_percentage >= 99.99 and cover_percentage >= 99.99 :   
            
                folder_test_0_1 = os.path.join(c.IMAGE_PATH_1,str(i_batch)) 
                if not os.path.exists(folder_test_0_1):
                    os.makedirs(folder_test_0_1)
                save_image(cover, os.path.join(folder_test_0_1, str(i_batch) +'cover.png'))
                save_image(steg_img, os.path.join(folder_test_0_1, str(i_batch) +'stego.png'))
                save_image(cover_1,os.path.join(folder_test_0_1, str(i_batch) +'cover_opt.png'))

                open_path=os.path.join(folder_test_0_1, str(i_batch) +'stego.png')
                image = Image.open(open_path)
                image = to_rgb(image)
                transform_val = T.Compose([
                    T.CenterCrop(c.cropsize_val), 
                    T.ToTensor(),
                ])

                item =transform_val(image).unsqueeze(0).to(device) 
                stego_input=dwt(item).to(device) 

                output_z_guass = gauss_noise(output_z.shape).to(device) 
                output_rev = torch.cat((stego_input, output_z_guass), 1).to(device) 
                output_image = net(output_rev, rev=True).to(device) 
                secret_rev = output_image.narrow(1, 4 * c.channels_in, output_image.shape[1] - 4 * c.channels_in).to(device) 
                secret_rev = iwt(secret_rev).to(device)   
                acc = calc_msg_acc(sec_map, secret_rev) 
                totalstop_time = time.time()
                time_cost = totalstop_time - totalTime
                psnr_history.append([psnr, 0.])
                ssim_history.append([ssim, 0.])            
                log_info = "Finally: Last_Batch:%d \tpsnr: %.4f \tssim:%.4f \tacc:%.4f \tP: %.4f \ttime_cost: %.4f \t" \
                    % (i_batch, psnr,ssim,acc,P.max(),time_cost)
                if not os.path.exists(log_path):
                    fp = open(log_path, "w")
                    fp.writelines(log_info + "\n")
                else:
                    with open(log_path, 'a+') as f:
                        f.writelines(log_info + '\n')
                break

            if (i_epoch + 1 == c.epochs):
                     
                folder_test_0_1 = os.path.join(c.IMAGE_PATH_1,str(i_batch)) 
                if not os.path.exists(folder_test_0_1):
                    os.makedirs(folder_test_0_1)
                save_image(cover, os.path.join(folder_test_0_1, str(i_batch) +'cover.png'))
                save_image(steg_img, os.path.join(folder_test_0_1, str(i_batch) +'stego.png'))
                save_image(cover_1,os.path.join(folder_test_0_1, str(i_batch) +'cover_opt.png'))

                open_path=os.path.join(folder_test_0_1, str(i_batch) +'stego.png')
                image = Image.open(open_path)
                image = to_rgb(image)
                transform_val = T.Compose([
                    T.CenterCrop(c.cropsize_val), 
                    T.ToTensor(),
                ])
                item =transform_val(image).unsqueeze(0).to(device)                         
                stego_input=dwt(item).to(device) 

                output_z_guass = gauss_noise(output_z.shape).to(device) 
                output_rev = torch.cat((stego_input, output_z_guass), 1).to(device) 
                output_image = net(output_rev, rev=True).to(device) 
                secret_rev = output_image.narrow(1, 4 * c.channels_in, output_image.shape[1] - 4 * c.channels_in).to(device) 
                secret_rev = iwt(secret_rev).to(device)   
                acc = calc_msg_acc(sec_map, secret_rev) 
                totalstop_time = time.time()
                time_cost = totalstop_time - totalTime
                psnr_history.append([psnr, 0.])
                ssim_history.append([ssim, 0.])
                log_info = "Finally: Last_Batch:%d \tpsnr: %.4f \tssim:%.4f \tacc:%.4f \tP: %.4f \ttime_cost: %.4f \t" \
                    % (i_batch, psnr,ssim,acc,P.max(),time_cost)
       
                if not os.path.exists(log_path):
                    fp = open(log_path, "w")
                    fp.writelines(log_info + "\n")
                else:
                    with open(log_path, 'a+') as f:
                        f.writelines(log_info + '\n')
                break
            optim2.zero_grad() 
            optim.zero_grad()
            total_loss.backward(retain_graph=True)

            L_HPF= 1*hpf_loss 
            L_HPF.backward()
            optim.step()          
            optim2.step()

            if (i_epoch + 1) % 50 == 0:
                log_info = "Batch:%d \ttotal_loss: %.4f \tpsnr: %.4f \tssim:%.4f \tacc:%.4f \tP: %.4f \t" \
                    % (i_batch,total_loss.item(), psnr,ssim,acc,P.max())

                if not os.path.exists(log_path):
                    fp = open(log_path, "w")
                    fp.writelines(log_info + "\n")
                else:
                    with open(log_path, 'a+') as f:
                        f.writelines(log_info + '\n')
            
            if (i_batch + 1) == RUN_BATCH:
                break
        if i_batch > 0 and (i_batch % c.SAVE_freq) == 0:
                torch.save({'opt': optim.state_dict(),
                            'net': net.state_dict()}, c.MODEL_PATH_2 + 'train_second_phase_checkpoint_%.5i' % i_batch + '.pt')    
    psnr_sum = 0.0
    num_psnr = len(psnr_history)

    for item in psnr_history:
        psnr_sum = item[0]+psnr_sum
    psnr_avg = psnr_sum / num_psnr
    ssim_sum = 0.0
    num_ssim = len(ssim_history)
    for item in ssim_history:
        ssim_sum = item[0]+ssim_sum
    ssim_avg = ssim_sum / num_ssim
       
    print("psnr_avg",psnr_avg)
    print("ssim_avg", ssim_avg)   
    torch.save({'opt': optim.state_dict(),
                'net': net.state_dict()}, c.MODEL_PATH_2 + 'train_second_phase_model' + '.pt')
   

except:
    if c.checkpoint_on_error:
        torch.save({'opt': optim.state_dict(),
                    'net': net.state_dict()}, c.MODEL_PATH_2 + 'model_ABORT' + '.pt')
    raise

finally:
    viz.signal_stop()
