import pdb
import os
os.environ['CUDA_VISIBLE_DEVICES']="1"
# import copy
from copy import deepcopy
import math
from typing import Optional
from munch import Munch
import numpy as np
import time
import datetime
import torch
from torch.backends import cudnn
from core.data_loader import get_train_loader, get_test_loader, InputFetcher
from torch import nn, Tensor
import torch.nn as nn
import torch.nn.functional as F
# visT
import pdb
import logging
from fastai.vision import *
import torchvision.transforms as transforms
#from score.resnet50_ft_dims_2048 import resnet50_ft
from torch.cuda.amp import autocast
import torchvision.transforms.functional as Func
import matplotlib.pyplot as plt
import tensorboardX
from Model.Face_Swap_loss import Loss_for_discriminator, Masked_Loss_for_the_mask


#os.environ['CUDA_VISIBLE_DEVICES']=1

train_writer = tensorboardX.SummaryWriter('./logs')

def moving_average(model, model_test, beta=0.999):
    for param, param_test in zip(model.parameters(), model_test.parameters()):
        param_test.data = torch.lerp(param.data, param_test.data, beta)


import TheMask as TM
TheMask = TM.build_models()
TheMask_eval = deepcopy(TheMask)
TheMask.Prepareing_adversrial_learning()

loaders = Munch(src=get_train_loader(root=['./dataset/CelebA_Dataset/CelebA_v2'],
                                             img_size=128,
                                             batch_size=16,
                                             num_workers=8))
config_for_loss = {
          'lambda_adv':1.0,
          'lambda_reg':0.01,
          'lambda_id':100.0,
          'lambda_recon':10.0,
          'lambda_att_face':1.0,
          }
config_for_opt = {
    'lr': 0.00001,            #type=float, learning rate for attribute encoder and discriminator
    'id_lr': 0.00001,         #type=float,  learning rate for identity encoder
    'beta1': 0.0,            #type=float,  decay rate for 1st moment of Adam
    'beta2': 0.99,           #type=float,  decay rate for 2nd moment of Adam
    'weight_decay': 0.0001  #type=float,  weight decay for optimizer
}



# resume training if necessary
print('Start training...')

start_time = time.time()

fetcher = InputFetcher(loaders.src,  'train')
inputs = next(fetcher)


save_dir = './save/'
dt = datetime.datetime.now()
save_name = dt.strftime("%Y_%m_%d_%H_%M_")
save_dir = save_dir+save_name

TheMask.set_grads_and_opts_4all(config_for_opt)
src_val, tar_val, src_id_val, tar_id_val, src_lm_val, tar_lm_val, src_depth_val, tar_depth_val, src_mask_val, tar_mask_val = inputs.src, inputs.tar, inputs.src_id, inputs.tar_id, inputs.src_lm, inputs.tar_lm,  inputs.src_depth, inputs.tar_depth,  inputs.src_mask, inputs.tar_mask
vis, _ = TheMask_eval(tar_val, tar_depth_val, tar_lm_val, source_id_img=src_id_val)

# pdb.set_trace()
global_step = 0

