''' train_fns.py
Functions for the main loop of training different conditional image models
'''
import torch
import torch.nn as nn
from torch import autograd
import torchvision
import torch.nn.functional as F
import os
from tqdm import tqdm
import numpy as np

import utils
import losses
from DiffAugment_pytorch import DiffAugment



# Dummy training function for debugging
def dummy_training_function():
    def train(x, y):
        return {}
    return train

# def get_grad(netD, x, y, norm=True):
#     if not x.requires_grad:
#         x = autograd.Variable(x, requires_grad=True)
#     disc_interpolates = netD(x, y)
#     gradients = autograd.grad(outputs=disc_interpolates, inputs=x, grad_outputs=torch.ones(
#                                     disc_interpolates.size()).cuda(),
#                                 create_graph=True, retain_graph=True, only_inputs=True)[0]
#     gradients = gradients.view(gradients.size(0), -1)
#     if norm:
#         gradient_norm = gradients.norm(2, dim=1)
#         return gradient_norm
#     else:
#         return gradients

def get_grad(x, disc_interpolates, norm=True):
    gradients = autograd.grad(outputs=disc_interpolates, inputs=x, grad_outputs=torch.ones(
                                    disc_interpolates.size()).cuda(),
                                create_graph=False, retain_graph=True, only_inputs=True, allow_unused=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    if norm:
        gradient_norm = gradients.norm(2, dim=1)
        return gradient_norm
    else:
        return gradients


def GAN_training_function(G, D, GD, z_, y_, ema, state_dict, config):
    def train(x, y, grad_real_ema, grad_fake_ema):
        G.optim.zero_grad()
        D.optim.zero_grad()
        # How many chunks to split x and y into?
        x = torch.split(x, config['batch_size'])
        y = torch.split(y, config['batch_size'])
        counter = 0

        # Optionally toggle D and G's "require_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, True)
            utils.toggle_grad(G, False)

        for step_index in range(config['num_D_steps']):
            # If accumulating gradients, loop multiple times before an optimizer step
            D.optim.zero_grad()
            for accumulation_index in range(config['num_D_accumulations']):
                z_.sample_()
                y_.sample_()
                D_scores = GD(z_[:config['batch_size']], y_[:config['batch_size']],
                              x[counter], y[counter], train_G=False, policy=config['DiffAugment'],
                              CR=config['CR'] > 0, CR_augment=config['CR_augment'], return_G_z=True)

                D_loss_CR = 0
                if config['CR'] > 0:
                    D_fake, D_real, D_real_aug = D_scores
                    D_loss_CR = torch.mean(
                        (D_real_aug - D_real) ** 2) * config['CR']
                else:
                    D_out, D_input = D_scores
                    D_fake, D_real = torch.split(D_out, [config['batch_size'], config['batch_size']])
                    # G_z_aug, x_aug = torch.split(D_input, [config['batch_size'], config['batch_size']])

                # Compute components of D's loss, average them, and divide by
                # the number of gradient accumulations
                D_loss_real, D_loss_fake = losses.discriminator_loss(
                    D_fake, D_real)
                D_loss = D_loss_real + D_loss_fake + D_loss_CR
                D_loss = D_loss / float(config['num_D_accumulations'])

                if step_index==0 and config['use_Dreg']:
                    
                    # === diggan penalty ===
                    grads = get_grad(D_input, D_out)
                    grad_fake, grad_real = torch.split(grads, [config['batch_size'], config['batch_size']])
                    
                    if grad_real_ema is None:
                        D_loss = D_loss + ((grad_fake-grad_real)**2).mean() * config['Dreg_weight']
                        grad_real_ema = grad_real.data
                        grad_fake_ema = grad_fake.data
                    else:
                        grad_real_avg = 0.5 * grad_real + 0.5 * grad_real_ema
                        grad_fake_avg = 0.5 * grad_fake + 0.5 * grad_fake_ema
                        D_loss = D_loss + ((grad_fake_avg - grad_real_avg) ** 2).mean() * config['Dreg_weight']
                        grad_real_ema = grad_real_avg.data
                        grad_fake_ema = grad_fake_avg.data

                    # # ===== R1 reg =======
                    # grad_real = get_grad(D, x_aug, y[counter])
                    # D_loss = D_loss + ((grad_real)**2).mean() * config['Dreg_weight']

                    # # ===== R0 reg =======
                    # grad_real = get_grad(D, G_z_aug, y_[:config['batch_size']])
                    # D_loss = D_loss + ((grad_real)**2).mean() * config['Dreg_weight']

                    # # ===== Dragan ======
                    # alpha = torch.rand(x_aug.size()[0], 1, 1, 1).expand(x_aug.size()).cuda()
                    # x_interpo = alpha * x_aug.data + (1 - alpha) * (x_aug.data + 0.5 * x_aug.data.std() * torch.rand(x_aug.size()).cuda())
                    # gradients = get_grad(D, x_interpo, y[counter])

                    # D_loss = D_loss + ((gradients - 1) ** 2).mean() * config['Dreg_weight']

                    # # ========== lecam =======
                    # if state_dict['itr'] < 1000:
                    #     decay = 0.0
                    # else:
                    #     decay = 0.9
                    # if grad_real_ema is None:
                    #     grad_real_ema = 1000.
                    #     grad_fake_ema = 1000.
                    # grad_real_ema = (grad_real_ema*decay + D_real*(1 - decay)).data
                    # grad_fake_ema = (grad_fake_ema*decay + D_fake*(1 - decay)).data
                    # reg = torch.mean(F.relu(D_real - grad_fake_ema).pow(2)) + \
                    #                         torch.mean(F.relu(grad_real_ema - D_fake).pow(2))
                    # D_loss = D_loss + 0.3 * reg
                    
                D_loss.backward()
                counter += 1

            # Optionally apply ortho reg in D
            if config['D_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in D.
                print('using modified ortho reg in D')
                utils.ortho(D, config['D_ortho'])

            D.optim.step()

        # Optionally toggle "requires_grad"
        if config['toggle_grads']:
            utils.toggle_grad(D, False)
            utils.toggle_grad(G, True)

        # Zero G's gradients by default before training G, for safety
        G.optim.zero_grad()

        if not config['fix_G']:
            # If accumulating gradients, loop multiple times
            for accumulation_index in range(config['num_G_accumulations']):
                z_.sample_()
                y_.sample_()
                D_fake = GD(z_, y_, train_G=True, policy=config['DiffAugment'])
                G_loss = losses.generator_loss(
                    D_fake) / float(config['num_G_accumulations'])
                G_loss.backward()

            # Optionally apply modified ortho reg in G
            if config['G_ortho'] > 0.0:
                # Debug print to indicate we're using ortho reg in G
                print('using modified ortho reg in G')
                # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this
                utils.ortho(G, config['G_ortho'],
                            blacklist=[param for param in G.shared.parameters()])
            G.optim.step()

            # If we have an ema, update it, regardless of if we test with it or not
            if config['ema']:
                ema.update(state_dict['itr'])

        out = {'G_loss': float(G_loss.item()) if not config['fix_G'] else 0,
               'D_loss_real': float(D_loss_real.item()),
               'D_loss_fake': float(D_loss_fake.item()),
               }
        if config['CR'] > 0:
            out['D_loss_CR'] = float(D_loss_CR.item())
        # Return G's loss and the components of D's loss.
        return out, grad_real_ema, grad_fake_ema
    return train


''' This function takes in the model, saves the weights (multiple copies if 
    requested), and prepares sample sheets: one consisting of samples given
    a fixed noise seed (to show how the model evolves throughout training),
    a set of full conditional sample sheets, and a set of interp sheets. '''


def save_and_sample(G, D, G_ema, z_, y_, fixed_z, fixed_y,
                    state_dict, config, experiment_name):
    utils.save_weights(G, D, state_dict, config['weights_root'],
                       experiment_name, 'iter%d' % state_dict['itr'], G_ema if config['ema'] else None)
    # Save an additional copy to mitigate accidental corruption if process
    # is killed during a save (it's happened to me before -.-)
    if config['num_save_copies'] > 0:
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name,
                           'copy%d' % state_dict['save_num'],
                           G_ema if config['ema'] else None)
        state_dict['save_num'] = (
            state_dict['save_num'] + 1) % config['num_save_copies']

    # Use EMA G for samples or non-EMA?
    which_G = G_ema if config['ema'] and config['use_ema'] else G

    # Accumulate standing statistics?
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
                                        z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])

    # Save a random sample sheet with fixed z and y
    with torch.no_grad():
        if config['parallel']:
            fixed_Gz = nn.parallel.data_parallel(
                which_G, (fixed_z, which_G.shared(fixed_y)))
        else:
            fixed_Gz = which_G(fixed_z, which_G.shared(fixed_y))
    if not os.path.isdir('%s/%s' % (config['samples_root'], experiment_name)):
        os.mkdir('%s/%s' % (config['samples_root'], experiment_name))
    image_filename = '%s/%s/fixed_samples%d.jpg' % (config['samples_root'],
                                                    experiment_name,
                                                    state_dict['itr'])
    torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename,
                                 nrow=int(fixed_Gz.shape[0] ** 0.5), normalize=True)
    # For now, every time we save, also save sample sheets
    utils.sample_sheet(which_G,
                       classes_per_sheet=utils.classes_per_sheet_dict[config['dataset']],
                       num_classes=config['n_classes'],
                       samples_per_class=10, parallel=config['parallel'],
                       samples_root=config['samples_root'],
                       experiment_name=experiment_name,
                       folder_number=state_dict['itr'],
                       z_=z_)
    # Also save interp sheets
    for fix_z, fix_y in zip([False, False, True], [False, True, False]):
        utils.interp_sheet(which_G,
                           num_per_sheet=16,
                           num_midpoints=8,
                           num_classes=config['n_classes'],
                           parallel=config['parallel'],
                           samples_root=config['samples_root'],
                           experiment_name=experiment_name,
                           folder_number=state_dict['itr'],
                           sheet_number=0,
                           fix_z=fix_z, fix_y=fix_y, device='cuda')


