import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchdp import PrivacyEngine

from model import NonDPModel, DPEncoder, DPDecoder
from data_loaders import get_clean_loaders
from data_creation import percentile_from_md
from plotting_functions import plot_tabular_df, generated_synthetic_dataframe
import numpy as np
import os
from plotting_functions import plot_histograms_of_reconstructed_data

from sklearn.manifold import TSNE
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=128, type=int, help='training batch size')
parser.add_argument('--trainset_size', default=None, type=int,
                    help='Size of the training set (set to None for full training set)')

parser.add_argument('--log_interval', default=10, type=int, help='how frequently to print loss')

parser.add_argument('--task', default='MNIST', type=str, help='One of [MNIST, LendingClub]')

parser.add_argument('--rep_dim', default=8, type=int, help='dimension of representation')

parser.add_argument('--diagonal_x_std', action='store_true',
                    help='if true parameterise std with NN, else use learnable scalar')
parser.add_argument('--latent_distn', default='Laplace', type=str, help='Gaussian, Laplace or Binary')

parser.add_argument('--dropout_rate', default=0., type=float, help='Dropout rate.')
parser.add_argument('--validation', action='store_true', help='if true, split data into train/val')

parser.add_argument('--md', default=None, type=float,
                    help='Mahalanobis distance defining the clipping distance during pre-training. If None, do not clip.')
parser.add_argument('--posterior_std', default=None, type=float,
                    help='std of approx posterior for rep learning. If None, parameterise with NN.')
parser.add_argument('--prior_std', default=1.0, type=float, help='std of prior distn.')
# parser.add_argument('--truncated_prior', action='store_true', help='truncate the prior.')

parser.add_argument('--epochs_stage1', type=int, default=50, metavar='N',
                    help='number of epochs to train non-dp model')
parser.add_argument('--epochs_stage2', type=int, default=200, metavar='N',
                    help='number of epochs to train dp model')
parser.add_argument('--lr', default=1e-4, type=float, help='Learning Rate for non-dp-sgd stage.')
parser.add_argument('--dp_lr', default=1e-4, type=float, help='Learning Rate for dp-sgd stage.')
parser.add_argument('--dp_encoder', action='store_true', help='if true, dp encoder is trained')
parser.add_argument('--dp_decoder', action='store_true', help='if true, dp decoder is trained')
parser.add_argument('--dp_optim', type=str, default='Adam', help='DP optimiser, one of {Adam, SGD}')
# parser.add_argument('--dp_decoder', action='store_true', help='if true use dp decoder, else use dp encoder')
parser.add_argument('--delta', default=1e-5, type=float, help='δ for Differential Privacy')
parser.add_argument('--noise_multiplier', default=0.5, type=float, help='Noise to add to gradients')
parser.add_argument('--max_grad_norm', default=1.0, type=float, help='Value to clip the gradients norm')
parser.add_argument('--custom_directory', default=None, help='for saving to a custom directory, e.g. DPVAE/grid_runs/')

parser.add_argument('--fraction', default=1., type=float, help='fraction to split pretraining and classification data')

parser.add_argument('--synthetic_generation', action='store_true', help='use if generating synthetic data and label')
parser.add_argument('--novel_class', action='store_true', help='use if experimenting with the distributional shift')
parser.add_argument('--data_join_task', action='store_true', help='use if doing the data join task')

opt, unknown = parser.parse_known_args()


def save_opt_to_file(dic, path):
    f = open(path+'/options.txt', 'w')
    f.write(str(dic))
    f.close()


def loss_function(data, label, stage):
    if stage == 1:
        return nondp_model.loss(data, label)
    elif stage == 2:
        if opt.dp_encoder:
            return dp_model.loss(data, label, decoder=nondp_model)
        else:
            return dp_model.loss(data, label, encoder=nondp_model)
    else:
        raise NotImplementedError('Stage should be 1 or 2.')