while global_step < 12000000:
    # fetch images
    inputs = next(fetcher)
    src, tar, src_id, tar_id, src_lm, tar_lm, src_depth, tar_depth, src_mask, tar_mask = inputs.src, inputs.tar, inputs.src_id, inputs.tar_id, inputs.src_lm, inputs.tar_lm, inputs.src_depth, inputs.tar_depth, inputs.src_mask, inputs.tar_mask
    # train the discriminator
    # pdb.set_trace()
    if global_step % 500 == 0:
        for name, param in TheMask.NV_module.G.named_parameters():
            train_writer.add_histogram('NV_Module_Generator_' + name, param.clone().cpu().data.numpy(), global_step)
        for name, param in TheMask.NG_module.named_parameters():
            train_writer.add_histogram('NG_Module_' + name, param.clone().cpu().data.numpy(), global_step)
        for name, param in TheMask.embedding_transfer_net.named_parameters():
            train_writer.add_histogram('Embedding_transfer_net_' + name, param.clone().cpu().data.numpy(), global_step)

    d_loss, d_losses_all = Loss_for_discriminator(TheMask, config_for_loss, src, tar, src_id, tar_id, src_lm, tar_lm,
                                                  src_mask, tar_mask, source_img_depth=src_depth,
                                                  target_img_depth=tar_depth)
    TheMask._set_zero_grads()
    d_loss.backward()
    TheMask.Dis_opt.step()

    g_loss, g_losses_all = Masked_Loss_for_the_mask(TheMask, config_for_loss, src, tar, src_id, tar_id, src_lm, tar_lm,
                                                    src_mask, tar_mask, source_img_depth=src_depth,
                                                    target_img_depth=tar_depth)
    TheMask._set_zero_grads()
    g_loss.backward()
    TheMask.TheMask_opt.step()
    # print(g_loss.item()
    if global_step % 50 == 0:
        # pdb.set_trace()
        # print('Embedding layer (Was zero initialised')
        # print(TheMask.embedding_transfer_net.weight[0:5,0])
        all_losses = dict()
        log = "Global step [%d], " % (global_step)
        '''
        for loss, prefix in zip([g_losses_all],['Swapper/all_']):
            for key, value in loss.items():
                all_losses[prefix + key] = value
                train_writer.add_scalar(prefix+key, value,global_step)
        '''
        for loss, prefix in zip([d_losses_all, g_losses_all], ['Dis/all_', 'Swapper/all_']):
            for key, value in loss.items():
                all_losses[prefix + key] = value
                train_writer.add_scalar(prefix + key, value, global_step)
        log += ' '.join(['%s: [%.6f]' % (key, value) for key, value in all_losses.items()])
        print(log)

    if global_step % 50 == 0 and global_step != 0:
        moving_average(TheMask.NV_module, TheMask_eval.NV_module, beta=0.999)
        moving_average(TheMask.NG_module, TheMask_eval.NG_module, beta=0.999)
        moving_average(TheMask.embedding_transfer_net, TheMask_eval.embedding_transfer_net, beta=0.999)
    if global_step % 500 == 0 and global_step != 0:
        # pdb.set_trace()
        with torch.no_grad():
            # train_vis,train_nv_latent  = TheMask(tar, tar_depth, tar_lm, source_id_img=src_id,nv_latent_out=True)
            eval_train_vis,_ = TheMask(tar, tar_depth, tar_lm, source_id_img=src_id)

            vis_train,_ = TheMask(tar_val, tar_depth_val, tar_lm_val, source_id_img=src_id_val)
            vis,_ = TheMask_eval(tar_val, tar_depth_val, tar_lm_val, source_id_img=src_id_val)
            train_writer.add_image('valid_img/1_eval_tar', tar_val[0], global_step)
            train_writer.add_image('valid_img/2_eval_lm', tar_lm_val[0], global_step)
            train_writer.add_image('valid_img/3_eval_depth', tar_depth_val[0], global_step)
            # pdb.set_trace()
            tmp = (127.5 * (src_id_val[0] + 1.0)).type(torch.int64)
            tmp = tmp / 255.0
            train_writer.add_image('valid_img/4_eval_source_id', tmp, global_step)
            train_writer.add_image('valid_img/6_train_model_image', vis_train[0], global_step)
            train_writer.add_image('valid_img/5_eval_model_image', vis[0], global_step)

            vis = vis.detach().cpu().numpy().transpose(0, 2, 3, 1)[0]
            vis = (vis * 255).astype(int)
            plt.imshow(vis)
            plt.title('valid_vis_t2s_img')
            plt.savefig('./tmp_vis/val_%d_step.png' % (global_step))

            vis_train = vis_train.detach().cpu().numpy().transpose(0, 2, 3, 1)[0]
            vis_train = (vis_train * 255).astype(int)
            plt.imshow(vis_train)
            plt.title('train_vis_t2s_img')
            plt.savefig('./tmp_vis/train_%d_step.png' % (global_step))

            train_writer.add_image('train_img/1_target_image', tar[0], global_step)
            train_writer.add_image('train_img/2_target_lm_image', tar_lm[0], global_step)
            train_writer.add_image('train_img/3_target_depth_image', tar_depth[0], global_step)
            tmp = (127.5 * (src_id[0] + 1.0)).type(torch.int64)
            tmp = tmp / 255.0
            train_writer.add_image('train_img/4_source_id_image', tmp, global_step)
            train_writer.add_image('train_img/5_swapped_image', eval_train_vis[0], global_step)

    if global_step % 4000 == 0 and global_step != 0:
        # pdb.set_trace()
        print('[Global step %d] Saving trained models' % (global_step))
        TheMask.save(save_dir, global_step)
        TheMask_eval.save(save_dir + '_eval_', global_step)
    global_step += 1




