"""
部分旋转 [:,:,20]
对 loss的研究，extraction_rate, c_init,c_fin
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path

import cv2
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import *
from disvae.models.losses import _kl_normal_loss, get_loss_f, BetaBLoss, BaseLoss, _reconstruction_loss, \
    linear_annealing, BtcvaeLoss
from disvae.models.vae import VAE
from disvae.utils.initialization import weights_init
from disvae.models.encoders import get_encoder
from disvae.models.decoders import get_decoder
from utils.visualize import *
from utils import data_generator
import numpy as np
import os
from utils.helpers import set_seed


hyperparameter_defaults = dict(
    batch_size=64,
    learning_rate=0.001,
    epochs=441,
    beta=20,
    c_fin=6,
    dimension=6,
    extraction_rate=1.0,
    img_id=6,
    random_seed=210
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()

wandb.init(project="exps", config=args, group='rotation', tags=['loss'])

config = args
set_seed(config.random_seed)



def get_preds(model, dl):
    model.eval()
    preds = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, targets


def _kl_normal_loss(mean, logvar, storer=None):
    latent_dim = mean.size(1)
    # batch mean of kl for each latent dimension
    latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=0)

    total_kl = (latent_kl).sum()
    if storer is not None:
        storer['KL_loss'].append(total_kl.item())
        for i in range(latent_dim):
            storer['kl_' + str(i)].append(latent_kl[i].item())

    return latent_kl

def train(dl, model, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)

    anneal = Constant(len(dl) * epochs, 1, 1, False)
    # c = 100
    # max_kl = 0
    weight = torch.linspace(1, config.extraction_rate, config.dimension).to('cuda') * 10
    c = 1e5
    loss_f = BtcvaeLoss(len(dl.dataset),
                        alpha=1,
                        beta=15,
                        gamma=1, record_loss_every=2,
                        rec_dist='bernoulli', anneal=anneal)
    for e in trange(epochs):
        storer = defaultdict(list)
        for img, label in dl:
            bs = len(img)
            img = img.view(-1, 1, 64, 64).cuda()
            # recon_batch, latent_dist, latent_sample = model(img)
            latent_dist = model.encoder(img)
            latent_sample = model.reparameterize(*latent_dist)
            # latent_sample1 = torch.cat([latent_sample[:,:5],label[:,1:2].cuda()],dim=1)
            recon_batch = model.decoder(latent_sample)

            # recon_loss = F.binary_cross_entropy(recon_batch, img, reduction="none").view(bs,-1).sum(1)

            # c = anneal.next()
            # v = recon_loss < c #选择优秀样本
            # if e>5:
            #     c = recon_loss.data.mean().item()
            #
            # recon_loss_s=recon_loss[v]
            # mu,logvar = latent_dist
            # kl_loss = _kl_normal_loss(mu[v],logvar[v], storer)
            #
            # kl = kl_loss.sum()
            # m = anneal.next()
            # loss = recon_loss_s.mean() + (kl-m).abs()*100
            # ((kl_loss*w1).sum()-c*w1.sum()).abs() #+ (kl_loss*weight).sum()
            loss = loss_f(img, recon_batch, latent_dist, True, storer, latent_sample=latent_sample)
            opt.zero_grad()
            loss.backward()
            opt.step()
            # storer['recon_loss'].append(recon_loss.mean().item())
            storer['loss'].append(loss.item())
            # storer['c'].append(c)
            # storer['y_loss'].append(y_loss.item())

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        # if kl_loss <= max_kl:
        #     c = max(1, c + 20 * (kl_loss.item() - max_kl))
        # max_kl = max(max_kl, kl_loss.item())

        kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
        _, index = kl_loss.sort(descending=True)

        if (e + 1) % 40 == 0:
            storer['recon'] = wandb.Image(recon_batch[0, 0])
            storer['img'] = wandb.Image(img[0, 0])

        if (e + 1) % 40 == 0:
            preds,targets=get_preds(model,dl)
            with torch.no_grad():

                points = preds[:1000,index[:3]].cpu()
                colors = targets[:1000]
                point_cloud = torch.cat([points, colors * torch.tensor([6, 6, 6]), ], 1).numpy()
                # storer['point_cloud'] = wandb.Object3D(point_cloud)
                fig, axes = plt.subplots(2, 3)
                for i in range(2):  # factor
                    for j in range(3):  # variable
                        axes[i, j].scatter(points[:, j].numpy(), colors[:, i].numpy(), s=0.2)
                        axes[i, j].set_title(f'{i},{index[j]}')
                storer['correlated'] = wandb.Image(fig)

                fig = plot_projection(imgs[::4, ::8, 20].cuda(), model, dim=index[:2])
                storer['projection_ay'] = wandb.Image(fig)

        storer['epoch'] = e
        wandb.log(storer, step=e, sync=False)
        plt.close()
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')

imgs, labels = data_generator.gen_rotation(img, img.shape)

epochs = config.epochs
dim = config.dimension

dataset = imgs[:, :20, 20].reshape(-1, 1, 64, 64)
nlabels = labels[:, :20, 20].reshape(-1, 3)
dl = DataLoader(list(zip(dataset, nlabels)), pin_memory=True,
                num_workers=0, batch_size=config.batch_size, shuffle=True)

vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)
# vae = nn.DataParallel(vae)
vae.cuda();
vae.train()
# wandb.watch(vae)

storer = train(dl, vae, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
# torch.save(vae.state_dict(), 'tmp_model.pt')
# refine
# refine(dl,vae,5)
vae.eval()

vae.cpu()

del config

fig = plot_reconstruct(dataset, (3, 2), vae)
wandb.log({'reconstruction': wandb.Image(fig)})

fig = plt_sample_traversal(None, vae, 7, index, r=2)
wandb.log({'traversal': wandb.Image(fig)})

fig = plot_projection(imgs[::4, ::4, 20], vae, dim=index[:2])
wandb.log({'projection_ay': wandb.Image(fig)})

# with torch.no_grad():
#     mu, _ = vae.encoder(imgs[::4, ::4, 20].reshape(-1, 1, 64, 64))
#     points = mu[:, index[:3]]
#     # points = 10*points
#     colors = labels[::4, ::4, 20].reshape(-1, 3) * torch.tensor([6, 6, 6])
#     point_cloud = torch.cat([points, colors, ], 1).numpy()
#     # np.savetxt('pc.txt', point_cloud)
#     wandb.log({'point_cloud': wandb.Object3D(point_cloud)})

plt.show()
wandb.join()