import argparse
import torch
import os
import torch.nn as nn
import torch.optim as optim
import torchvision
from datetime import datetime
from utils import *
from dataset import get_dataloader
from torchvision import transforms
from forward_operator import get_operator
from torch.autograd import Variable
from otur.srresnet_unet_CAB import _NetG ,_NetD

from collections import defaultdict

deg_metrics = defaultdict(lambda: {
    "psnr": 0.0,
    "ssim": 0.0,
    "lpips": [],
    "count": 0
})

def extract_center_from_padded(padded_tensor, original_resolution=128):
    h, w = padded_tensor.shape[-2:]

    start_h = (h - original_resolution) // 2
    start_w = (w - original_resolution) // 2

    end_h = start_h + original_resolution
    end_w = start_w + original_resolution

    cropped_tensor = padded_tensor[..., start_h:end_h, start_w:end_w]
    
    return cropped_tensor

def margin(current_epoch, num_epoch, warm_up, max_margin):
    if current_epoch < warm_up:
        return max_margin * (current_epoch / warm_up)
    else:
        return max_margin
    
def resize(input, factor):
    transform = transforms.Resize((256//factor, 256//factor))
    return transform(input)

def hjb_regularizer(V_net, t, x, device, tau, alpha=1.0):
    """
    Computes the HJB regularization term:
        E[(∂V/∂t + (1/(2α)) * ||∇x V||^2)^2]

    Args:
        V_net: V_phi(t, x) network
        t: (B, 1, 1, 1), requires_grad=True
        x: (B, C, H, W), requires_grad=True
        alpha: float, weighting parameter in HJB

    Returns:
        hjb_loss: tensor of shape (B,)
    """

    # Flatten t to (B,) for autograd
    t_flat = t
    t_flat.requires_grad_(True)
    x.requires_grad_(True)

    # Compute V(x, t)
    V = V_net(x, t_flat) # shape: (B,)

    # ∂V/∂t
    grad_t = torch.autograd.grad(
        V.sum(), t_flat,
        create_graph=True,
    )[0]  # shape: (B,)

    # ∇ V
    grad_x = torch.autograd.grad(
        V.sum(), x,
        create_graph=True,
    )[0].reshape(x.shape[0], -1)  # shape: (B, C, H, W)

    # Compute ||∇x V||^2
    grad_x_norm_sq = torch.norm(grad_x, dim=1)**2  # (B,)

    # HJB expression per sample
    hjb_expr = 2* grad_t * tau + (1.0 / (2 * alpha)) * grad_x_norm_sq  # (B,)

    # Final loss: squared (per sample)
    hjb_loss = hjb_expr.abs()  # (B,)

    return hjb_loss

def train(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = 'cuda:0'
    args.device = device
    batch_size = args.batch_size
    nz = args.nz
    operator_type = args.operator_type
    lmda = args.lmda
    alpha = args.alpha
    normalize = args.normalize
    reg_A = args.reg_A
    reg_rank = args.reg_rank
    reg_grad = args.reg_grad
    noise = args.noise
    w_d_g1 = args.w_d_g1
    w_d_g2 = args.w_d_g2
    cycle = args.cycle




    # Get Data   
    if args.dataset in ['mnist', 'cifar10', 'cifar10+mnist', 'celeba_256'] :
        data_loader = get_dataloader(args)
        'tuple' != data_loader
    else :
        data_loader, data_loader_test = get_dataloader(args)

    # Set Generator
    from models.ncsnpp_generator_adagn import NCSNpp
    from otur.srresnet_unet_CAB import _NetG, _NetD

    if args.generator == 'NCSNpp':
        netG = NCSNpp(args).to(device)
        optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
        schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.schedule, eta_min=1e-5)
    
    elif args.generator == 'otur':
        netG = _NetG().to(device)
        optimizerG = optim.RMSprop(netG.parameters(), lr=args.lr_g)

    schedule = args.num_iterations
    total_epoch = schedule//len(data_loader)

    if args.use_ema:
        optimizerG = EMA(optimizerG, ema_decay=args.ema_decay)
    

    netG = nn.DataParallel(netG)
    
    
    # Set potential
    if args.dataset in ['mnist','cifar10','cifar10+mnist']:
        from models.discriminator import Discriminator_small
        netD = Discriminator_small(nc = args.num_channels, ngf = args.ngf, act=nn.LeakyReLU(0.2)).to(device)
        if w_d_g1>0:
            netW1= Discriminator_small(nc = args.num_channels, ngf = args.ngf, act=nn.LeakyReLU(0.2)).to(device)
    else:
        from models.discriminator import Discriminator_large
        if args.discriminator == 'NCSNpp':
            netD = Discriminator_large(nc = args.num_channels, ngf = args.ngf, act=nn.LeakyReLU(0.2)).to(device)
            optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
            schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.schedule, eta_min=1e-5)
        
        elif args.discriminator == 'otur':
            netD = _NetD().to(device)
            # args.lr_d = args.lr_g
            optimizerD = optim.RMSprop(netD.parameters(), lr=args.lr_d)

        if w_d_g1>0:
            if args.w1_discriminator == 'NCSNpp':
                netW1 = Discriminator_large(nc = args.num_channels, ngf = args.ngf, act=nn.LeakyReLU(0.2)).to(device)
                optimizerW1 = optim.Adam(netW1.parameters(), lr=args.lr_w1, betas = (args.beta1, args.beta2))
                schedulerW1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerW1, args.schedule, eta_min=1e-5)
        
            elif args.w1_discriminator == 'otur':
                netW1 = _NetD().to(device)
                optimizerW1 = optim.RMSprop(netW1.parameters(), lr=args.lr_w1)

    
    netD = nn.DataParallel(netD)
    if w_d_g1>0:
        netW1 = nn.DataParallel(netW1)
    
    if cycle:
        # Set Generator
        from models.ncsnpp_generator_adagn import NCSNpp
        if args.generator =='NCSNpp':
            netG2 = NCSNpp(args).to(device)
            optimizerG2 = optim.Adam(netG2.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
            schedulerG2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG2, args.schedule, eta_min=1e-5)
        elif args.generator =='otur':
            netG2 = _NetG().to(device)
            optimizerG2 = optim.RMSprop(netG.parameters(), lr=args.lr_g)

        schedule = args.num_iterations
        total_epoch = schedule//len(data_loader)

        if args.use_ema:
            optimizerG2 = EMA(optimizerG2, ema_decay=args.ema_decay)
        
        netG2 = nn.DataParallel(netG2)
        
        
        # Set potential
        if args.dataset in ['mnist','cifar10','cifar10+mnist']:
            from models.discriminator import Discriminator_small
            netD2 = Discriminator_small(nc = args.num_channels, ngf = args.ngf, act=nn.LeakyReLU(0.2)).to(device)
            if w_d_g2>0:
                netW2= Discriminator_small(nc = args.num_channels, ngf = args.ngf, act=nn.LeakyReLU(0.2)).to(device)
        else:
            from models.discriminator import Discriminator_large
            if args.discriminator == 'NCSNpp':
                netD2 = Discriminator_large(nc = args.num_channels, ngf = args.ngf, act=nn.LeakyReLU(0.2)).to(device)
                optimizerD2 = optim.Adam(netD2.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
                schedulerD2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD2, args.schedule, eta_min=1e-5)
        
            elif args.discriminator == 'otur':
                netD2 = _NetD().to(device)
                # args.lr_d = args.lr_g
                optimizerD2 = optim.RMSprop(netD.parameters(), lr=args.lr_d)

            if w_d_g2>0:
                if args.w2_discriminator == 'NCSNpp':
                    netW2 = Discriminator_large(nc = args.num_channels, ngf = args.ngf, act=nn.LeakyReLU(0.2)).to(device)
                    optimizerW2 = optim.Adam(netW2.parameters(), lr=args.lr_w2, betas = (args.beta1, args.beta2))
                    schedulerW2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerW2, args.schedule, eta_min=1e-5)
                elif args.w2_discriminator == 'otur':
                    netW2 = _NetD().to(device)
                    optimizerW2 = optim.RMSprop(netW2.parameters(), lr=args.lr_w2)

        netD2 = nn.DataParallel(netD2)
        if w_d_g2>0:
            netW2 = nn.DataParallel(netW2)

    # Create log path
    exp = args.exp
    if args.phi1 =='kl':
        parent_dir = "./train_logs/uotm_{}/{}".format(args.generator, args.dataset)
        if args.dataset == 'mixed':
            parent_dir = "./train_logs/uotm_{}/{}/mixed_ratio_{}".format(args.generator, args.dataset, args.mixed_ratio)
    elif args.phi1 =='linear':
        parent_dir = "./train_logs/otm_{}/{}".format(args.generator, args.dataset)
        if args.dataset == 'mixed' :
            parent_dir = "./train_logs/otm_{}/{}/mixed_ratio_{}".format(args.generator, args.dataset, args.mixed_ratio)
    else :
        print('which type ot Map?')
        assert 0
    exp_path = os.path.join(parent_dir, args.phi1)
    log_path = os.path.join(exp_path, args.operator_type, args.log_dir)
    os.makedirs(log_path, exist_ok=True)
    arg2json(args, log_path)     

    def adjust_learning_rate(lr_gd, epoch):
        """Sets the learning rate to the initial LR decayed by 10"""
        lr = lr_gd * (0.1 ** (epoch // 100))
        return lr    
  
    # Resume
    if args.resume:
        checkpoint_file = os.path.join(log_path, 'content.pth')
        checkpoint = torch.load(checkpoint_file, map_location=device)
        init_epoch = checkpoint['epoch']
        epoch = init_epoch
        netG.load_state_dict(checkpoint['netG_dict'])
        # load G
        optimizerG.load_state_dict(checkpoint['optimizerG'])
        if args.generator in ('NCSNpp', 'blind'):
            schedulerG.load_state_dict(checkpoint['schedulerG'])
        # load D
        netD.load_state_dict(checkpoint['netD_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD'])
        if args.discriminator == 'NCSNpp':
            schedulerD.load_state_dict(checkpoint['schedulerD'])
        global_step = checkpoint['global_step']
        print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
        if w_d_g1>0:
            netW1.load_state_dict(checkpoint['netW1_dict'])
            optimizerW1.load_state_dict(checkpoint['optimizerW1'])
            if args.w1_discriminator == 'NCSNpp':
                schedulerW1.load_state_dict(checkpoint['schedulerW1'])

        if cycle:
            checkpoint_file = os.path.join(log_path, 'content2.pth')
            checkpoint = torch.load(checkpoint_file, map_location=device)
            netG2.load_state_dict(checkpoint['netG2_dict'])
            # load G2
            optimizerG2.load_state_dict(checkpoint['optimizerG2'])
            if args.generator in ('NCSNpp', 'blind'):
                schedulerG2.load_state_dict(checkpoint['schedulerG2'])
            # load D2
            netD2.load_state_dict(checkpoint['netD2_dict'])
            optimizerD2.load_state_dict(checkpoint['optimizerD2'])
            if args.discriminator == 'NCSNpp':
                schedulerD2.load_state_dict(checkpoint['schedulerD2'])
            if w_d_g2>0:
                netW2.load_state_dict(checkpoint['netW2_dict'])
                optimizerW2.load_state_dict(checkpoint['optimizerW2'])
                if args.w2_discriminator == 'NCSNpp':
                    schedulerW2.load_state_dict(checkpoint['schedulerW2'])
    else:
        global_step, epoch, init_epoch = 0, 0, 0
    
    
    # Make log file
    with open(os.path.join(log_path, 'log.txt'), 'w') as f:
        f.write("Start Training")
        f.write('\n')
    
    
    # get phi star
    phi_star1 = select_phi(args.phi1)
    phi_star2 = select_phi(args.phi2)


    # Start training
    start = datetime.now()
    operator_selector = get_operator(args.operator_type)
    operator, deg_name = operator_selector()
    for epoch in range(init_epoch, total_epoch+1):
        global_step=0

        if args.generator == 'otur':
            lr_g = adjust_learning_rate(args.lr_g, epoch)
            for param_group in optimizerG.param_groups:
                param_group["lr"] = lr_g

        if args.discriminator =='otur':
            lr_d = adjust_learning_rate(args.lr_d, epoch)       
            for param_group in optimizerD.param_groups:
                param_group["lr"] = lr_d
        if w_d_g1>0:
            if args.w1_discriminator =='otur':
                lr_w1 = adjust_learning_rate(args.lr_w1,epoch)
                for param_group in optimizerW1.param_groups:
                    param_group["lr"] = lr_w1
        if cycle:
            if args.generator == 'otur':
                lr_g2 = adjust_learning_rate(args.lr_g, epoch)
                for param_group in optimizerG2.param_groups:
                    param_group["lr"] = lr_g2
            if args.discriminator =='otur':
                lr_d2 = adjust_learning_rate(args.lr_d, epoch)
                for param_group in optimizerD2.param_groups:
                    param_group["lr"] = lr_d2
            if w_d_g2>0:
                if args.w2_discriminator == 'otur':
                    lr_w2 = adjust_learning_rate(args.lr_w2, epoch)
                    for param_group in optimizerW2.param_groups:
                        param_group["lr"] = lr_w2


        for _, x in enumerate(data_loader):
            try: x_low, x_origin = x
            except: pass

            x_low = x_low.float().to(device, non_blocking=True)


            ## Update W1
            if w_d_g1 > 0:
                for p in netW1.parameters():  
                    p.requires_grad = True
                
                real_data = x_origin.float().to(device, non_blocking=True)
                real_data.requires_grad = True
                
                netW1.zero_grad()
                


                with torch.no_grad():
                    latent_z = torch.randn(batch_size, nz, device=device)
                    if args.generator =='NCSNpp':
                        x_predict = netG(x_low, latent_z)
                    elif args.generator in ('otur', 'blind'):
                        x_predict = netG(x_low)
                
                W_loss1 = -(-netW1(x_predict)) -(netW1(real_data))
                
                W_loss1 = W_loss1.mean()

                # gradient penalty
                alpha_ = torch.rand(real_data.size(0), 1, 1, 1)
                alpha1 = alpha_.cuda().expand_as(real_data) 
                interpolated1 = Variable(alpha1 * real_data + (1 - alpha1) * x_predict, requires_grad=True)
                out = netW1(interpolated1)
                grad = torch.autograd.grad(outputs=out,
                                    inputs=interpolated1,
                                    grad_outputs=torch.ones(out.size()).cuda(),
                                    retain_graph=True,
                                    create_graph=True,
                                    only_inputs=True)[0]
                
                grad = grad.view(grad.size(0), -1)
                grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                w_loss_gp1 = torch.mean((grad_l2norm - 1) ** 2)
                w_gp_loss1 = 10 * w_loss_gp1
                total_w_loss1 = w_d_g1*(W_loss1 + w_gp_loss1)
                total_w_loss1.backward()
                optimizerW1.step()
                
                netW1.zero_grad()
                for p in netW1.parameters():  
                    p.requires_grad = False

            #### Update potential ####
            for p in netD.parameters():  
                p.requires_grad = True

            real_data = x_origin.float().to(device, non_blocking=True)
            real_data.requires_grad = True 
            netD.zero_grad()

            # real D loss
            
            D_real = netD(real_data)
            errD_real = phi_star2(-D_real)
            errD_real = errD_real.mean()
            errD_real.backward(retain_graph=True)
            
            # R1 regularization
            grad_real = torch.autograd.grad(outputs=D_real.sum(), inputs=real_data, create_graph=True)[0]
            grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
            grad_penalty = args.r1_gamma / 2 * grad_penalty
            grad_penalty.backward()

            # fake D loss
            latent_z = torch.randn(batch_size, nz, device=device)
            
            if ('down_sampling' in operator_type) or ('down_sampling' in deg_name):
                x_low_resize = F.interpolate(x_low, size=(128, 128), mode='bicubic',)
                with torch.no_grad():
                    if args.generator =='NCSNpp':
                        x_0_predict = netG(x_low_resize, latent_z)
                    elif args.generator in ('otur', 'blind'):
                        x_0_predict = netG(x_low_resize)
                cost_cal = args.tau * torch.sum(((x_0_predict-x_low_resize).view(x_low_resize.size(0), -1))**2, dim=1)+\
                    args.c_like * args.tau * torch.sum(((F.interpolate(operator.measure(x_0_predict, normalize, 0), size=(128,128), mode='bicubic')-x_low_resize).view(x_low_resize.size(0), -1))**2, dim=1)
            else:
                with torch.no_grad(): 
                    if args.generator =='NCSNpp':
                        x_0_predict = netG(x_low, latent_z)
                    elif args.generator in ('otur', 'blind'):
                        x_0_predict = netG(x_low)
                if ('phase_retrieval' in operator_type) or ('phase_retrieval' in deg_name):
                    cost_cal = args.tau * torch.sum(((x_0_predict-x_low).view(x_low.size(0), -1))**2, dim=1)+\
                    args.c_like * args.tau * torch.sum(((operator.measure(extract_center_from_padded(x_0_predict), normalize, 0)-x_low).view(x_low.size(0), -1))**2, dim=1)
                else:
                    cost_cal = args.tau * torch.sum(((x_0_predict-x_low).view(x_low.size(0), -1))**2, dim=1)+\
                        args.c_like * args.tau * torch.sum(((operator.measure(x_0_predict, normalize, 0)-x_low).view(x_low.size(0), -1))**2, dim=1)
                
            
            D_fake = netD(x_0_predict)
            
            errD_fake = phi_star1(D_fake - cost_cal)
            errD_fake = errD_fake.mean()
            errD_fake.backward()
            errD = errD_real + errD_fake
            optimizerD.step()


            #### Update Generator ####
            for p in netD.parameters():
                p.requires_grad = False
            
            
            netG.zero_grad()

            # Generator loss
            x_low = x_low.clone().float().to(device)
            latent_z = torch.randn(batch_size, nz, device=device)
            
            if ('down_sampling' in operator_type) or ('down_sampling' in deg_name):
                x_low_resize = F.interpolate(x_low, size=(128, 128), mode='bicubic')
                x_low = x_low_resize
                if args.generator =='NCSNpp':
                    x_0_predict = netG(x_low, latent_z)
                elif args.generator in ('otur', 'blind'):
                    x_0_predict = netG(x_low)
                cost_cal = args.tau * torch.sum(((x_0_predict-x_low).view(x_low.size(0), -1))**2, dim=1)+\
                        args.c_like * args.tau * torch.sum(((F.interpolate(operator.measure(x_0_predict, normalize, 0), size=(128,128), mode='bicubic')-x_low).view(x_low.size(0), -1))**2, dim=1)

            
            else:
                if args.generator =='NCSNpp':
                    x_0_predict = netG(x_low, latent_z)
                elif args.generator in ('otur', 'blind'):
                    x_0_predict = netG(x_low)
                
                if ('phase_retrieval' in operator_type) or ('phase_retrieval' in deg_name):
                    cost_cal = args.tau * torch.sum(((x_0_predict-x_low).view(x_low.size(0), -1))**2, dim=1)+\
                        args.c_like * args.tau * torch.sum(((operator.measure(extract_center_from_padded(x_0_predict), normalize, 0)-x_low).view(x_low.size(0), -1))**2, dim=1)
                else:
                    cost_cal = args.tau * torch.sum(((x_0_predict-x_low).view(x_low.size(0), -1))**2, dim=1)+\
                    args.c_like * args.tau * torch.sum(((operator.measure(x_0_predict, normalize, 0)-x_low).view(x_low.size(0), -1))**2, dim=1)

            D_fake = netD(x_0_predict)
            
            err = cost_cal - D_fake
            err = err.mean()
            if w_d_g1 >0:
                gen_w_loss1= w_d_g1 * netW1(x_0_predict).mean()
                err += - gen_w_loss1
            # || T2(T1(x))-x ||=0
            if cycle: 
                if args.generator == 'NCSNpp':
                    diff = (netG2(netG(x_low, latent_z.clone().float().to(device)), latent_z.clone().float().to(device))-x_low) 
                elif args.generator in ('otur', 'blind'):
                    diff = (netG2(netG(x_low))-x_low) 
                T2T1_loss= diff.view(diff.size(0),-1).pow(2).mean(dim=1).mean()
                err += T2T1_loss

            
            err.backward()
            optimizerG.step()
            global_step += 1
            
            ## save losses
            if global_step % args.print_every == 0:
                if cycle:
                    if w_d_g1>0:
                        print(f'Epoch {epoch:04d}_iter{_} : G Loss {err.item():.4f}, D Loss {errD.item():.4f}, T2T1 loss {T2T1_loss.item():.4f}, W Loss1 {W_loss1.item():.4f}, gen_w_loss1 : {gen_w_loss1.item():.4f}')
                    else:
                        print(f'Epoch {epoch:04d}_iter{_} : G Loss {err.item():.4f}, D Loss {errD.item():.4f}, T2T1 loss {T2T1_loss.item():.4f}')
                else:
                    if w_d_g1>0:
                        print(f'Epoch {epoch:04d}_iter{_} : G Loss {err.item():.4f}, D Loss {errD.item():.4f}, W Loss1 {W_loss1.item():.4f}')
                    else:
                        print(f'Epoch {epoch:04d}_iter{_} : G Loss {err.item():.4f}, D Loss {errD.item():.4f}')

                with open(os.path.join(log_path, 'log.txt'), 'a') as f:
                    if cycle:
                        if w_d_g1>0:
                            f.write(f'Epoch {epoch:04d}_iter{_} : G Loss {err.item():.4f}, D Loss {errD.item():.4f}, T2T1 loss {T2T1_loss.item():.4f}, W Loss1 {W_loss1.item():.4f}, gen_w_loss1 : {gen_w_loss1.item():.4f}')
                        else:
                            f.write(f'Epoch {epoch:04d}_iter{_} : G Loss {err.item():.4f}, D Loss {errD.item():.4f}, T2T1 loss {T2T1_loss.item():.4f}')
                    else:
                        if w_d_g1>0:
                            f.write(f'Epoch {epoch:04d}_iter{_} : G Loss {err.item():.4f}, D Loss {errD.item():.4f}, W Loss1 {W_loss1.item():.4f}')
                        else:
                            f.write(f'Epoch {epoch:04d}_iter{_} : G Loss {err.item():.4f}, D Loss {errD.item():.4f}')
                    f.write('\n')


            if cycle:

                ## update W2
                x_low.requires_grad = True
                if w_d_g2 > 0:
                    for p in netW2.parameters():  
                        p.requires_grad = True
                    
                    source_data = x_origin.float().to(device, non_blocking=True)

                    
                    netW2.zero_grad()
                    
                    if 'down_sampling' in operator_type or ('down_sampling' in deg_name):
                        x_low_resize = F.interpolate(x_low, size=(128, 128), mode='bicubic')
                        x_low = x_low_resize # for regular-term

                    with torch.no_grad():
                        latent_z = torch.randn(batch_size, nz, device=device)
                        if args.generator =='NCSNpp':
                            x_predict2 = netG2(source_data, latent_z)
                        elif args.generator in ('otur', 'blind'):
                            x_predict2 = netG2(source_data)

                    
                    W_loss2 = -(-netW2(x_predict2)) -(netW2(x_low))
                    
                    W_loss2 = W_loss2.mean()

                    # gradient penalty
                    alpha_ = torch.rand(x_low.size(0), 1, 1, 1)
                    alpha1 = alpha_.cuda().expand_as(x_low)
                    interpolated1 = Variable(alpha1 * x_low + (1 - alpha1) * x_predict2, requires_grad=True)
                    out = netW1(interpolated1)
                    grad = torch.autograd.grad(outputs=out,
                                        inputs=interpolated1,
                                        grad_outputs=torch.ones(out.size()).cuda(),
                                        retain_graph=True,
                                        create_graph=True,
                                        only_inputs=True)[0]
                    
                    grad = grad.view(grad.size(0), -1)
                    grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
                    w_loss_gp2 = torch.mean((grad_l2norm - 1) ** 2)
                    w_gp_loss2 = 10 * w_loss_gp2
                    total_w_loss2 = w_d_g2 * (W_loss2 + w_gp_loss2)
                    total_w_loss2.backward()
                    optimizerW2.step()
                    
                    netW2.zero_grad()
                    for p in netW2.parameters():  
                        p.requires_grad = False


                #### Update potential ####
                for p in netD2.parameters():  
                    p.requires_grad = True

                if ('down_sampling' in operator_type) or ('down_sampling' in deg_name):
                    x_low = F.interpolate(x_low, size=(128, 128), mode='bicubic')
                
                    
                source_data = x_origin.float().to(device, non_blocking=True)
                x_low = x_low.clone()
                netD2.zero_grad()

                # real D loss
                
                D_real2 = netD2(x_low)
                errD_real2 = phi_star2(-D_real2)
                errD_real2 = errD_real2.mean()
                errD_real2.backward(retain_graph=True)
                
                # R1 regularization
                grad_real2 = torch.autograd.grad(outputs=D_real2.sum(), inputs=x_low, create_graph=True)[0]
                
                grad_penalty2 = (grad_real2.view(grad_real2.size(0), -1).norm(2, dim=1) ** 2).mean()
                grad_penalty2 = args.r1_gamma / 2 * grad_penalty2
                grad_penalty2.backward()

                # fake D loss
                latent_z = torch.randn(batch_size, nz, device=device)
                
                with torch.no_grad():
                    if args.generator =='NCSNpp':
                        x_0_predict2 = netG2(source_data, latent_z)
                    elif args.generator in ('otur', 'blind'):
                        x_0_predict2 = netG2(source_data)
                cost_cal = args.tau * torch.sum(((x_0_predict2-source_data).view(source_data.size(0), -1))**2, dim=1)

                D_fake2 = netD2(x_0_predict2)
                
                
                errD_fake2 = phi_star1(D_fake2 - cost_cal)
                errD_fake2 = errD_fake2.mean()
                errD_fake2.backward()
                errD2 = errD_real2 + errD_fake2
                optimizerD2.step()


                #### Update Generator ####
                for p in netD2.parameters():
                    p.requires_grad = False
                
                
                netG2.zero_grad()

                # Generator loss
                source_data = source_data.clone().float().to(device)
                latent_z = torch.randn(batch_size, nz, device=device)
                
                if args.generator =='NCSNpp':
                    x_0_predict2 = netG2(source_data, latent_z)
                elif args.generator in ('otur', 'blind'):
                    x_0_predict2 = netG2(source_data)
                cost_cal = args.tau * torch.sum(((x_0_predict2-source_data).view(source_data.size(0), -1))**2, dim=1)
                
                D_fake2 = netD2(x_0_predict2)
                
                err2 = cost_cal - D_fake2
                err2 = err2.mean()
                
                
                if w_d_g2 > 0:
                    gen_w_loss2 = w_d_g2 * netW2(x_0_predict2).mean()
                    err2 += - gen_w_loss2

                 # || T1(T2(y))-y ||=0
                if cycle: 
                    if args.generator == 'NCSNpp':
                        diff = (netG(netG2(source_data, latent_z.clone().float().to(device)), latent_z.clone().float().to(device))-source_data) 
                    elif args.generator in ('otur', 'blind'):
                        diff = (netG(netG2(source_data))-source_data)
                    T1T2_loss= diff.view(diff.size(0),-1).pow(2).mean(dim=1).mean()
                    err += T2T1_loss

                err2.backward()
                optimizerG2.step()
                
                ## save losses
                if global_step % args.print_every == 0:
                    if cycle:
                        if w_d_g2>0:
                            print(f'Epoch {epoch:04d}_iter{_} : G2 Loss {err2.item():.4f}, D2 Loss {errD2.item():.4f}, T1T2 loss {T1T2_loss.item():.4f}, W Loss2 {W_loss2.item():.4f}, gen_w_loss2 : {gen_w_loss2.item():.4f}')
                        else:
                            print(f'Epoch {epoch:04d}_iter{_} : G2 Loss {err2.item():.4f}, D2 Loss {errD2.item():.4f}, T1T2 loss {T1T2_loss.item():.4f}')
                    else:
                        if w_d_g2>0:
                            print(f'Epoch {epoch:04d}_iter{_} : G2 Loss {err2.item():.4f}, D2 Loss {errD2.item():.4f}, W Loss2 {W_loss2.item():.4f}')
                        else:
                            print(f'Epoch {epoch:04d}_iter{_} : G2 Loss {err2.item():.4f}, D2 Loss {errD2.item():.4f}')

                    with open(os.path.join(log_path, 'log.txt'), 'a') as f:
                        if cycle:
                            if w_d_g2>0:
                                f.write(f'Epoch {epoch:04d}_iter{_} : G2 Loss {err2.item():.4f}, D2 Loss {errD2.item():.4f}, T1T2 loss {T1T2_loss.item():.4f}, W Loss2 {W_loss2.item():.4f}, gen_w_loss2 : {gen_w_loss2.mean().item():.4f}')
                            else:
                                f.write(f'Epoch {epoch:04d}_iter{_} : G2 Loss {err2.item():.4f}, D2 Loss {errD2.item():.4f}, T1T2 loss {T1T2_loss.item():.4f}')
                        else:
                            if w_d_g2>0:
                                f.write(f'Epoch {epoch:04d}_iter{_} : G2 Loss {err2.item():.4f}, D2 Loss {errD2.item():.4f}, W Loss2 {W_loss2.item():.4f}')
                            else:
                                f.write(f'Epoch {epoch:04d}_iter{_} : G2 Loss {err2.item():.4f}, D2 Loss {errD2.item():.4f}')
                        f.write('\n')

            # break 
            total_psnr = 0.0
            total_ssim = 0.0
            total_images = 0
            x_predict_list = []
            x_origin_list = []
            all_lpips_scores = []
            
            deg_metrics = defaultdict(lambda: {
                "psnr": 0.0,
                "ssim": 0.0,
                "lpips": [],
                "count": 0
            })
            
            if args.dataset == 'mixed':
                deg_metrics_dog = defaultdict(lambda: {
                    "psnr": 0.0,
                    "ssim": 0.0,
                    "lpips": [],
                    "count": 0
                })
                
                deg_metrics_cat = defaultdict(lambda: {
                    "psnr": 0.0,
                    "ssim": 0.0,
                    "lpips": [],
                    "count": 0
                })

            if _ == len(data_loader)//2 or _ == 0 or _==len(data_loader)-1:
                if args.use_ema:
                    optimizerG.swap_parameters_with_ema(store_params_in_ema=True)   
                netG.eval()
                with torch.no_grad():
                    saved_deg_names = defaultdict(int)
                    for th, x in enumerate(data_loader_test):
                        try: x_low, x_origin = x
                        except: pass

                        x_low = x_low.float().to(device, non_blocking=True)
                        
                        real_data = x_origin.float().to(device, non_blocking=True)

                            

                        # fake D loss
                        latent_z = torch.randn(batch_size, nz, device=device)
                        

                        if ('down_sampling' in operator_type) or ('down_sampling' in deg_name):
                            x_low_resize = F.interpolate(x_low, size=(128, 128), mode='bicubic')
                            if args.generator =='NCSNpp':
                                x_0_predict = netG(x_low_resize, latent_z)
                            elif args.generator in ('otur', 'blind'):
                                x_0_predict = netG(x_low_resize)
                        else:
                            if args.generator =='NCSNpp':
                                x_0_predict = netG(x_low, latent_z)
                            elif args.generator in ('otur', 'blind'):
                                x_0_predict = netG(x_low)

                        if ('phase_retrieval' in operator_type) or ('phase_retrieval' in deg_name):
                            x_0_predict = extract_center_from_padded(x_0_predict)
                            
                        # for LPIPS
                        if normalize:
                            real_data_normalized = (real_data.clone().detach()).to(device)
                            x_0_predict_normalized = (x_0_predict.clone().detach()).to(device)
                        else: 
                            real_data_normalized = (real_data.clone().detach()*2-1).to(device)
                            x_0_predict_normalized = (x_0_predict.clone().detach()*2-1).to(device)
                        
                        batch_lpips = calculate_batch_lpips_piq(real_data_normalized, x_0_predict_normalized, device=device)
                        all_lpips_scores.append(batch_lpips)


                        # for PSNR, SSIM
                        if normalize:
                            real_data_value = ((real_data.clone().detach()+1)*0.5 * 255).to(device)
                            x_0_predict_value = ((x_0_predict.clone().detach()+1)*0.5 * 255).to(device)
                        else :
                            real_data_value = ((real_data.clone().detach()) * 255).to(device)
                            x_0_predict_value = ((x_0_predict.clone().detach()) * 255).to(device)

                        psnr_vals = calculate_batch_psnr(x_0_predict_value, real_data_value)  # shape: (batch_size,)
                        ssim_vals = calculate_ssim(x_0_predict_value, real_data_value)

                        batch_size = psnr_vals.size(0)

                        total_psnr += psnr_vals.sum().item()
                        total_ssim += ssim_vals.sum().item()
                        total_images += batch_size

                        deg_metrics[deg_name]["psnr"] += psnr_vals.sum().item()
                        deg_metrics[deg_name]["ssim"] += ssim_vals.sum().item()
                        deg_metrics[deg_name]["count"] += batch_size
                        deg_metrics[deg_name]["lpips"].append(batch_lpips)
                        
                        if args.dataset == 'mixed':
                            if total_images <= 200: ## for dog images
                                deg_metrics_dog[deg_name]["psnr"] += psnr_vals.sum().item()
                                deg_metrics_dog[deg_name]["ssim"] += ssim_vals.sum().item()
                                deg_metrics_dog[deg_name]["count"] += batch_size
                                deg_metrics_dog[deg_name]["lpips"].append(batch_lpips)
                                
                            else: ## for cat images
                                deg_metrics_cat[deg_name]["psnr"] += psnr_vals.sum().item()
                                deg_metrics_cat[deg_name]["ssim"] += ssim_vals.sum().item()
                                deg_metrics_cat[deg_name]["count"] += batch_size
                                deg_metrics_cat[deg_name]["lpips"].append(batch_lpips)

                        if args.dataset == 'mixed':
                            th_list = [0, 50]
                        else:
                            th_list = [0, 1]
                        if saved_deg_names[deg_name] in th_list:

                            if args.normalize :
                                images = (0.5*(x_0_predict+1)).detach().cpu()
                            else: 
                                images = ((x_0_predict)).detach().cpu()# images = (x_0_predict).detach().cpu()
                            torchvision.utils.save_image(images, os.path.join(log_path, f'epoch_{epoch}_iter_{_}_{th}_{deg_name}_{saved_deg_names[deg_name]}.png'))
                        saved_deg_names[deg_name] += 1

                    mean_psnr = total_psnr / total_images
                    mean_ssim = total_ssim / total_images
                
                    if (epoch==0 and _ == 0)  :
                        best_psnr = 0 


                    print("\n=== Results per Degradation Type ===")
                    for name, stats in deg_metrics.items():
                        psnr = stats["psnr"] / stats["count"]
                        ssim = stats["ssim"] / stats["count"]
                        lpips_tensor = torch.cat(stats["lpips"], dim=0)
                        lpips = lpips_tensor.mean().item()
                        
                    print(f"[{name}] PSNR: {psnr:.2f}, SSIM: {ssim:.4f}, LPIPS: {lpips:.4f}")
                                   
                    if args.dataset == 'mixed':
                        print("\n=== Results for Dog Images ===")
                        for name, stats in deg_metrics_dog.items():
                            psnr = stats["psnr"] / stats["count"]
                            ssim = stats["ssim"] / stats["count"]
                            lpips_tensor = torch.cat(stats["lpips"], dim=0)
                            lpips = lpips_tensor.mean().item()
                            print(f"[{name}] PSNR: {psnr:.2f}, SSIM: {ssim:.4f}, LPIPS: {lpips:.4f}")

                        print("\n=== Results for Cat Images ===")
                        for name, stats in deg_metrics_cat.items():
                            psnr = stats["psnr"] / stats["count"]
                            ssim = stats["ssim"] / stats["count"]
                            lpips_tensor = torch.cat(stats["lpips"], dim=0)
                            lpips = lpips_tensor.mean().item()
                            print(f"[{name}] PSNR: {psnr:.2f}, SSIM: {ssim:.4f}, LPIPS: {lpips:.4f}")
                    
                    print(f"\nAverage PSNR: {mean_psnr:.2f} dB")
                    print(f"Average SSIM: {mean_ssim:.4f}")
                    ## LPIPS
                    all_lpips_tensor = torch.cat(all_lpips_scores, dim=0)
                    lpips_mean = all_lpips_tensor.mean().item()
                    print(f"Average LPIPS: {lpips_mean:.4f}")
                                   
                
                    with open(os.path.join(log_path, 'result.txt'), 'a') as f:
                            f.write(f'Epoch {epoch}_iter {_} : Average PSNR {mean_psnr:.10f} dB, Average SSIM: {mean_ssim:.10f}, Average LPIPS: {lpips_mean:.10f}\n')
                            for name, stats in deg_metrics.items():
                                psnr = stats["psnr"] / stats["count"]
                                ssim = stats["ssim"] / stats["count"]
                                lpips_tensor = torch.cat(stats["lpips"], dim=0)
                                lpips = lpips_tensor.mean().item()
                                f.write(f'[{name}] PSNR {psnr:.4f}, SSIM {ssim:.4f}, LPIPS {lpips:.4f}\n')
                    if args.dataset == 'mixed':
                        with open(os.path.join(log_path, 'result_dog.txt'), 'a') as f:
                            f.write(f'Epoch {epoch}_iter {_} : Average PSNR {mean_psnr:.10f} dB, Average SSIM: {mean_ssim:.10f}, Average LPIPS: {lpips_mean:.10f}\n')
                            for name, stats in deg_metrics_dog.items():
                                psnr = stats["psnr"] / stats["count"]
                                ssim = stats["ssim"] / stats["count"]
                                lpips_tensor = torch.cat(stats["lpips"], dim=0)
                                lpips = lpips_tensor.mean().item()
                                f.write(f'[{name}] PSNR {psnr:.4f}, SSIM {ssim:.4f}, LPIPS {lpips:.4f}\n')
                                
                        with open(os.path.join(log_path, 'result_cat.txt'), 'a') as f:
                            f.write(f'Epoch {epoch}_iter {_} : Average PSNR {mean_psnr:.10f} dB, Average SSIM: {mean_ssim:.10f}, Average LPIPS: {lpips_mean:.10f}\n')
                            for name, stats in deg_metrics_cat.items():
                                psnr = stats["psnr"] / stats["count"]
                                ssim = stats["ssim"] / stats["count"]
                                lpips_tensor = torch.cat(stats["lpips"], dim=0)
                                lpips = lpips_tensor.mean().item()
                                f.write(f'[{name}] PSNR {psnr:.4f}, SSIM {ssim:.4f}, LPIPS {lpips:.4f}\n')
                                
                        with open(os.path.join(log_path, 'result_total.txt'), 'a') as f:
                            f.write(f'Epoch {epoch}_iter {_} : Average PSNR {mean_psnr:.10f} dB, Average SSIM: {mean_ssim:.10f}, Average LPIPS: {lpips_mean:.10f}\n')
                            for name, stats in deg_metrics_dog.items():
                                psnr = stats["psnr"] / stats["count"]
                                ssim = stats["ssim"] / stats["count"]
                                lpips_tensor = torch.cat(stats["lpips"], dim=0)
                                lpips = lpips_tensor.mean().item()
                                f.write(f'DOG : [{name}] PSNR {psnr:.4f}, SSIM {ssim:.4f}, LPIPS {lpips:.4f}\n')
                            for name, stats in deg_metrics_cat.items():
                                psnr = stats["psnr"] / stats["count"]
                                ssim = stats["ssim"] / stats["count"]
                                lpips_tensor = torch.cat(stats["lpips"], dim=0)
                                lpips = lpips_tensor.mean().item()
                                f.write(f'CAT : [{name}] PSNR {psnr:.4f}, SSIM {ssim:.4f}, LPIPS {lpips:.4f}\n')
                            for name, stats in deg_metrics.items():
                                psnr = stats["psnr"] / stats["count"]
                                ssim = stats["ssim"] / stats["count"]
                                lpips_tensor = torch.cat(stats["lpips"], dim=0)
                                lpips = lpips_tensor.mean().item()
                                f.write(f'AVG : [{name}] PSNR {psnr:.4f}, SSIM {ssim:.4f}, LPIPS {lpips:.4f}\n')
                        

                netG.train()
                
        # save content
                if mean_psnr>best_psnr:
                    best_psnr = mean_psnr
                    print('Saving content.')
                    if w_d_g1: 
                        if args.generator == 'NCSNpp' and args.discriminator =='NCSNpp' and args.w1_discriminator =='NCSNpp':
                            content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
                                'netG_x_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                                'schedulerG_x': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
                                'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict(),
                                'netW1_dict': netW1.state_dict(),
                                'optimizerW1': optimizerW1.state_dict(), 'schedulerW1': schedulerW1.state_dict()}
                        else:
                            content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
                                'netG_x_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                                'netD_dict': netD.state_dict(),
                                'optimizerD': optimizerD.state_dict(), 
                                'netW1_dict': netW1.state_dict(),
                                'optimizerW1': optimizerW1.state_dict()}
                       

                    else:
                        if args.generator == 'NCSNpp' and args.discriminator =='NCSNpp':
                            content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
                                'netG_x_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                                'schedulerG_x': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
                                'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
                        else:
                            content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
                                'netG_x_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
                                 'netD_dict': netD.state_dict(),
                                'optimizerD': optimizerD.state_dict()}
                        


                    torch.save(content, os.path.join(log_path, 'content.pth'))
                    if cycle:
                        if w_d_g2:
                            if args.generator == 'NCSNpp'and args.discriminator =='NCSNpp':
                                content = {'epoch': epoch + 1, 'global_step': global_step,
                                    'netG2_x_dict': netG2.state_dict(), 'optimizerG2': optimizerG2.state_dict(),
                                    'schedulerG2_x': schedulerG2.state_dict(), 'netD2_dict': netD2.state_dict(),
                                    'optimizerD2': optimizerD2.state_dict(), 'schedulerD': schedulerD2.state_dict(),
                                     'netW2_dict': netW2.state_dict(), 'optimizerW2': optimizerW2.state_dict(), 
                                     'schedulerW2': schedulerW2.state_dict()}
                            else: 
                                content = {'epoch': epoch + 1, 'global_step': global_step,
                                    'netG2_x_dict': netG2.state_dict(), 'optimizerG2': optimizerG2.state_dict(),
                                    'netD2_dict': netD2.state_dict(),
                                    'optimizerD2': optimizerD2.state_dict(),
                                     'netW2_dict': netW2.state_dict(), 'optimizerW2': optimizerW2.state_dict()}
                    
                        else:
                            if args.generator == 'NCSNpp'and args.discriminator =='NCSNpp':
                                content = {'epoch': epoch + 1, 'global_step': global_step,
                                    'netG2_x_dict': netG2.state_dict(), 'optimizerG2': optimizerG2.state_dict(),
                                    'schedulerG2_x': schedulerG2.state_dict(), 'netD2_dict': netD2.state_dict(),
                                    'optimizerD2': optimizerD2.state_dict(), 'schedulerD': schedulerD2.state_dict()}
                            else:
                                content = {'epoch': epoch + 1, 'global_step': global_step,
                                    'netG2_x_dict': netG2.state_dict(), 'optimizerG2': optimizerG2.state_dict(),
                                    'netD2_dict': netD2.state_dict(), 'optimizerD2': optimizerD2.state_dict()}
                            
                        torch.save(content, os.path.join(log_path, 'content2.pth'))
                
                    if args.use_ema:
                        optimizerG.swap_parameters_with_ema(store_params_in_ema=True)   

                    torch.save(netG.state_dict(), os.path.join(log_path, 'netG_{}_iter_{}.pth'.format(epoch, _)))
                    
                    torch.save(netD.state_dict(), os.path.join(log_path, 'netD_{}_iter_{}.pth'.format(epoch, _)))
                    
                    if cycle:
                        torch.save(netG2.state_dict(), os.path.join(log_path, 'netG2_{}_iter_{}.pth'.format(epoch, _)))
                        
                        torch.save(netD2.state_dict(), os.path.join(log_path, 'netD2_{}_iter_{}.pth'.format(epoch, _)))


                if epoch == total_epoch or epoch == total_epoch//2:
                    torch.save(netG.state_dict(), os.path.join(log_path, 'netG_{}_iter_{}.pth'.format(epoch, _)))
                    
                    torch.save(netD.state_dict(), os.path.join(log_path, 'netD_{}_iter_{}.pth'.format(epoch, _)))
                    
                    if cycle:
                        torch.save(netG2.state_dict(), os.path.join(log_path, 'netG2_{}_iter_{}.pth'.format(epoch, _)))
                        
                        torch.save(netD2.state_dict(), os.path.join(log_path, 'netD2_{}_iter_{}.pth'.format(epoch, _)))

        if args.generator in ('NCSNpp', 'blind'):
            schedulerG.step()
        if args.discriminator =='NCSNpp':
            schedulerD.step()
        if w_d_g1>0:
            if args.w1_discriminator =='NCSNpp':
                schedulerW1.step()

        if cycle:
            if args.generator in ('NCSNpp', 'blind'):
                schedulerG2.step()
            if args.discriminator =='NCSNpp':
                schedulerD2.step()
            if w_d_g2>0:
                if args.w2_discriminator =='NCSNpp':
                    schedulerW2.step()
        
        

 
