"""
叫图像预处理更合适。能否对图像进行一定的预处理，从而突出某一类因子。
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb

from disvae.models.losses import _kl_normal_loss, get_loss_f
from disvae.models.vae import VAE
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.001,
    epochs=251,
    plot_interval=20,
    dimension=2,
    beta=4,
    texture='diagonal',
    random_seed=112
)

wandb.init(project="texture", config=hyperparameter_defaults, )
config = wandb.config
set_seed(config.random_seed)


def get_texture(type):
    horiz_strip = torch.cat([torch.ones(1, 10), torch.zeros(1, 10)]).repeat(5, 1)
    if type == "solid":
        return 1
    elif type == 'horiz':
        return horiz_strip
    elif type == 'vertical':
        return horiz_strip.t()
    elif type == 'dot':
        return horiz_strip * horiz_strip.t()
    elif type == 'rand':
        return torch.rand(10, 10) < 0.1
    elif type == 'diagonal':
        return torch.diag(torch.ones(10))
    else:
        raise Exception('Wrong pattern type')


def plot_projection(sample, model):
    fig = plt.figure()
    with torch.no_grad():
        v = sample.reshape(32, 32, 64, 64)
        mu, log_var = model.encoder(v[::8, ::8].reshape(-1, 1, 64, 64))
        mu = mu.cpu()
        plt.scatter(mu[:, 0].data, mu[:, 1].data, )
        for j in range(4):
            for i in range(4):
                index = i + 4 * j
                plt.text(mu[index, 0], mu[index, 1], f'({j},{i})')
            plt.plot(mu[4 * j:4 * j + 4, 0], mu[4 * j:4 * j + 4, 1])

    return fig


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    for e in range(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img in dl:
            img = img.view(-1, 1, 64, 64)
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        wandb.log(storer, sync=False, step=e)
        # wandb.logger.info(';'.join([f'{k}:{v}' for k, v in storer.items()]))
        interval = config.plot_interval
        if e % interval == 0:
            fig = plot_projection(dl.dataset, vae)
            wandb.log({'projection': wandb.Image(fig)}, step=e)
            plt.close('all')

            # fig = plot_trajectory(dl.dataset, vae)
            # wandb.log({f'trajectory_{e // interval}':wandb.Image(fig)},step=e)

        with torch.no_grad():
            mu, log_var = latent_dist
            var = log_var.exp().cpu()
            mu = mu.cpu()
            for d in range(dim):
                wandb.log({f"mu_{d}": wandb.Histogram(mu[:, d])}, step=e, sync=False)
                wandb.log({f"var_{d}": wandb.Histogram(var[:, d])}, step=e, sync=False)

    # refine
    vae.encoder.requires_grad_(False)
    for e in range(epochs, int(epochs * 1.5)):
        for img in dl:
            mu, _ = vae.encoder(img)
            recon_batch = vae.decoder(mu)
            loss = F.binary_cross_entropy(recon_batch, img, reduction='sum') / img.size(0)

            opt.zero_grad()
            loss.backward()
            opt.step()
        wandb.log({'recon_loss': loss.item()})


def evaluate(model, ds):
    factor = torch.cat([(torch.arange(32 * 32) // 32).view(-1, 1),
                        (torch.arange(32 * 32) % 32).view(-1, 1)], 1).float().cuda()
    with torch.no_grad():
        z, _ = model.encoder(ds)
    G = nn.Linear(2, config.dimension).cuda()
    opt = optim.AdamW(G.parameters(), config.learning_rate)

    for e in range(2000):
        preds = G(factor)
        ploss = F.mse_loss(preds, z)
        wandb.log({'prediction': ploss.item()})
        opt.zero_grad()
        ploss.backward()
        opt.step()


exp_dir = Path('results/')
imgs = torch.zeros(32, 32, 64, 64)
texture = get_texture(config.texture)

for i in range(32):
    for j in range(32):
        imgs[i, j][11 + i:11 + i + 10, 11 + j:11 + j + 10] = texture
imgs = imgs.view(32 * 32, 1, 64, 64)

epochs = config.epochs
dim = config.dimension
beta = config.beta

dataset = imgs
dl = DataLoader(dataset.cuda(), num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim, 1)
vae.cuda()
# wandb.watch(vae)
loss_f = get_loss_f('Dis', betaH_B=beta,
                    record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
vae.train()
train(dl, vae, loss_f, epochs)
vae.eval()
evaluate(vae, dataset.cuda())
vae.cpu()

fig = plot_reconstruct(dataset, (3, 2), vae)
wandb.log({'reconstruction': wandb.Image(fig)})
fig = plt_sample_traversal(dataset[:1], vae, 7, dim)
wandb.log({'traversal': wandb.Image(fig)})
# wandb.save('strip.py')
wandb.join()
