import os
from util.dataset import trainloader    
from util.dataset import testloader   
from util import viz
import util.config as c
import INN_based_hiding_network.Unet_common as common
import warnings
import numpy as np
from util.sec_reshape import sec_reshape
from util.utils import *
from INN_based_hiding_network.model import *


import math
from torch.utils.tensorboard import SummaryWriter

from skimage.metrics import structural_similarity as SSIM, peak_signal_noise_ratio as PSNR

warnings.filterwarnings("ignore")
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def gauss_noise(shape):
    noise = torch.zeros(shape).to(device)
    for i in range(noise.shape[0]):
        noise[i] = torch.randn(noise[i].shape).to(device)

    return noise

def concealing_loss(output, bicubic_image):
    loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
    loss = loss_fn(output, bicubic_image)
    return loss.to(device)

def revealing_loss(rev_input, input):
    loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
    loss = loss_fn(rev_input, input)
    return loss.to(device)

def low_pass_filter_loss(ll_input, gt_input):
    loss_fn = torch.nn.MSELoss(reduce=True, size_average=False)
    loss = loss_fn(ll_input, gt_input)
    return loss.to(device)

def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

def calc_psnr_ssim(cover_imgv, container_img):
        N, _, _, _ = cover_imgv.shape
      
        cover_img_numpy = cover_imgv.clone().cpu().detach().numpy()
        container_img_numpy = container_img.clone().cpu().detach().numpy()
        
        cover_img_numpy = cover_img_numpy.transpose(0, 2, 3, 1)
        container_img_numpy = container_img_numpy.transpose(0, 2, 3, 1)
        
        psnr = np.zeros((N, 3))
        for i in range(N):
            psnr[i, 0] = PSNR(cover_img_numpy[i, :, :, 0], container_img_numpy[i, :, :, 0], data_range=1)
            psnr[i, 1] = PSNR(cover_img_numpy[i, :, :, 1], container_img_numpy[i, :, :, 1], data_range=1)
            psnr[i, 2] = PSNR(cover_img_numpy[i, :, :, 2], container_img_numpy[i, :, :, 2], data_range=1)
        psnr_res = psnr.mean().item()
        
        ssim = np.zeros(N)
        for i in range(N):
            ssim[i] = SSIM(cover_img_numpy[i], container_img_numpy[i], data_range=1, channel_axis=2)
        ssim_res = ssim.mean().item()
        
        return psnr_res, ssim_res

def calc_msg_acc(secret_imgv_nh, rev_secret_img):
    secret_imgv_nh = secret_imgv_nh.cuda()
    rev_secret_img = rev_secret_img.cuda()
    
    decoder_acc = (rev_secret_img >= 0.5).eq(secret_imgv_nh >= 0.5).sum().float() / secret_imgv_nh.numel()
    return decoder_acc

def load(name):
    state_dicts = torch.load(name)
    network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
    net.load_state_dict(network_state_dict)
    try:
        optim.load_state_dict(state_dicts['opt'])
    except:
        print('Cannot load optimizer for some reason or other')


#####################
# Model initialize: #
#####################

net = Model()
net.to(device)
init_model(net)
#net = torch.nn.DataParallel(net, device_ids=c.device_ids)
para = get_parameter_number(net)
print(para)
params_trainable = (list(filter(lambda p: p.requires_grad, net.parameters())))

optim = torch.optim.Adam(params_trainable, lr=c.lr, betas=c.betas, eps=1e-6, weight_decay=c.weight_decay)
weight_scheduler = torch.optim.lr_scheduler.StepLR(optim, c.weight_step, gamma=c.gamma)

IMAGE_NUM = 5000
BATCH_SIZE = c.batch_size
RUN_BATCH = IMAGE_NUM // BATCH_SIZE

VAL_IMAGE_NUM = 3000
BATCH_SIZE = c.batchsize_val
RUN_BATCH_VAL = VAL_IMAGE_NUM // BATCH_SIZE  

dwt = common.DWT().to(device)
iwt = common.IWT().to(device)

if c.tain_next:
    load(c.MODEL_PATH + c.suffix)
