"""
首先进行消融测试，探究是什么因素导致旋转因子难以学习。
1. 模型学习速率不同
2. 数据中平移因子比较明显
"""
a = """
attention表示某个控制因子是否应该被学习。
如果失活是由于训练次数过多导致的，那么不在attention列表的隐变量应该减少搜索次数。
以一个概率p产生误差，在attention list，p=1.
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path

import cv2
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.losses import _kl_normal_loss, get_loss_f, DisVAELoss, BetaBLoss
from disvae.models.vae import VAE
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
from utils import data_generator
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.001,
    epochs=1001,
    dimension=6,
    loss='dis',
    beta=80,
    img_id=6,
    random_seed=224
)

wandb.init(project="exps", config=hyperparameter_defaults, group='attention',
           tags=['rotation', 'ablation'],
           notes=__doc__)

config = wandb.config
set_seed(config.random_seed)


class VAE(nn.Module):
    def __init__(self, img_size, encoder, decoder, latent_dim):
        """
        Class which defines model and forward pass.

        Parameters
        ----------
        img_size : tuple of ints
            Size of images. E.g. (1, 32, 32) or (3, 64, 64).
        """
        super(VAE, self).__init__()

        if list(img_size[1:]) not in [[32, 32], [64, 64]]:
            raise RuntimeError(
                "{} sized images not supported. Only (None, 32, 32) and (None, 64, 64) supported. Build your own architecture or reshape images!".format(
                    img_size))

        self.latent_dim = latent_dim
        self.img_size = img_size
        self.num_pixels = self.img_size[1] * self.img_size[2]
        self.encoder = encoder(img_size, self.latent_dim)
        self.decoder = decoder(img_size, self.latent_dim)
        self.attention = range(6)
        self.reset_parameters()

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            batch_size, latent_dim = mean.shape
            std = torch.exp(0.5 * logvar)
            # eps = torch.randn_like(std)
            eps = torch.zeros_like(std)
            for a in range(6):
                if a in self.attention:
                    eps[:, a] = torch.randn_like(eps[:, a])
                else:
                    if torch.rand(1) <= 0.5:
                        eps[:, a] = torch.randn_like(eps[:, a])

            return mean + eps * std
        else:
            # Reconstruction mode
            return mean

    def forward(self, x):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)

        reconstruct = self.decoder(latent_sample)
        return reconstruct, latent_dist, latent_sample

    def reset_parameters(self):
        self.apply(weights_init)

    def sample_latent(self, x):
        """
        Return latent distribution and samples.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        return latent_dist, latent_sample


def mean(storer):
    for k, v in storer.items():
        if isinstance(v, list):
            storer[k] = np.mean(v)
    return storer


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    storer = defaultdict(list)
    for e in trange(epochs):
        for itr, (img, label) in enumerate(dl):
            itr = itr + e * len(dl)

            img = img.view(-1, 1, 64, 64)
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            storer['loss'].append(loss.item())

            with torch.no_grad():
                if (itr + 1) % 10 == 0:
                    storer = mean(storer)

                    # histogram
                    mu, log_var = latent_dist
                    var = (log_var.data).exp().cpu()
                    mu = mu.data.cpu()
                    for d in range(dim):
                        storer[f"mu_{d}"] = wandb.Histogram(mu[:, d])
                        storer[f'var_{d}'] = wandb.Histogram(var[:, d])

                    # img
                    storer['recon'] = wandb.Image(recon_batch[0, 0])
                    storer['img'] = wandb.Image(img[0, 0])

                    # if (itr + 1)%100 == 0:
                    #     model.cpu()
                    #     model.eval()
                    #     fig = plot_projection(imgs[2, ::4, ::4], model, dim=[0,1])
                    #     storer['projection_xy'] = wandb.Image(fig)
                    #     model.cuda()
                    #     model.train()

                    wandb.log(storer)
                    storer = defaultdict(list)
        if e == 1:
            model.attention = range(2, 6)

    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')

imgs, labels = data_generator.gen_rotation(img, img.shape)

epochs = config.epochs
dim = config.dimension
beta = config.beta

dataset = imgs[:, 0, 0].reshape(-1, 1, 64, 64)
labels = labels[::40 * 40]
dl = DataLoader(list(zip(dataset.cuda(), labels)),
                num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)
vae.cuda();
vae.train()
wandb.watch(vae.decoder, log="all")

loss_f = None
if config.loss == 'betaH':
    loss_f = get_loss_f('betaH', betaH_B=beta,
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
elif config.loss == 'btcvae':
    loss_f = get_loss_f('btcvae', btcvae_A=1, btcvae_B=beta, btcvae_G=1, n_data=len(dataset),
                        record_loss_every=2, rec_dist='bernoulli', reg_anneal=epochs * len(dl))
elif config.loss == 'dis':
    loss_f = DisVAELoss(beta=beta,
                        record_loss_every=2, rec_dist='bernoulli', steps_anneal=epochs * len(dl))
elif config.loss == 'betaB':
    loss_f = BetaBLoss(C_init=0, C_fin=beta, gamma=100,
                       record_loss_every=2, rec_dist='bernoulli', steps_anneal=epochs * len(dl))
else:
    raise Exception('unknown loss')

storer = train(dl, vae, loss_f, epochs)

# torch.save(vae.state_dict(), 'tmp_model.pt')
# refine
# refine(dl,vae,5)
vae.eval()

vae.cpu()

del config

fig = plot_reconstruct(dataset, (3, 2), vae)
wandb.log({'reconstruction': wandb.Image(fig)})

fig = plt_sample_traversal(dataset[:1], vae, 7, range(dim), r=3)
wandb.log({'traversal': wandb.Image(fig)})
#
# with torch.no_grad():
#     mu, _ = vae.encoder(dataset)
#     points = mu[:, [0,1,2]]
#     # points = 10*points
#     colors = labels.reshape(20 * 40 * 40, 3) * torch.tensor([12, 6, 6])
#     point_cloud = torch.cat([points, colors, ], 1).numpy()
#     wandb.log({'point_cloud': wandb.Object3D(point_cloud)})

plt.show()
wandb.join()
