"""
Max-sliced Bures MNIST DCGAN experiment
Script computes one-sided max-sliced Bures divergence after each GAN training epoch and finds the real images 
with the highest discrepancy (witness function evaluations). It then computes the fraction of synthetic/fake images not belonging to a mode. The experiment is run for the desired number of Monte Carlo trials. When training completes, a plot of the number of modes, the mode precision for max-discrepancy (witness points) real images, and the mode precision for randomly selected real images is generated for each epoch. The plotted values are also saved as numpy arrays.  A cuda-enabled computer capable of running pytorch is required. This script was tested using pytorch version 1.7 but should work with earlier versions.

Options:
samples_per_epoch: Sample size (and training steps) per epoch
n_fake_examples: Number of fake images used to compute uncentered covariance matrices for Max-sliced Bures computation
n_test_witness_points: Number of real witness points used to compute mode coverage

GAN is based on https://github.com/pytorch/examples/blob/master/dcgan/main.py
"""

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import types
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import sklearn.metrics
from tqdm import tqdm
import matplotlib.pyplot as plt

opt = types.SimpleNamespace()
opt.dataset = 'stacked_mnist'
opt.dataroot = './'
opt.workers = 2
opt.batchSize = 64
opt.imageSize = 28
opt.nz = 100
opt.ngf = 64
opt.ndf = 64
opt.niter = 100
opt.lr = 0.0002
opt.beta1 = 0.5
opt.cuda = True
opt.dry_run = False
opt.ngpu = 1
opt.netG = ''
opt.netD = ''
opt.outf = './mnist_out'
opt.manualSeed = None
opt.classes = None
opt.samples_per_epoch = 1000
opt.n_fake_examples = 10000
opt.n_test_witness_points = 10
print(opt)

try:
    os.makedirs(opt.outf)
except OSError:
    pass

if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
#random.seed(opt.manualSeed)
#torch.manual_seed(opt.manualSeed)


cudnn.benchmark = True

if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")
  
if opt.dataroot is None and str(opt.dataset).lower() != 'fake':
    raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % opt.dataset)

## Create stacked MNIST dataset
class StackedMnistDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, dataset_size=None, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
       
        #self.transform = transform
        self.mnist_dataset = dset.MNIST(root=opt.dataroot, download=True,
                               transform=transforms.Compose([
                               transforms.Resize(opt.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,)),
                           ])) 
        
        if dataset_size is None:
            dataset_size = len(self.mnist_dataset)
        self.dataset_size = dataset_size

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        rand_ints = np.random.randint(0, len(self.mnist_dataset), size=3)
        
        img = []
        label = []
        for rand_int in rand_ints:
            img.append(self.mnist_dataset[rand_int][0])
            label.append(self.mnist_dataset[rand_int][1])
        
        img = torch.cat(img)
        label = torch.tensor(label)
        
        sample = [img, label]

        return sample


#%%
def wif(id):
   print(id, torch.initial_seed(), flush=True)


if opt.dataset == 'mnist':
        dataset = dset.MNIST(root=opt.dataroot, download=True,
                           transform=transforms.Compose([
                               transforms.Resize(opt.imageSize),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,)),
                           ]))
        nc=1


elif opt.dataset == 'stacked_mnist':
    dataset = StackedMnistDataset(dataset_size=opt.samples_per_epoch)
    nc = 3


assert dataset

dataloader = DataLoader(dataset, batch_size=opt.batchSize, shuffle=False, 
                        num_workers=int(opt.workers), worker_init_fn=wif)
#%%
device = torch.device("cuda:0" if opt.cuda else "cpu")
ngpu = int(opt.ngpu)
nz = int(opt.nz)
ngf = int(opt.ngf)
ndf = int(opt.ndf)


# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
          nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),
          nn.BatchNorm2d(ngf*4),
          nn.ReLU(True),
  
          nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 1, bias=False),
          nn.BatchNorm2d(ngf*2),
          nn.ReLU(True),
  
          nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf),
          nn.ReLU(True),
  
          nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
          nn.Tanh()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output




class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)