def train(model, data_loader, epoch, iteration, optimizer, device, writer, log_interval, opt, stage):
    model.train()
    losses = []
    for batch_idx, datapoint in enumerate(data_loader):
        if opt.data_join_task:
            data, _, label = datapoint
        else:
            data, label = datapoint
        current_iter = iteration + batch_idx + 1
        optimizer.zero_grad()

        if opt.synthetic_generation:
            one_hot_label = torch.zeros((opt.batch_size, opt.n_categories)).scatter_(1, label.unsqueeze(1).long(), 1)
            loss, rec_x, rec_y, kl = loss_function(data=data.to(device), label=one_hot_label.float().to(device),
                                                   stage=stage)
        else:
            loss, rec_x, rec_y, kl = loss_function(data=data.to(device), label=None, stage=stage)
        if loss != loss:
            print(loss)
            break
        loss.backward()
        losses.append(loss.item())

        optimizer.step()

        if batch_idx % log_interval == 0:

            # Tensorboard plots
            writer.add_scalar('Stage{}/TrainLoss'.format(stage), loss.item(), current_iter)
            writer.add_scalar('Stage{}/TrainRec'.format(stage), rec_x.item(), current_iter)
            writer.add_scalar('Stage{}/TrainKL'.format(stage), kl.item(), current_iter)
            if opt.synthetic_generation:
                writer.add_scalar('Stage{}/TrainRecY'.format(stage), rec_y.item(), current_iter)

            string = 'Train Stage {} Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}\tRec: {:.4f}\tKL: {:.4f}'.format(
                            stage, epoch, batch_idx * len(data), len(data_loader.dataset),
                            100. * batch_idx / len(data_loader),
                            loss.item(), rec_x.item(), kl.item())
            if not opt.diagonal_x_std and stage == 1:
                string += '\tp(x|r) std: {:.4f}'.format(model.network.x_std.item())
            print(string)

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, np.mean(losses)))

    if stage == 1:
        return current_iter
    elif stage == 2 and (not opt.no_dp):
        #dp decoder
        epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent(opt.delta)
        string = '\tEpsilon = {:.4f} (Delta = {}) for alpha = {}'.format(epsilon, opt.delta, best_alpha)
        print(string)
        writer.add_scalar('Stage{}/epsilon'.format(stage), epsilon, epoch)
        return current_iter, epsilon
    elif stage == 2 and opt.no_dp:
        return current_iter, None
    else:
        return current_iter, np.inf

def test(model, data_loader, epoch, device, writer, stage):
    model.eval()
    with torch.no_grad():
        losses = []
        recs = []
        kls = []

        for batch_idx, datapoint in enumerate(data_loader):
            if opt.data_join_task:
                data, _, label = datapoint
            else:
                data, label = datapoint
            if batch_idx > 50:
                print('Test Stage {} loss evaluated on {} datapoints'.format(stage, (batch_idx - 1) * len(data)))
                break

            # loss, rec, kl = loss_function(data.to(device), stage)
            if opt.synthetic_generation:
                one_hot_label = torch.zeros((opt.batch_size, opt.n_categories)).scatter_(1, label.unsqueeze(1).long(),
                                                                                         1)
                loss, rec_x, rec_y, kl = loss_function(data=data.to(device), label=one_hot_label.float().to(device),
                                                       stage=stage)
            else:
                loss, rec_x, rec_y, kl = loss_function(data=data.to(device), label=None, stage=stage)
            losses.append(loss.item())
            recs.append(rec_x.item())
            kls.append(kl.item())

    writer.add_scalar('Stage{}/TestLoss'.format(stage), np.mean(losses), epoch)
    writer.add_scalar('Stage{}/TestRec'.format(stage), np.mean(recs), epoch)
    writer.add_scalar('Stage{}/TestKL'.format(stage), np.mean(kls), epoch)
    if opt.synthetic_generation:
        writer.add_scalar('Stage{}/TestRecY'.format(stage), np.mean(recs), epoch)

    print('\nEpoch: {}\tTest loss: {:.6f}\n\n'.format(epoch, np.mean(losses)))


def plot_rep(model, data_loader, epoch, device, writer, stage):
    model.eval()
    with torch.no_grad():

        data_list = [next(iter(data_loader)) for _ in range(10)]
        images = torch.cat([tuple[0] for tuple in data_list]).to(device)
        labels = torch.cat([tuple[1] for tuple in data_list]).to(device)
        num_classes = len(labels.unique())
        if opt.synthetic_generation:
            one_hot_labels = torch.zeros((len(labels), opt.n_categories), device=device).scatter_(1, labels.unsqueeze(1).long(), 1)
            images_and_labels = torch.cat([images.view(len(labels), -1), one_hot_labels.float()], dim=1)
            _, (rep, _) = model.posterior(images_and_labels)
        else:
            _, (rep, _) = model.posterior(images)

        colors = ("tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple", "tab:brown", "tab:pink",
                    "tab:gray", "tab:olive", "tab:cyan")

        rep_reduced = TSNE(n_components=2).fit_transform(rep.cpu())
        plt.clf()
        fig_r = plt.figure()
        ax = fig_r.add_subplot(111)
        for j, color in zip(range(num_classes), colors):
            rep_j = rep_reduced[(labels.cpu() == j).nonzero().squeeze(1)]
            x, y = rep_j[:, 0], rep_j[:, 1]
            ax.scatter(x, y, alpha=0.8, c=color, edgecolors='none', s=30, label=str(j))

        plt.legend(loc=2)
        writer.add_figure('Stage{}/TSNEplots'.format(stage), fig_r, global_step=epoch)
        plt.clf()

        print('TSNE reduction and plotting finished')


