"""

"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path

import cv2
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.losses import _kl_normal_loss, get_loss_f
from disvae.models.vae import VAE
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.001,
    epochs=201,
    dimension=6,
    loss='betaH',
    beta=10,
    img_id=6,
    random_seed=224
)

wandb.init(project="exps", config=hyperparameter_defaults, group='bijection')

config = wandb.config
set_seed(config.random_seed)


def plot_projection(samples, model, dim=(0, 1)):
    fig = plt.figure()
    z1, z2, _, _ = samples.shape
    with torch.no_grad():
        mu, log_var = model.encoder(samples.reshape(-1, 1, 64, 64))
        mu = mu.cpu()
        plt.scatter(mu[:, dim[0]].data, mu[:, dim[1]].data, )
        for j in range(z1):
            for i in range(z2):
                index = i + z2 * j
                plt.text(mu[index, dim[0]], mu[index, dim[1]], f'({j},{i})')
            plt.plot(mu[z2 * j:z2 * j + z2, dim[0]], mu[z2 * j:z2 * j + z2, dim[1]])

    return fig


def refine(dl, model, epochs):
    model.encoder.requires_grad_(False)
    opt = optim.AdamW(model.decoder.parameters(), config.learning_rate)
    for e in range(epochs):
        for img in dl:
            mu, _ = model.encoder(img)
            recon_batch = model.decoder(mu)
            loss = F.binary_cross_entropy(recon_batch, img, reduction='sum') / img.size(0)

            opt.zero_grad()
            loss.backward()
            opt.step()
        wandb.log({'recon_loss': loss.item()})


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)

    for e in trange(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img, label in dl:
            img = img.view(-1, 1, 64, 64)
            recon_batch, latent_dist, latent_sample = model(img)

            rcodes = torch.randn((128, dim), device=img.device)
            rimgs = model.decoder(rcodes)
            rpreds, _ = model.encoder(rimgs)
            rloss = F.mse_loss(rpreds, rcodes)
            storer['rloss'].append(rloss.item())

            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        wandb.log(storer, sync=False)

        with torch.no_grad():
            mu, log_var = latent_dist
            var = log_var.exp().cpu()
            mu = mu.cpu()
            for d in range(dim):
                wandb.log({f"mu_{d}": wandb.Histogram(mu[:, d])}, sync=False)
                wandb.log({f"var_{d}": wandb.Histogram(var[:, d])}, sync=False)
        if e % 50 == 1:
            wandb.log({
                'recon': wandb.Image(recon_batch[0, 0]),
                'img': wandb.Image(img[0, 0]),
                'rimgs': wandb.Image(rimgs[0, 0].data)
            })
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')


def gen_imgs(img, size=(11, 11), step=1):
    h, w = size
    imgs = np.zeros((32, 32, 64, 64), float)

    for i in range(32):
        for j in range(32):
            imgs[i, j][i * step:i * step + h, j * step:j * step + w] = img
    return imgs


imgs = torch.Tensor(gen_imgs(img, img.shape))
labels = np.zeros((32, 32, 2), np.float32)
for i in range(32):
    for j in range(32):
        labels[i, j] = [i, j]
labels = torch.from_numpy(labels).reshape(32 * 32, 2)

epochs = config.epochs
dim = config.dimension
beta = config.beta

dataset = imgs.reshape(-1, 1, 64, 64)
dl = DataLoader(list(zip(dataset.cuda(), labels.cuda())),
                num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim, 1)
vae.cuda();
vae.train()
# wandb.watch(vae)
loss_f = None
if config.loss == 'betaH':
    loss_f = get_loss_f('betaH', betaH_B=beta,
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
elif config.loss == 'btcvae':
    loss_f = get_loss_f('btcvae', btcvae_A=1, btcvae_B=beta, btcvae_G=1, n_data=len(dataset),
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
else:
    loss_f = get_loss_f('Dis', betaH_B=beta,
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))

storer = train(dl, vae, loss_f, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_loss_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
torch.save(vae.state_dict(), 'tmp_model.pt')
# refine
# refine(dl,vae,5)
vae.eval()

vae.cpu()

del config

fig = plot_reconstruct(dataset, (3, 2), vae)
wandb.log({'reconstruction': wandb.Image(fig)})

fig = plt_sample_traversal(dataset[:1], vae, 7, index, r=3)
wandb.log({'traversal': wandb.Image(fig)})

fig = plot_projection(imgs[::4, ::4], vae, dim=index[:2])
wandb.log({'projection_xy': wandb.Image(fig)})

with torch.no_grad():
    mu, _ = vae.encoder(dataset)
    points = mu[:, index[:3]]
    # points = 10*points
    colors = labels.reshape(32 * 32, 2) * torch.tensor([8, 0]) + torch.tensor([0, 255])
    point_cloud = torch.cat([points, colors, torch.ones(32 * 32, 1) * 255], 1).numpy()
    wandb.log({'point_cloud': wandb.Object3D(point_cloud)})

# wandb.save('strip.py')
plt.show()
wandb.join()
