"""
比较TC与data cues在解耦上的差异。比较非独立采样
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt

import cv2
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae import Evaluator
from disvae.models.anneal import get_anneal
from disvae.models.losses import _kl_normal_loss, get_loss_f, get_loss_s
from disvae.models.vae import VAE
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from disvae.utils.math import log_density_gaussian
from utils.data_generator import *
from utils.visualize import *
from utils import data_generator
import numpy as np

from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=64,
    learning_rate=0.0001,
    epochs=151,
    dimension=6,
    dataset='related',
    loss='betaH',
    beta=100,
    random_seed=134
)
parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args
set_seed(config.random_seed)

wandb.init(project="experiment", group=__file__.split('/')[-1][:-3], notes=__doc__,config=args,
           )

def get_preds(model, dl):
    model.eval()
    preds = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, targets


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)

    for e in trange(epochs):
        storer = defaultdict(list)
        loss_set = []
        for img, label in dl:
            img = img.cuda()

            try:
                recon_batch, latent_dist, latent_sample = model(img)
                loss = loss_f(img, recon_batch, latent_dist, model.training,
                              storer, latent_sample=latent_sample)

                opt.zero_grad()
                loss.backward()
                opt.step()
            except ValueError:
                # for losses that use multiple optimizers (e.g. Factor)
                loss = loss_f.call_optimize(img, model, opt, storer)

            loss_set.append(loss.item())
        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)
        storer['loss'] = np.mean(loss_set)

        if e % 50 == 0:
            wandb.log({
                'recon': wandb.Image(recon_batch[0, 0]),
                'img': wandb.Image(img[0, 0])
            })
            kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' == k[:3]])
            _, index = kl_loss.sort(descending=True)
            mu, logvar = latent_dist
            # for d in range(dim):
            #     storer[f"mu_{d}"] = wandb.Histogram(mu.data[:, d].cpu())
            #     storer[f"var_{d}"] = wandb.Histogram(logvar.data.exp()[:, d].cpu())
            preds, targets = get_preds(vae, dl)
            points = preds[:1000, index[:3]].cpu()
            colors = targets[:1000]
            # points = mu.data[:, index[:3]].cpu()
            # colors = label
            fig, axes = plt.subplots(2, 3)
            for i in range(2):  # factor
                for j in range(3):  # variable
                    axes[i, j].scatter(points[:, j].numpy(), colors[:, i ].numpy(), s=0.2)
                    axes[i, j].set_title(f'{i + 1},{index[j]}')
            storer['correlated'] = wandb.Image(fig)
            plt.close()
        storer['epoch'] = e
        wandb.log(storer, sync=False)
    return storer


# generate data
img_id = 0
img = torch.load(f'exps/patterns/{img_id}.pat')
img = rotate(img,90)
epochs = config.epochs
dim = config.dimension
beta = config.beta

if config.dataset=='related':
    dataset = RelatedTranslation(img)
else:
    dataset = PolarTranslation(img)
dl = DataLoader(dataset,
                num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)
vae.cuda();
vae.train()

# wandb.watch(vae, log_freq=10)
iterations = len(dl) * epochs
if config.loss == 'betaB':
    anneal = get_anneal('monotonic', iterations, 0, 1)
elif config.loss == 'betaH':
    anneal = get_anneal('monotonic', iterations, 1, 0.2)
else:
    anneal = get_anneal('constant', iterations, 1, 1)
loss_f = get_loss_s(config.loss, anneal, beta, len(dl.dataset))
vae.load_state_dict(torch.load('tmp.pt'))
index = [1,2]
# storer = train(dl, vae, loss_f, epochs)
#
# kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' == k[:3]])
# _, index = kl_loss.sort(descending=True)
# print(index)
# torch.save(vae.state_dict(), 'tmp.pt')

vae.eval()
test_loader = DataLoader(dataset,
                         batch_size=512,
                         shuffle=True)

evaluator=Evaluator(vae, loss_f,
                      device='cuda',
                      logger=wandb.logger,
                      save_dir=wandb.wandb_dir(),
                      is_progress_bar=False)



evaluator.compute_metrics(test_loader)

preds,target = get_preds(vae,test_loader)
preds = preds.cpu()
vae.cpu()


fig = plot_reconstruct(dataset.imgs, (3, 2), vae)
wandb.log({'reconstruction': wandb.Image(fig)})

data_len=len(dataset)
fig = plt_sample_traversal(dataset.imgs[data_len//2].reshape(1,1,64,64), vae, 7, range(dim), r=2)
wandb.log({'traversal': wandb.Image(fig)})

z = preds[:,index[:2]]
fig = plt.figure()
for y in range(0,40,5):
    selected = target[:,0]== y
    plt.plot(z[selected, 0], z[selected, 1],marker='o')
    for i in selected.nonzero().flatten()[::4]:
        plt.text(z[i,0], z[i,1],str(target[i].tolist()))
wandb.log({'projection_xy': wandb.Image(fig)})
# fig = plot_projection(dataset.o_imgs[::5,::4], vae, dim=index[:2])
# wandb.log({'projection_xy': wandb.Image(fig)})
plt.show()
