"""main.py"""

import argparse
import os
from os.path import join
from torchvision import transforms
import time

import numpy as np
# from scipy import stats
import torch
import torch.utils.data as Data
import copy
from itertools import islice

# from betaVAE.solver import Solver
from betaVAE.utils import str2bool, cuda
from betaVAE.model import BetaVAE_H, BetaVAE_B
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.nn.functional as F

from clstool import __version__, build_model
from clstool.utils.io import checkpoint_loader
from betaVAE.dataset import CelebA

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

def load_checkpoint(net, ckpt_dir, viz_name, filename):   
        file_path = os.path.join(ckpt_dir, viz_name, filename)
        if os.path.isfile(file_path):
            checkpoint = torch.load(file_path)
            net.load_state_dict(checkpoint['model_states']['net'])
            print("=> loaded checkpoint '{} (iter {})'".format(file_path, 1500000))
        else:
            print("=> no checkpoint found at '{}'".format(file_path))

    
def dynamic_pert(image, model, net, mu, logvar, num_eps, decrement, num_classes=2, max_iter=50):
        device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
        print("Using GPU")
        image = image.to(device)
        net = net.to(device)
        model = model.to(device)
        
        args.image_size = 224
        transform = transforms.Compose([
            transforms.Resize((args.image_size, args.image_size)),
        ])
        image_r = F.sigmoid(net._decode(mu))  # B, 3, 64, 64
        image = transform(image)  # B, 3, 224, 224
        image_r = transform(image_r)  # B, 3, 224, 224

        f_image = model.forward(Variable(image, requires_grad=True)).data.cpu().numpy()  
        I = (np.array(f_image)).argsort()[:, ::-1]  
        I = I[:, 0:num_classes]  
        label = I[:, 0]  

        f_image_r = model.forward(Variable(image_r, requires_grad=True)).data.cpu().numpy() 
        I_r = (np.array(f_image_r)).argsort()[:, ::-1]  # B, out 
        I_r = I_r[:, 0:num_classes]  # B, 2  
        label_r = I_r[:, 0]  # B, 1  

        threshold = torch.zeros(args.batch_size, args.z_dim, 2).long()  # B, z_dim, 2
        tol = torch.zeros(args.batch_size, args.z_dim, 1)   # B, z_dim, 1 
        bia = torch.zeros(args.batch_size, args.z_dim, 1)   # B, z_dim, 1
        
        std = logvar.div(2).exp()  # B, z_dim
        eps = mu.clone().unsqueeze(-1).repeat(1, 1, 2 * num_eps)  # B, z_dim
        for m in range(mu.shape[0]):
            for n in range(mu.shape[1]):
                eps_front = torch.arange(0 - num_eps * decrement * std[m, n].item(), 0, decrement * std[m, n].item())
                eps_back = torch.arange(0, 0 + num_eps * decrement * std[m, n].item(), decrement * std[m, n].item())
                row_eps = torch.cat((eps_front, eps_back))
                eps[m, n, :] = row_eps  # B, z_dim, r_dim     middle->num_eps
        

        for zi in range(args.z_dim):
            eps_i = num_eps  # [mu, max) 
            label_p = label_r  # B, 1   
            while len(torch.nonzero(threshold[:, zi, 1])) != args.batch_size and eps_i < 2*num_eps: 
                z = mu.clone()                  # B, z_dim
                z[:, zi] = mu[:, zi] + eps[:, zi, eps_i]    # B, z_dim
                x_recon = F.sigmoid(net._decode(z))  # B, 3, 64, 64
                f_image_p = model.forward(transform(x_recon)).data.cpu().numpy()    # B, 2  
                I_p = (np.array(f_image_p)).argsort()[:, ::-1]  # B, out  
                I_p = I_p[:, 0:num_classes]  # B, 2  
                label_p = I_p[:, 0]  # B, 1  
                condition = torch.tensor(label_p != label_r) & (threshold[:, zi, 1] == 0)  # B, 1  
                threshold[torch.nonzero(condition).squeeze(), zi, 1] = eps_i  # B, 1 
                eps_i += 1
            threshold[threshold[:, zi, 1]==0, zi, 1] = eps_i - 1
                
            eps_i = num_eps - 1  # (min, mu)
            label_p = label_r  
            while len(torch.nonzero(threshold[:, zi, 0])) != args.batch_size and eps_i > 0: 
                z = mu.clone()                  # B, z_dim
                z[:, zi] = mu[:, zi] + eps[:, zi, eps_i]     # B, z_dim
                x_recon = F.sigmoid(net._decode(z))  # B, 3, 64, 64
                f_image_p = model.forward(transform(x_recon)).data.cpu().numpy()    # B, 2 
                I_p = (np.array(f_image_p)).argsort()[:, ::-1]  # B, out
                I_p = I_p[:, 0:num_classes]  # B, 2
                label_p = I_p[:, 0]  # B, 1 
                condition = torch.tensor(label_p != label_r) & (threshold[:, zi, 0] == 0)  # B, 1 
                threshold[torch.nonzero(condition).squeeze(), zi, 0] = eps_i  # B, 1  
                eps_i -= 1
            threshold[threshold[:, zi, 0]==0, zi, 0] = eps_i + 1

            left_tail = mu[:, zi].cpu() + (threshold[:, zi, 0] - num_eps) * decrement * std[:, zi].cpu()
            right_tail = mu[:, zi].cpu() + (threshold[:, zi, 1] - num_eps) * decrement * std[:, zi].cpu()
            tol[:, zi]  =  ((right_tail - left_tail) / std[:, zi].cpu()).unsqueeze(1)  # B, 1  
            bia[:, zi] = (torch.abs((right_tail + left_tail) / 2 - mu[:, zi].cpu()) / std[:, zi].cpu()).unsqueeze(1)  # B, 1 


        return torch.sum(threshold, dim=0), torch.sum(tol, dim=0), torch.sum(bia, dim=0)