if __name__ == '__main__':
    parser = argparse.ArgumentParser('UOT parameters')
    
    # Experiment description
    parser.add_argument('--seed', type=int, default=1024, help='seed used for initialization')
    parser.add_argument('--exp', default='linear', help='name of experiment')
    parser.add_argument('--resume', action='store_true',default=False, help='Resume training or not')
    parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'cifar10', 'cifar10+mnist', 'lsun', 'celeba_256', 'AFHQ', 'FFHQ', 'DIV2K', 'mixed'], help='name of dataset')
    parser.add_argument('--image_size', type=int, default=32, help='size of image')
    parser.add_argument('--num_channels', type=int, default=3, help='channel of image')
    parser.add_argument('--operator_type', type=str, required=True, help = 'output path')
    parser.add_argument('--time_embed', action='store_true' ,default=False)
    
    # Generator configurations
    parser.add_argument('--generator', type=str, default='NCSNpp', choices=['NCSNpp', 'otur','blind'], help='generator type')
    parser.add_argument('--centered', action='store_false', default=True, help='-1,1 scale')
    parser.add_argument('--num_channels_dae', type=int, default=128, help='number of initial channels in denoising model')
    parser.add_argument('--n_mlp', type=int, default=4, help='number of mlp layers for z')
    parser.add_argument('--ch_mult', nargs='+', type=int, default=[1,2,2,2], help='channel multiplier')
    parser.add_argument('--num_res_blocks', type=int, default=2, help='number of resnet blocks per scale')
    parser.add_argument('--attn_resolutions', default=(16,), help='resolution of applying attention')
    parser.add_argument('--dropout', type=float, default=0., help='drop-out rate')
    parser.add_argument('--resamp_with_conv', action='store_false', default=True, help='always up/down sampling with conv')
    parser.add_argument('--conditional', action='store_false', default=True, help='noise conditional')
    parser.add_argument('--fir', action='store_false', default=True, help='FIR')
    parser.add_argument('--fir_kernel', default=[1, 3, 3, 1], help='FIR kernel')
    parser.add_argument('--skip_rescale', action='store_false', default=True, help='skip rescale')
    parser.add_argument('--resblock_type', default='biggan', help='tyle of resnet block, choice in biggan and ddpm')
    parser.add_argument('--progressive', type=str, default='none', choices=['none', 'output_skip', 'residual'], help='progressive type for output')
    parser.add_argument('--progressive_input', type=str, default='residual', choices=['none', 'input_skip', 'residual'], help='progressive type for input')
    parser.add_argument('--progressive_combine', type=str, default='sum', choices=['sum', 'cat'], help='progressive combine method.')
    parser.add_argument('--embedding_type', type=str, default='positional', choices=['positional', 'fourier'], help='type of time embedding')
    parser.add_argument('--fourier_scale', type=float, default=16., help='scale of fourier transform')
    parser.add_argument('--not_use_tanh', action='store_true', default=False, help='use tanh for last layer')
    parser.add_argument('--z_emb_dim', type=int, default=256, help='embedding dimension of z')
    parser.add_argument('--nz', type=int, default=100, help='latent dimension')
    parser.add_argument('--ngf', type=int, default=64, help='The default number of channels of model')
    parser.add_argument('--discriminator', type=str, default='NCSNpp', choices=['NCSNpp', 'otur'], help='generator type')
    parser.add_argument('--w1_discriminator', type=str, default='NCSNpp', choices=['NCSNpp', 'otur'], help='generator type')
    parser.add_argument('--w2_discriminator', type=str, default='NCSNpp', choices=['NCSNpp', 'otur'], help='generator type')

    # Training/Optimizer configurations
    parser.add_argument('--batch_size', type=int, default=32, help='input batch size')
    parser.add_argument('--num_epoch', type=int, default=600, help='the number of epochs')
    parser.add_argument('--lr_g', type=float, default=1.6e-4, help='learning rate g')
    parser.add_argument('--lr_d', type=float, default=1.0e-4, help='learning rate d')
    parser.add_argument('--lr_w1', type=float, default=1.0e-4, help='learning rate w1')
    parser.add_argument('--lr_w2', type=float, default=1.0e-4, help='learning rate w2')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam')
    parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for adam')
    parser.add_argument('--schedule', type=int, default=1800, help='cosine scheduler, learning rate 1e-5 until {schedule} epoch')
    parser.add_argument('--use_ema', action='store_false', default=True, help='use EMA or not')
    parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')
    parser.add_argument('--num_iterations', type=int, default=60000, help='the number of epochs')
    parser.add_argument('--lmda', type=float, default=10, help='coordinate for regularization')
    parser.add_argument('--normalize', action='store_true', default=False)
    parser.add_argument('--noise',  type=str, required=True)
    parser.add_argument('--alpha', type=float, default=0.0005, help='alpha for hjb equation')
    parser.add_argument('--cycle', action='store_true', default=False)

    # Loss configurations
    parser.add_argument('--phi1', type=str, default='kl', choices=['linear', 'kl', 'softplus'])
    parser.add_argument('--phi2', type=str, default='kl', choices=['linear', 'kl', 'softplus'])
    parser.add_argument('--tau', type=float, default=0.001, help='proportion of the cost c')
    parser.add_argument('--r1_gamma', type=float, default=0.2, help='coef for r1 reg')
    parser.add_argument('--reg_A', type=float, default=0.0, help = 'coordinate of reg_A')
    parser.add_argument('--reg_rank', type=float, default=0.0, help = 'coordinate of reg_rank')
    parser.add_argument('--reg_grad', type=float, default=0.0, help = 'coordinate of reg_grad')
    parser.add_argument('--w_d_g1', type=float, default=0.0, help = 'coordniate of w_d_real')
    parser.add_argument('--w_d_g2', type=float, default=0.0, help = 'coordniate of w_d_real')
    parser.add_argument('--W1clip', type=float, default=0.0, help = 'coordniate of w_d_real')
    parser.add_argument('--W2clip', type=float, default=0.0, help = 'coordniate of w_d_real')
    parser.add_argument('--c_like', type=float, default=0.1, help='coordinate of likelihood cost')
     
    # Visualize/Save configurations
    parser.add_argument('--print_every', type=int, default=20, help='print current loss for every x iterations')
    parser.add_argument('--save_content_every', type=int, default=10, help='save content for resuming every x epochs')
    parser.add_argument('--save_ckpt_every', type=int, default=100, help='save ckpt every x epochs')
    parser.add_argument('--save_image_every', type=int, default=10, help='save images every x epochs')
    parser.add_argument('--log_dir', type=str, required=True, help = 'output path')
    
    # For mixed dataset (Class-imbalance ablation)
    parser.add_argument('--mixed_ratio', type=float, default=1.0, help='num_cat/num_dog (<=1)')
    
    args = parser.parse_args()
    train(args)