try:
    train_writer = SummaryWriter('train_first_phase')
    test_writer = SummaryWriter('test_first_phase')

    log_path = "train_first_phase.txt"
    for i_epoch in range(c.epochs):
        i_epoch = i_epoch + c.trained_epoch + 1
        loss_history = []
        L_con_history = []
        L_rev_history = []
        L_lpf_history = []
        #################
        #     train:    #
        #################
        
        for i_batch, mydata in enumerate(trainloader):
            data = mydata.to(device) 
            cover = data.to(device)    
            secret = np.random.choice([0, 1], size=[cover.size(0), 1, 256, 256])       
            sec_map = torch.from_numpy(sec_reshape(secret)).to(device)  
        
            cover_input = dwt(cover).to(device)         
            secret_input = dwt(sec_map).to(device)         
            input_img = torch.cat((cover_input, secret_input), 1).to(device) 

            #################
            #    forward:   #
            #################
            output = net(input_img).to(device)   
            output_steg = output.narrow(1, 0, 4 * c.channels_in)  
            output_z = output.narrow(1, 4 * c.channels_in, output.shape[1] - 4 * c.channels_in)   
            steg_img = iwt(output_steg).to(device) 

            #################
            #   backward:   #
            #################

            output_z_guass = gauss_noise(output_z.shape).to(device) 
            output_rev = torch.cat((output_steg, output_z_guass), 1).to(device) 
            output_image = net(output_rev, rev=True).to(device) 
            secret_rev = output_image.narrow(1, 4 * c.channels_in, output_image.shape[1] - 4 * c.channels_in).to(device) 
            secret_rev = iwt(secret_rev).to(device) 

            #################
            #     loss:     #
            #################

            L_con = concealing_loss(steg_img.cuda(), cover.cuda()).to(device) 
            L_rev = revealing_loss(secret_rev, sec_map.float()).to(device) 
            steg_low = output_steg.narrow(1, 0, c.channels_in).to(device) 
            cover_low = cover_input.narrow(1, 0, c.channels_in).to(device) 
            L_lpf = low_pass_filter_loss(steg_low, cover_low).to(device)   
            total_loss = c.lamda_c * L_con+c.lamda_r * L_rev + c.lamda_l * L_lpf  

            psnr, ssim = calc_psnr_ssim(cover, steg_img)
            acc = calc_msg_acc(sec_map, secret_rev)
            
            
            total_loss.backward()
            optim.step()
            optim.zero_grad()
            loss_history.append([total_loss.item(), 0.])

            L_con_history.append([L_con.item(), 0.])
            L_rev_history.append([L_rev.item(), 0.])
            L_lpf_history.append([L_lpf.item(), 0.])
          
            if (i_batch + 1) % 100 == 0:
                log_info = "train loss: %.4f \tencoded_loss: %.4f \treval_loss: %.4f \tpsnr: %.2f \tssim: %.2f \tacc: %.4f" \
                    % (total_loss.item(),L_con.item(), L_rev.item(), psnr, ssim, acc)
                if not os.path.exists(log_path):
                    fp = open(log_path, "w")
                    fp.writelines(log_info + "\n")
                else:
                    with open(log_path, 'a+') as f:
                        f.writelines(log_info + '\n') 

            if (i_batch + 1) == RUN_BATCH:
                log_info = "Train: %d--End" % (i_epoch)
                with open(log_path, 'a+') as f:
                    f.writelines(log_info + '\n')
                break

        epoch_losses = np.mean(np.array(loss_history), axis=0)
        L_con_epoch_losses = np.mean(np.array(L_con_history), axis=0)
        L_rev_epoch_losses = np.mean(np.array(L_rev_history), axis=0)
        L_lpf_epoch_losses = np.mean(np.array(L_lpf_history), axis=0)

        epoch_losses[1] = np.log10(optim.param_groups[0]['lr'])  
        #################
        #     val:    #
        #################

        if i_epoch % c.val_freq == 0:
            with torch.no_grad():
                psnr_history=[]
                ssim_history=[]
                acc_history=[]
                net.eval()
         
                for i_batch, mydata in enumerate(testloader):

                    data = mydata.to(device) 
                    cover = data.to(device) 

                    secret = np.random.choice([0, 1], size=[cover.size(0), 1, 256, 256])      
                    sec_map = torch.from_numpy(sec_reshape(secret)).to(device) 

                    cover_input = dwt(cover).to(device) 
                    secret_input = dwt(sec_map).to(device) 
                        
                    input_img = torch.cat((cover_input, secret_input), 1).to(device) 

                    #################
                    #    forward:   #
                    #################
                    output = net(input_img).to(device) 
                    output_steg = output.narrow(1, 0, 4 * c.channels_in).to(device) 
                    output_z = output.narrow(1, 4 * c.channels_in, output.shape[1] - 4 * c.channels_in).to(device) 
                    steg_img = iwt(output_steg).to(device) 

                    #################
                    #   backward:   #
                    #################

                    output_z_guass = gauss_noise(output_z.shape).to(device) 

                    output_rev = torch.cat((output_steg, output_z_guass), 1).to(device) 
                    output_image = net(output_rev, rev=True).to(device) 

                    secret_rev = output_image.narrow(1, 4 * c.channels_in, output_image.shape[1] - 4 * c.channels_in).to(device) 
                    secret_rev = iwt(secret_rev).to(device) 
           
                    
                    psnr, ssim = calc_psnr_ssim(cover, steg_img)
                    acc = calc_msg_acc(sec_map, secret_rev)
                    psnr_history.append([psnr, 0.])
                    ssim_history.append([ssim, 0.])

                    if (i_batch + 1) % 30 == 0:
                        log_info = "test_psnr: %.2f \tssim: %.2f \tacc: %.4f" \
                            % ( psnr, ssim, acc)
                        
                        if not os.path.exists(log_path):
                            fp = open(log_path, "w")
                            fp.writelines(log_info + "\n")
                        else:
                            with open(log_path, 'a+') as f:
                                f.writelines(log_info + '\n') 

                    if (i_batch + 1) == RUN_BATCH_VAL:
                        log_info = "Train: %d--End" % (i_epoch)
                        with open(log_path, 'a+') as f:
                            f.writelines(log_info + '\n')
                        break

                psnr_sum = 0.0
                num_psnr = len(psnr_history)
                for item in psnr_history:
                    psnr_sum += item[0]
                psnr_avg = psnr_sum / num_psnr

                ssim_sum = 0.0
                num_ssim = len(ssim_history)
                for item in ssim_history:
                    ssim_sum += item[0]
                ssim_avg = ssim_sum / num_ssim
                
                print("psnr_avg",psnr_avg)
                print("ssim_avg", ssim_avg)

                test_writer.add_scalar("psnr_avg",  psnr_avg, i_epoch)
                test_writer.add_scalar("ssim_avg",  ssim_avg, i_epoch)

        viz.show_loss(epoch_losses)
        train_writer.add_scalar("Train_Loss", epoch_losses[0], i_epoch)
        train_writer.add_scalars("Train", {"L_con_epoch_losses": L_con_epoch_losses[0]}, i_epoch) 
        train_writer.add_scalars("Train", {"L_rev_losses": L_rev_epoch_losses[0]}, i_epoch)
        train_writer.add_scalars("Train", {"L_lpf_epoch_losses": L_lpf_epoch_losses[0]}, i_epoch)
        if i_epoch > 0 and (i_epoch % c.SAVE_freq) == 0:
            torch.save({'opt': optim.state_dict(),
                        'net': net.state_dict()}, c.MODEL_PATH_1 + 'train_first_phase_checkpoint_%.5i' % i_epoch + '.pt')
      

    torch.save({'opt': optim.state_dict(),
                'net': net.state_dict()}, c.MODEL_PATH_1 + 'train_first_phase_model' + '.pt')
    test_writer .close()
    train_writer.close()

except:
    if c.checkpoint_on_error:
        torch.save({'opt': optim.state_dict(),
                    'net': net.state_dict()}, c.MODEL_PATH_1 + 'model_ABORT' + '.pt')
    raise

finally:
    viz.signal_stop()