class MnistClassifier(nn.Module):
    def __init__(self):
        super(MnistClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

mnist_cls = MnistClassifier().to(device)
mnist_cls.load_state_dict(torch.load('./mnist_cnn.pt'))

# Training metrics
class Metrics():
    def __init__(self, *args):
       self.metrics = {}
       
       for arg in args:
           self.metrics[arg] = []

    def update(self, **kwargs):
        for key,value in kwargs.items():
            if key in self.metrics:
                self.metrics[key].append(value)
            else:
                self.metrics[key] = [value]
                        
    def update_confmat_metrics(self, y_true, y_pred):
        tn, fp, fn, tp = sklearn.metrics.confusion_matrix(y_true, y_pred).ravel()
        
        tpr = tp / (tp + fn)
        fpr = fp / (fp + tn)
        acc = (tp + tn) / (tp + tn + fp + fn)
        
        self.update(tpr=tpr, fpr=fpr, acc=acc)
                
    def __call__(self, *args):
       if args:
           out = []
           for arg in args:
               out.append(np.mean(self.metrics[arg]))
           return out
       else:
           out = {}
           for key,value in self.metrics.items():
               out[key] = np.mean(value)
           return out
               
    def __getattr__(self, item):
        if item in self.metrics:
            return np.mean(self.metrics[str(item)])
        else:
            return np.nan

#%% Function for computing Bures
def one_sided_max_sliced_bures(rho_x, rho_y, max_iter=100):
    
    eps = np.finfo('float32').eps
    
    rho_x_tensor = torch.tensor(rho_x.astype('float32'))
    rho_y_tensor = torch.tensor(rho_y.astype('float32'))
    
    w = torch.randn(rho_x.shape[0],1, requires_grad=True)
    w.data = torch.sign(torch.sum(rho_x_tensor,axis=0) - torch.sum(rho_y_tensor,axis=0)).t()
    w.data = w / torch.sqrt(w.t() @ w)
    
    opt = torch.optim.Adam([w], lr=1e-2, betas=(0.9,0.999), eps=1e-8)
 
    for iiter in range(max_iter):
        
        #old_w = w.clone().detach()
        
        opt.zero_grad()
        
        RMS_x = torch.sqrt( w.t() @ rho_x_tensor @ w)
        RMS_y = torch.sqrt( w.t() @ rho_y_tensor @ w)
        
        loss = -(RMS_x - RMS_y) / torch.sqrt(eps + w.t() @ w)
        loss.backward()
        opt.step()
        
        #diff = torch.abs(w.detach() - old_w)
        #print('-------------------------------------------')
        #print('weight norm: ', (torch.sqrt(eps + w.t() @ w)))
        #print('obj: ', -loss)
        
    return loss.detach().cpu().numpy(), w.detach().cpu().numpy()

#%% Create real examples and uncentered covariance matrix
n_real_samples_mult = 10
n_real_samples = 50000
bs = 1000

dataset_real = StackedMnistDataset(dataset_size=n_real_samples)
dataloader_real = torch.utils.data.DataLoader(dataset_real, batch_size=bs,
                                         shuffle=False, num_workers=4, 
                                         worker_init_fn=wif)                             
np.random.seed()
real_example = np.zeros((n_real_samples_mult, n_real_samples, 28*28*3))
real_modes = np.zeros((n_real_samples_mult, n_real_samples, 3))
real_rho_batch = np.zeros((28*28*3, 28*28*3))

for jj in range(n_real_samples_mult):
    np.random.seed()
    
    for ii, data in enumerate(dataloader_real):

            real_example[jj, ii*bs:(ii+1)*bs] = data[0].detach().cpu().numpy().reshape((-1,28*28*3))
            real_modes[jj, ii*bs:(ii+1)*bs] = data[1].detach().cpu().numpy()
            real_rho_batch_tmp = real_example[jj, ii*bs:(ii+1)*bs].T @ real_example[jj, ii*bs:(ii+1)*bs]
            real_rho_batch += real_rho_batch_tmp

real_example = real_example.reshape((n_real_samples_mult*n_real_samples, -1))
real_modes = real_modes.reshape((n_real_samples_mult*n_real_samples, -1))

real_rho = real_rho_batch    


#%%
n_monte = 1
num_modes_list = []
kl_div_list = []
osmsb_div_list_RF = []
osmsb_div_list_FR = []
mode_precision_list = []
random_mode_precision_list = []
    
for nm in range(n_monte):
    netG = Generator(ngpu).to(device)
    netG.apply(weights_init)
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
    print(netG)
    
    #%%
    netD = Discriminator(ngpu).to(device)
    netD.apply(weights_init)
    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    print(netD)
    
    criterion = nn.BCELoss()
    
    fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0
    
    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
    
    if opt.dry_run:
        opt.niter = 1
    
    #%%
    for epoch in range(opt.niter):
        np.random.seed()
        print('\nEpoch: ', epoch)
        print('Nm', nm)
        metrics = Metrics()
        progress_bar = tqdm(dataloader, bar_format='{l_bar}{bar:25}{r_bar}')
        for i, data in enumerate(progress_bar, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            # train with real
            netD.zero_grad()
            real_cpu = data[0].to(device)
            batch_size = real_cpu.size(0)
            label = torch.full((batch_size,), real_label,
                               dtype=real_cpu.dtype, device=device)
    
            d_out_real = netD(real_cpu)
            errD_real = criterion(d_out_real, label)
            errD_real.backward()
            D_x = d_out_real.mean().item()
    
            # train with fake
            noise = torch.randn(batch_size, nz, 1, 1, device=device)
            fake = netG(noise)
            label.fill_(fake_label)
            d_out_fake = netD(fake.detach())
            errD_fake = criterion(d_out_fake, label)
            errD_fake.backward()
            D_G_z1 = d_out_fake.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()
    
            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            netG.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            
            metrics.update(d_loss=errD.item())
            metrics.update(g_loss=errG.item())
            metrics.update(D_x=D_x)
            metrics.update(D_G_z1=D_G_z1)
            metrics.update(D_G_z2=D_G_z2)
            
            thresh = 0.5
            disc_pred = np.concatenate(( (d_out_fake.detach().cpu().numpy()>0.5).astype(int), (d_out_real.detach().cpu().numpy()>0.5).astype(int)))
            disc_true = np.concatenate((np.zeros(batch_size), np.ones(batch_size)))
            
            metrics.update_confmat_metrics(disc_true, disc_pred)
            
            progress_bar.set_postfix(d_loss=metrics.d_loss, g_loss=metrics.g_loss, tpr=metrics.tpr, 
                             fpr=metrics.fpr, acc=metrics.acc)
            if i % 100 == 0:
                vutils.save_image(real_cpu,
                        '%s/real_samples.png' % opt.outf,
                        normalize=True)
                fake = netG(fixed_noise)
                vutils.save_image(fake.detach(),
                        '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),
                        normalize=True)
    
            if opt.dry_run:
                break
        # do checkpointing
        torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
        torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
        
        ## test
        n_test = opt.n_fake_examples
        bs = opt.batchSize
        cls_out = np.zeros((n_test, 3))
        noise = torch.randn(n_test, nz, 1, 1, device=device)
        fake_all = np.zeros((n_test, 28 * 28 * 3))
        for ii in range(noise.shape[0]//bs+1):
            noise_batch = noise[ii*bs:(ii+1)*bs]
            fake = netG(noise_batch)
            fake_all[ii*bs:(ii+1)*bs] = fake.detach().cpu().numpy().reshape((-1,28*28*3))
            for jj in range(3):
                fake_chan = fake[:,jj:jj+1,:,:]
                cls_out_batch = mnist_cls(fake_chan).detach().cpu().numpy()
                cls_out_batch = np.argmax(cls_out_batch, axis=1)
                cls_out[ii*bs:(ii+1)*bs,jj] = cls_out_batch
        
        modes, counts = np.unique(cls_out, axis=0, return_counts=True)
        p = counts / np.sum(counts)
        
        q = [1.0 / 1000.0] * 1000
        
        num_modes = modes.shape[0]
        kl_div = np.sum(p * np.log(p/q[0]))
        
        
        fake_rho = fake_all.T @ fake_all
      
        osmsb_div_real_fake, w_real_gt_fake = one_sided_max_sliced_bures(real_rho, fake_rho)
        osmsb_div_fake_real, w_real_lt_fake = one_sided_max_sliced_bures(fake_rho, real_rho)
        
        print('num modes: ', num_modes)
        print('kl div: ', kl_div)
        
        num_modes_list.append(num_modes)
        kl_div_list.append(kl_div)
        osmsb_div_list_RF.append(osmsb_div_real_fake)
        osmsb_div_list_FR.append(osmsb_div_fake_real)
        
        #%% Find max discrepancy samples and modes
        max_slice_real_example_score = np.abs(real_example @ w_real_gt_fake)
        
        max_slice_real_example_score_sorted_inds = np.argsort(max_slice_real_example_score)
        max_discrep_examples = real_example[max_slice_real_example_score_sorted_inds[-opt.n_test_witness_points:]]
        max_discrep_modes = real_modes[max_slice_real_example_score_sorted_inds[-opt.n_test_witness_points:]]
        real_mode_in_fake_modes_list_bool = [(modes == x).all(axis=1).any() for x in max_discrep_modes]
        mode_precision = np.sum(np.logical_not( real_mode_in_fake_modes_list_bool)) / len(real_mode_in_fake_modes_list_bool)
        
        random_modes = real_modes[np.random.permutation(real_example.shape[0])[-opt.n_test_witness_points:]]
        real_mode_in_fake_modes_list_bool_random = [(modes == x).all(axis=1).any() for x in random_modes]
        mode_precision_random = np.sum(np.logical_not( real_mode_in_fake_modes_list_bool_random)) / len(real_mode_in_fake_modes_list_bool_random)
        
        mode_precision_list.append(mode_precision)
        random_mode_precision_list.append(mode_precision_random)
        
        print('\nmode precision: ', mode_precision)
        print('\nrandom mode precision: ', mode_precision_random )
    
    #%%    
    lw = 4
    plt.figure(figsize=(14,8))
    plt.plot(np.array(num_modes_list).reshape(n_monte,-1).mean(axis=0)/1000, label='Mode coverage', linewidth=lw)
    plt.plot(np.array(mode_precision_list).reshape(n_monte,-1).mean(axis=0), label='Max-sliced bures', linewidth=lw)
    plt.plot(np.array(random_mode_precision_list).reshape(n_monte,-1).mean(axis=0), label='Random chance', linewidth=lw)
    
    plt.xlabel('Epochs', fontsize=24)
    plt.legend(fontsize=20) 
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    
    plt.tight_layout()
    plt.show() 
    
    
#%%
np.save('num_modes_list', num_modes_list)
np.save('mode_precision_list', mode_precision_list)
np.save('random_mode_precision_list', random_mode_precision_list)
    
    
    
    
    
    
    