''' This function runs the inception metrics code, checks if the results
    are an improvement over the previous best (either in IS or FID, 
    user-specified), logs the results, and saves a best_ copy if it's an 
    improvement. '''


def test(G, D, GD, G_ema, z_, y_, state_dict, config, sample, get_inception_metrics,
         experiment_name, test_log, acc_metrics, acc_itrs):
    print('Calculating validation accuracy...')
    D.eval()
    D_accuracy = []
    loader = utils.get_data_loaders(
        **{**config, 'train': False, 'use_multiepoch_sampler': False, 'load_in_mem': False})[0]
    with torch.no_grad():
        for x, y in loader:
            D_real = GD(None, None, x=x, dy=y, policy=config['DiffAugment'])
            D_accuracy.append((D_real > 0).float().mean().item())
    D.train()
    D_acc_val = np.mean(D_accuracy)

    print('Calculating training accuracy...')
    D.eval()
    D_accuracy = []
    loader = utils.get_data_loaders(
        **{**config, 'train': True, 'use_multiepoch_sampler': False, 'load_in_mem': False})[0]
    with torch.no_grad():
        for x, y in loader:
            D_real = GD(None, None, x=x, dy=y, policy=config['DiffAugment'])
            D_accuracy.append((D_real > 0).float().mean().item())
    D.train()
    D_acc_train = np.mean(D_accuracy)

    print('Gathering inception metrics...')
    if config['accumulate_stats']:
        utils.accumulate_standing_stats(G_ema if config['ema'] and config['use_ema'] else G,
                                        z_, y_, config['n_classes'],
                                        config['num_standing_accumulations'])
    IS_mean, IS_std, FID = get_inception_metrics(sample,
                                                 config['num_inception_images'],
                                                 num_splits=10, use_torch=False)
    print('Itr %d: PYTORCH UNOFFICIAL Inception Score is %3.3f +/- %3.3f, PYTORCH UNOFFICIAL FID is %5.4f' %
          (state_dict['itr'], IS_mean, IS_std, FID))
    # If improved over previous best metric, save approrpiate copy
    if ((config['which_best'] == 'IS' and IS_mean > state_dict['best_IS'])
            or (config['which_best'] == 'FID' and FID < state_dict['best_FID'])):
        print('%s improved over previous best, saving checkpoint...' %
              config['which_best'])
        utils.save_weights(G, D, state_dict, config['weights_root'],
                           experiment_name, 'best%d' % state_dict['save_best_num'] if config['num_best_copies'] > 1 else 'best',
                           G_ema if config['ema'] else None)
        state_dict['save_best_num'] = (
            state_dict['save_best_num'] + 1) % config['num_best_copies']
    state_dict['best_IS'] = max(state_dict['best_IS'], IS_mean)
    state_dict['best_FID'] = min(state_dict['best_FID'], FID)

    # Log results to file
    test_log.log(itr=int(state_dict['itr']), IS_mean=float(IS_mean),
                 IS_std=float(IS_std), FID=float(FID), D_acc_val=D_acc_val, D_acc_train=D_acc_train,
                 Dreg_weight=config['Dreg_weight'],
                 **{k: v / acc_itrs for k, v in acc_metrics.items()})
    
    # if state_dict['itr'] > 20000 and FID >= state_dict['best_IS'] + 2:
    #     config['Dreg_weight'] = config['Dreg_weight'] * 2