"""
Shape comparison.
To study the effects of shape, we investigate the pattern of shape: horiz stripe, vertical stripe, solid, dot
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
import cv2
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb

from disvae.models.losses import _kl_normal_loss, get_loss_f
from disvae.models.vae import VAE
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.001,
    epochs=3001,
    plot_interval=150,
    dimension=2,
    width=3,
    beta=35,
    random_seed=112
)

wandb.init(project="exps", config=hyperparameter_defaults, group="clock")

config = wandb.config
set_seed(config.random_seed)


def plot_projection(samples, model):
    fig = plt.figure()
    with torch.no_grad():
        mu, log_var = model.encoder(samples)
        mu = mu.cpu()
        if mu.size(1) == 2:
            plt.scatter(mu[:, 0].data, mu[:, 1].data, )
            for i in range(len(samples)):
                plt.text(mu[i, 0], mu[i, 1], f'{i}')
            plt.plot(mu[:, 0], mu[:, 1])
        else:
            for i in range(len(samples)):
                plt.scatter(i, mu[i, 0])

    return fig


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    for e in range(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img in dl:
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        wandb.log(storer, sync=False)
        # wandb.logger.info(';'.join([f'{k}:{v}' for k, v in storer.items()]))
        interval = config.plot_interval
        if e % interval == 0:
            fig = plot_projection(dl.dataset, vae)
            wandb.log({'projection': wandb.Image(fig)})
            plt.close('all')

            # fig = plot_trajectory(dl.dataset, vae)
            # wandb.log({f'trajectory_{e // interval}':wandb.Image(fig)},)

        with torch.no_grad():
            mu, log_var = latent_dist
            var = log_var.exp().cpu()
            mu = mu.cpu()
            for d in range(dim):
                wandb.log({f"mu_{d}": wandb.Histogram(mu[:, d])}, sync=False, )
                wandb.log({f"var_{d}": wandb.Histogram(var[:, d])}, sync=False, )

    # refine


def refine(dl, model, epochs):
    model.encoder.requires_grad_(False)
    opt = optim.AdamW(model.decoder.parameters(), config.learning_rate)
    for e in range(epochs):
        for img in dl:
            mu, _ = model.encoder(img)
            recon_batch = model.decoder(mu)
            loss = F.binary_cross_entropy(recon_batch, img, reduction='sum') / img.size(0)

            opt.zero_grad()
            loss.backward()
            opt.step()
        wandb.log({'recon_loss': loss.item()})


def evaluate(model, ds):
    factor = torch.arange(N).reshape(N, 1).float().cuda()
    with torch.no_grad():
        z, _ = model.encoder(ds)
    G = nn.Linear(1, config.dimension).cuda()
    opt = optim.AdamW(G.parameters(), config.learning_rate)

    for e in range(2000):
        preds = G(factor)
        ploss = F.mse_loss(preds, z)
        wandb.log({'prediction': ploss.item()})
        opt.zero_grad()
        ploss.backward()
        opt.step()


epochs = config.epochs
dim = config.dimension
beta = config.beta

img = np.zeros((32, 32))
h, w = 32, 32
center = (w // 2, h // 2)

triangle_cnt = np.array([(16, 16 + config.width), (16, 16 - config.width), (31, 16)])
nimg = cv2.drawContours(img, [triangle_cnt], 0, (1), -1)

N = 16
imgs = torch.zeros(N, 1, 32, 32)
for i in range(N):
    M = cv2.getRotationMatrix2D(center, 360 / N * i, 1)
    rotated = cv2.warpAffine(img, M, (w, h))
    imgs[i, 0] = torch.from_numpy(rotated)
s = imgs.sum(0)[0]

dataset = imgs
dl = DataLoader(dataset.cuda(), num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 32, 32), get_encoder('Burgess'), get_decoder('Burgess'), dim, 1)
vae.cuda()
# wandb.watch(vae)
loss_f = get_loss_f('Dis', betaH_B=beta,
                    record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
vae.train()
train(dl, vae, loss_f, epochs)
# refine(dl,vae, int(epochs*0.5))
vae.cpu()

fig = plot_reconstruct(dataset, (3, 2), vae)
wandb.log({'reconstruction': wandb.Image(fig)})
mu, _ = vae.encoder(dataset)

fig = plt_sample_traversal(dataset[:1], vae, 7, dim, r=(mu.abs().max()).item() + 0.5)
wandb.log({'traversal': wandb.Image(fig)})
# wandb.save('strip.py')
del config
wandb.join()
