import argparse
import logging
from collections import defaultdict
from pathlib import Path

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb

from disvae.models.losses import _kl_normal_loss, get_loss_f
from disvae.models.vae import VAE
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.001,
    epochs=251,
    plot_interval=20,
    dimension=2,
    beta=30,
    strip_k1=0,
    strip_k2=1,
    random_seed=112
)

wandb.init(project="strip", config=hyperparameter_defaults, )
config = wandb.config
set_seed(config.random_seed)


def plot_projection(sample, model):
    fig = plt.figure()
    with torch.no_grad():
        v = sample.reshape(32, 32, 64, 64)
        mu, log_var = model.encoder(v[::8, ::8].reshape(-1, 1, 64, 64))
        mu = mu.cpu()
        plt.scatter(mu[:, 0].data, mu[:, 1].data, )
        for j in range(4):
            for i in range(4):
                index = i + 4 * j
                plt.text(mu[index, 0], mu[index, 1], f'({j},{i})')
            plt.plot(mu[4 * j:4 * j + 4, 0], mu[4 * j:4 * j + 4, 1])

    return fig


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    for e in range(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img in dl:
            img = img.view(-1, 1, 64, 64)
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        wandb.log(storer, sync=False, step=e)
        # wandb.logger.info(';'.join([f'{k}:{v}' for k, v in storer.items()]))
        interval = config.plot_interval
        if e % interval == 0:
            fig = plot_projection(dl.dataset, vae)
            wandb.log({'projection': wandb.Image(fig)}, step=e)
            plt.close('all')

            # fig = plot_trajectory(dl.dataset, vae)
            # wandb.log({f'trajectory_{e // interval}':wandb.Image(fig)},step=e)

        with torch.no_grad():
            mu, log_var = latent_dist
            var = log_var.exp().cpu()
            mu = mu.cpu()
            for d in range(dim):
                wandb.log({f"mu_{d}": wandb.Histogram(mu[:, d])}, step=e, sync=False)
                wandb.log({f"var_{d}": wandb.Histogram(var[:, d])}, step=e, sync=False)


exp_dir = Path('results/')
imgs = torch.zeros(32, 32, 64, 64)
for i in range(32):
    for j in range(32):
        imgs[i, j][11 + i:11 + i + 10, 11 + j:11 + j + 10] = 1
imgs = imgs.view(32 * 32, 1, 64, 64)

epochs = config.epochs
dim = config.dimension
beta = config.beta
mask1 = torch.cat([torch.ones(1, 64), torch.ones(1, 64) * config.strip_k1], 0).repeat(32, 1)
mask2 = torch.cat([torch.ones(64, 1), torch.ones(64, 1) * config.strip_k2], 1).repeat(1, 32)
strip_imgs = imgs * mask1 * mask2

dataset = strip_imgs.cuda()
dl = DataLoader(dataset, num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim, 1)
vae.cuda()
loss_f = get_loss_f('Dis', betaH_B=beta,
                    record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))

train(dl, vae, loss_f, epochs)

vae.cpu()

fig = plot_reconstruct(strip_imgs, (3, 2), vae)
wandb.log({'reconstruction': wandb.Image(fig)})
fig = plt_sample_traversal(strip_imgs[:1].view(1, 1, 64, 64), vae, 7, dim)
wandb.log({'traversal': wandb.Image(fig)})
# wandb.save('strip.py')
wandb.join()