def main(args):
    seed = args.seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    use_cuda = args.cuda and torch.cuda.is_available()

    args.viz_name = 'celeba_H_beta10_z' + str(args.z_dim)
    args.no_pretrain = True

    num_eps = 101
    decrement = 0.5  #(0,1)
    
    # ** net **
    net = BetaVAE_H(args.z_dim, 3)
    net.to(device)
    net.eval()

    if args.ckpt_name is not None:
        load_checkpoint(net, args.ckpt_dir, args.viz_name, args.ckpt_name)

    # ** dataset **
    attr_label = { 'main_attr': args.main_attr, 'sub_attr': args.sub_attrs,}
    test_dataset = CelebA(args.data_root, args.image_size, 'valid', attr_label=attr_label)
    test_dataloader = Data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=True)

    # ** model **
    model = build_model(args)
    model.to(device)
    resume_dir = '/nfs/jhz/Disentangled_Fairness/runs/' + args.main_attr.lower() + '/' + args.model + '_' + args.dataset + '/checkpoint0029.pth'
    checkpoint = torch.load(resume_dir, map_location='cpu')
    print("=> loaded checkpoint '{}'".format(resume_dir))
    checkpoint_loader(model, checkpoint['model'], delete_keys=())
    model.eval()

  
    # ** dynamic_pert **
    print(len(test_dataloader) // 20)
    slice_dataloader = islice(test_dataloader, len(test_dataloader) // 20)  # valid:19867 -> 993
    Ths = torch.zeros(args.z_dim, 2)
    Tols = torch.zeros(args.z_dim, 1)
    Deviation = torch.zeros(args.z_dim, 1)
    batch_idx = 0
    for x, main_attrs, sub_attrs, x_name in slice_dataloader:
        x = x.to(device)
        x_recon, mu, logvar = net(x, repara_name='reparametrize_slide', num_eps = num_eps, decrement = decrement)  # B, z, r, 3, 64, 64
        thresholds, tolerances, deviation = dynamic_pert(x, model, net, mu, logvar, num_eps, decrement) 
        Ths = Ths + thresholds
        Tols = Tols + tolerances
        Deviation = Deviation + deviation
        batch_idx = batch_idx + 1
        print('batch_idx: ', batch_idx)
    Ths = Ths / (batch_idx * args.batch_size)
    Tols = Tols / (batch_idx * args.batch_size)
    Deviation = Deviation / (batch_idx * args.batch_size)
    with open(f'out_dis_{args.z_dim}.txt', 'a') as file:
        print(f'model:{args.model}', file=file)
        coupling = 0
        for zi in range(args.z_dim):
            if Tols[zi].item() < 0.9 * (num_eps - 1):
                coupling += 1   
            print("z{}  min: {:.2f}  max: {:.2f}  tol:{:.2f}  bia:{:.2f}".format(zi, Ths[zi, 0], Ths[zi, 1], Tols[zi].item(), Deviation[zi].item()), file=file)
        print("coupling:{:.2f}  tolerance:{:.2f}  bia:{:.2f}".format(coupling / args.z_dim, torch.mean(Tols.float(), dim=0).item(), torch.mean(Deviation.float(), dim=0).item()), file=file)
        print(' ', file=file)
    
    
   


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='dynamic_pert')

    parser.add_argument('--device', default='cuda', help='cuda or cpu')
    parser.add_argument('--train', default=True, type=str2bool, help='train or traverse')
    parser.add_argument('--seed', default=1, type=int, help='random seed')
    parser.add_argument('--cuda', default=True, type=str2bool, help='enable cuda')
    parser.add_argument('--max_iter', default=1e6, type=float, help='maximum training iteration')
    parser.add_argument('--batch_size', default=5, type=int, help='batch size')

    parser.add_argument('--z_dim', default=10, type=int, help='dimension of the representation z')
    parser.add_argument('--beta', default=10, type=float, help='beta parameter for KL-term in original beta-VAE')
    parser.add_argument('--objective', default='H', type=str, help='beta-vae objective proposed in Higgins et al. or Burgess et al. H/B')
    # parser.add_argument('--model', default='H', type=str, help='model proposed in Higgins et al. or Burgess et al. H/B')
    parser.add_argument('--gamma', default=1000, type=float, help='gamma parameter for KL-term in understanding beta-VAE')
    parser.add_argument('--C_max', default=25, type=float, help='capacity parameter(C) of bottleneck channel')
    parser.add_argument('--C_stop_iter', default=1e5, type=float, help='when to stop increasing the capacity')
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--beta1', default=0.9, type=float, help='Adam optimizer beta1')
    parser.add_argument('--beta2', default=0.999, type=float, help='Adam optimizer beta2')

    parser.add_argument('--dset_dir', default='data', type=str, help='dataset directory')
    parser.add_argument('--dataset', default='celeba', type=str, help='dataset name')
    parser.add_argument('--image_size', default=64, type=int, help='image size. now only (64,64) is supported')
    parser.add_argument('--num_workers', default=2, type=int, help='dataloader num_workers')

    parser.add_argument('--viz_on', default=True, type=str2bool, help='enable visdom visualization')
    parser.add_argument('--viz_name', default='main', type=str, help='visdom env name')
    parser.add_argument('--viz_port', default=8097, type=str, help='visdom port number')
    parser.add_argument('--save_output', default=True, type=str2bool, help='save traverse images and gif')
    parser.add_argument('--output_dir', default='outputs', type=str, help='output directory')

    parser.add_argument('--gather_step', default=1000, type=int, help='numer of iterations after which data is gathered for visdom')
    parser.add_argument('--display_step', default=10000, type=int, help='number of iterations after which loss data is printed and visdom is updated')
    parser.add_argument('--save_step', default=10000, type=int, help='number of iterations after which a checkpoint is saved')

    parser.add_argument('--ckpt_dir', default='betaVAE/checkpoints', type=str, help='checkpoint directory')
    parser.add_argument('--ckpt_name', default='last', type=str, help='load previous checkpoint. insert checkpoint filename')

    # model
    parser.add_argument('--model_lib', default='default', type=str, choices=['default', 'timm'], help='model library')
    parser.add_argument('--model', '-m', default='resnet34', type=str, help='model name')
    parser.add_argument('--model_kwargs', default=dict(), help='model specific kwargs')

    # loading weights
    parser.add_argument('--no_pretrain', action='store_true')
    parser.add_argument('--resume', '-r', type=str)
    parser.add_argument('--load_pos', type=str)

    # dataset
    parser.add_argument('--data_root', type=str, default='./data')
    # parser.add_argument('--dataset', '-d', type=str, default='celeba')
    parser.add_argument('--main_attr', type=str, default='Attractive')
    parser.add_argument('--sub_attrs', type=list, default=['Bald', 'Eyeglasses', 'Mouth_Slightly_Open', 'No_Beard', 'Pale_Skin'])
    parser.add_argument('--img_id', type=int, default=0)
    args = parser.parse_args()

    start_time = time.time()
    
    main(args)

    end_time = time.time()
    execution_time = end_time - start_time
    print("Execution time:", execution_time, "seconds")