"""
Shape comparison.
To study the effects of shape, we investigate the pattern of shape: horiz stripe, vertical stripe, solid, dot
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path

import cv2
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.losses import _kl_normal_loss, get_loss_f
from disvae.models.vae import VAE
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
import numpy as np

from utils.helpers import set_seed

parser = argparse.ArgumentParser()
parser.add_argument('--beta', type=float)
parser.add_argument('--img_id', type=int)
args = parser.parse_args()

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.001,
    epochs=201,
    dimension=6,
    beta=40,
    img_id=1,
    random_seed=224
)

hyperparameter_defaults['img_id'] = args.img_id
hyperparameter_defaults['beta'] = args.beta

wandb.init(project="exps", config=hyperparameter_defaults, group='pattern')

config = wandb.config
set_seed(config.random_seed)


def plot_projection(samples, model, dim=(0, 1)):
    fig = plt.figure()
    z1, z2, _, _ = samples.shape
    with torch.no_grad():
        mu, log_var = model.encoder(samples.reshape(-1, 1, 64, 64))
        mu = mu.cpu()
        plt.scatter(mu[:, dim[0]].data, mu[:, dim[1]].data, )
        for j in range(z1):
            for i in range(z2):
                index = i + z2 * j
                plt.text(mu[index, dim[0]], mu[index, dim[1]], f'({j},{i})')
            plt.plot(mu[z2 * j:z2 * j + z2, dim[0]], mu[z2 * j:z2 * j + z2, dim[1]])

    return fig


def refine(dl, model, epochs):
    model.encoder.requires_grad_(False)
    opt = optim.AdamW(model.decoder.parameters(), config.learning_rate)
    for e in range(epochs):
        for img in dl:
            mu, _ = model.encoder(img)
            recon_batch = model.decoder(mu)
            loss = F.binary_cross_entropy(recon_batch, img, reduction='sum') / img.size(0)

            opt.zero_grad()
            loss.backward()
            opt.step()
        wandb.log({'recon_loss': loss.item()})


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)

    for e in trange(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img in dl:
            img = img.view(-1, 1, 64, 64)
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        wandb.log(storer, sync=False)

        with torch.no_grad():
            mu, log_var = latent_dist
            var = log_var.exp().cpu()
            mu = mu.cpu()
            for d in range(dim):
                wandb.log({f"mu_{d}": wandb.Histogram(mu[:, d])}, sync=False)
                wandb.log({f"var_{d}": wandb.Histogram(var[:, d])}, sync=False)
        if e % 50 == 1:
            wandb.log({
                'recon': wandb.Image(recon_batch[0, 0]),
                'img': wandb.Image(img[0, 0])
            })
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')


def gen_imgs(img):
    imgs = np.zeros((32, 32, 64, 64))

    h, w = 11, 11
    step = 1

    for i in range(32):
        for j in range(32):
            imgs[i, j][i * step:i * step + h, j * step:j * step + w] = img
    return imgs


imgs = torch.Tensor(gen_imgs(img))
labels = np.zeros((32, 32, 2))

for j in range(32):
    for k in range(32):
        labels[j, k] = [j, k]
labels = torch.from_numpy(labels).float()

epochs = config.epochs
dim = config.dimension
beta = config.beta

dataset = imgs.reshape(-1, 1, 64, 64).clone()
dl = DataLoader(dataset.cuda(), num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim, 1)
vae.cuda()
# wandb.watch(vae)
if config.loss == 'betaH':
    loss_f = get_loss_f('betaH', betaH_B=beta,
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
elif config.loss == 'btcvae':
    loss_f = get_loss_f('btcvae', btcvae_A=1, btcvae_B=beta, btcvae_G=1, n_data=len(dataset),
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
elif config.loss == 'dis':
    loss_f = get_loss_f('Dis', betaH_B=beta,
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
else:
    raise Exception('unknown loss')

vae.train()
storer = train(dl, vae, loss_f, epochs)
kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_loss_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
torch.save(vae.state_dict(), 'tmp_model.pt')
# refine
# refine(dl,vae,5)
vae.eval()

vae.cpu()

fig = plot_reconstruct(dataset, (3, 2), vae)
wandb.log({'reconstruction': wandb.Image(fig)})

fig = plt_sample_traversal(dataset[:1], vae, 7, index, r=3)
wandb.log({'traversal': wandb.Image(fig)})

fig = plot_projection(imgs[::4, ::4], vae, dim=index[:2])
wandb.log({'projection_xy': wandb.Image(fig)})

with torch.no_grad():
    mu, _ = vae.encoder(dataset)
    points = mu[:, index[:3]]
    points = 30 * points  # (points-points.mean(0))/points.std(0)
    colors = labels.reshape(8 * 32 * 32, 3) * torch.tensor([24, 8, 8]) + 7
    point_cloud = torch.cat([points, colors], 1).numpy()
    wandb.log({'point_cloud': wandb.Object3D(point_cloud)})

# wandb.save('strip.py')
plt.show()
wandb.join()
del config