def sample(recon_data, epoch, device, writer, stage):
    (image, label) = recon_data
    nondp_model.eval()

    with torch.no_grad():
        image = image.to(device)
        if opt.synthetic_generation:
            label=label.to(device)
            one_hot_label = torch.zeros((label.size(0), opt.n_categories),
                                        device=device).scatter_(1, label.long(), 1).to(device)
        else:
            one_hot_label = None

        if stage == 1:
            nondp_model.reconstruct(image, one_hot_label, epoch, writer, stage=stage)
            nondp_model.sample(epoch, writer, stage=stage)
        elif stage == 2:
            dp_model.eval()
            if opt.dp_decoder:
                dp_model.sample(epoch, writer, stage=stage)
                dp_model.reconstruct(image, one_hot_label, epoch, writer, stage=stage, encoder=nondp_model)
            elif opt.dp_encoder:
                dp_model.reconstruct(image, one_hot_label, epoch, writer, stage=stage, decoder=nondp_model)
        else:
            raise NotImplementedError('Stage should be 1 or 2.')


if __name__ == '__main__':
    
    out_dir = './runs/'
    if opt.synthetic_generation:
        if not opt.no_dp_encoder or opt.no_dp_decoder:
            raise ValueError('Need dp decoder but dp encoder not necessary')
        out_dir += 'Synth'
    if opt.novel_class:
        out_dir += 'dShift'

    if opt.custom_directory is not None:
        out_dir += opt.custom_directory
    opt.no_dp = not (opt.dp_encoder or opt.dp_decoder)
    if opt.data_join_task:
        out_dir += 'JoinTask'
    if opt.no_dp:
        out_dir += 'VAE_lr{}_epoch{}_B{}_frac{}'.format(opt.lr, opt.epochs_stage1, opt.batch_size, opt.fraction)   
    else:
        if opt.synthetic_generation:
            if opt.dp_encoder:
                raise ValueError('Only need dp decoder')
            out_dir += 'Synth'
        elif opt.dp_encoder:
            out_dir += 'DPEnc'
        elif opt.dp_decoder:
            out_dir += 'DPDec'
        out_dir += '_lr{}_delta{}_norm{}' \
                   '_mult{}_epoch{}_B{}_frac{}'.format(opt.dp_lr, opt.delta, opt.max_grad_norm,
                                                opt.noise_multiplier, opt.epochs_stage2, opt.batch_size, opt.fraction)
    out_dir += '_task{}'.format(opt.task) # Everything after task will appear in classifier directory name
    out_dir += '_latent{}{}'.format(opt.latent_distn, opt.rep_dim)
    if opt.diagonal_x_std:
        out_dir += '_NNxstd'

    if opt.md is not None:
        out_dir += '_MDtrain{}_std{}'.format(opt.md, opt.posterior_std)
    out_dir+= '_priorstd{}'.format(opt.prior_std)
    # out_dir += '_{:%m%d%H%M}'.format(datetime.now()) #do we want to add this to prevent overwrites?

    os.makedirs(out_dir, exist_ok=True)
    writer = SummaryWriter(out_dir)

    np.random.seed(15)
    torch.manual_seed(11)

    use_cuda = torch.cuda.is_available()
    opt.device = torch.device("cuda" if use_cuda else "cpu")

    clean_train_loader, clean_val_loader, clean_test_loader, recon_dataset, opt.image_dim, opt.tabular, \
    opt.n_continuous_features, opt.ncat_of_cat_features, opt.n_categories, _ = get_clean_loaders(opt, writer)

    opt.z_dim = opt.rep_dim                                     # TODO fix naming here. Need a z_dim for nn_achitectures
    opt.color_chans = None if opt.tabular else opt.image_dim[0]

    if opt.md is not None:
        opt.prior_clip_percentile = percentile_from_md(opt.latent_distn, opt.rep_dim, opt.md)

    nondp_model = NonDPModel(opt).to(opt.device)

    optimizer1 = optim.Adam(nondp_model.network.parameters(), lr=opt.lr)
    if opt.dp_encoder:
        dp_model = DPEncoder(opt).to(opt.device)
        dp_params = list(dp_model.network.parameters())
    else:
        dp_model = DPDecoder(opt).to(opt.device)
        if not opt.diagonal_x_std:
            dp_params = list(dp_model.network.parameters()) + \
                             [dp_model.network.x_std]
        else:
            dp_params = list(dp_model.network.parameters())

    if opt.no_dp:
        print("WARNING: Non DP training occurring in Stage 2.")
        optimizer_dp = optim.Adam(dp_params, lr=opt.lr)       # use non-dp learning rate
        # optimizer_enc = optim.Adam(dp_enc_params, lr=opt.lr)           # use non-dp learning rate

    else:
        if opt.dp_optim == 'SGD':
            optimizer_dp = optim.SGD(dp_params, lr=opt.dp_lr)
        elif opt.dp_optim == 'Adam':
            optimizer_dp = optim.Adam(dp_params, lr=opt.dp_lr)

        privacy_engine = PrivacyEngine(
            dp_model,
            opt.batch_size,
            len(clean_train_loader.dataset),
            alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
            noise_multiplier=opt.noise_multiplier,
            max_grad_norm=opt.max_grad_norm,
        )
        privacy_engine.attach(optimizer_dp)
        
    print("\n", "=" * 50)
    print("Hyperparameters")
    for x, y in vars(opt).items():
        print(x, y)
    print("=" * 50, "\n")

    # Stage 1: Non-DP training
    iters = 0
    for epoch in range(1, opt.epochs_stage1 + 1):
        iters = train(nondp_model, clean_train_loader, epoch, iters, optimizer1, opt.device,
                      writer, opt.log_interval, opt, stage=1)
        test(nondp_model, clean_test_loader, epoch, opt.device, writer, stage=1)
        plot_rep(nondp_model, clean_test_loader, epoch, opt.device, writer, stage=1)

        # sample(recon_dataset, epoch, opt.device, writer, stage=1)

        if opt.tabular and not opt.synthetic_generation:
            generated_df = generated_synthetic_dataframe(nondp_model, 15000)
            fig_g = plot_tabular_df(generated_df, opt.ncat_of_cat_features, opt.n_continuous_features)
            writer.add_figure('plot_features/stage1generated_data', fig_g, global_step=epoch)
            
            plot_histograms_of_reconstructed_data(nondp_model, nondp_model, clean_test_loader, epoch, writer, opt, stage=1)
            

        print('Saving state_dict...')
        torch.save(nondp_model.state_dict(),
                   os.path.join(out_dir, 'rep_model_state_dict.pth'))
        print('state_dict saved.')
        print('Saving nonDP model...')
        torch.save(nondp_model,
                   os.path.join(out_dir, 'rep_model.pth'))
        print('nonDP model saved.')
    
    # Stage 2: DP Training
    if not opt.no_dp:
        iters = 0
        for epoch in range(1, opt.epochs_stage2 + 1):
            iters, epsilon = train(dp_model, clean_train_loader, epoch, iters, optimizer_dp, opt.device,
                                   writer, opt.log_interval, opt, stage=2)
            test(dp_model, clean_test_loader, epoch, opt.device, writer, stage=2)
            if opt.dp_encoder:
                plot_rep(dp_model, clean_test_loader, epoch, opt.device, writer, stage=2)

            if not opt.tabular:
                sample(recon_dataset, epoch, opt.device, writer, stage=2)
            
            if opt.tabular and opt.dp_decoder and not opt.synthetic_generation:
                generated_df = generated_synthetic_dataframe(dp_model, 15000)
                fig_g = plot_tabular_df(generated_df, opt.ncat_of_cat_features, opt.n_continuous_features)
                writer.add_figure('plot_features/stage2generated_data', fig_g, global_step=epoch)
                
                plot_histograms_of_reconstructed_data(nondp_model, dp_model, clean_test_loader, epoch, writer, opt, stage=2)


            
            print('Saving state_dict...')
            torch.save(dp_model.state_dict(),
                       os.path.join(out_dir, 'dp_rep_model_state_dict_eps{}.pth'.format(str(round(epsilon, 2)))))
            print('state_dict saved.')
            print('Saving DP model...')
            torch.save(dp_model,
                       os.path.join(out_dir, 'dp_rep_model_eps{}.pth'.format(str(round(epsilon, 2)))))
            print('DP model saved.')

    save_opt_to_file(opt, out_dir)


