"""
data cue是指数据中的引导解耦的线索。
"""
import argparse
from collections import defaultdict

import numpy as np
import wandb
from torch import optim
from torch.utils.data import DataLoader
from tqdm import trange

from disvae import Evaluator
from disvae.models.anneal import *
from disvae.models.decoders import get_decoder
from disvae.models.encoders import get_encoder
from disvae.models.losses import get_loss_s
from disvae.models.vae import VAE
from utils.datasets import get_dataloaders, DSprites
from utils.helpers import set_seed
from utils.visualize import *

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.0005,
    iterations=10001,
    loss='betaH',
    random_seed=210
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args
set_seed(config.random_seed)


def get_preds(model, dl):
    model.eval()
    preds = []
    reals = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            reals.append(img)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, torch.cat(reals), targets


def evaluate(model, dl, storer):
    preds, img, targets = get_preds(model, dl)

    kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
    _, index = kl_loss.sort(descending=True)

    if len(index) >= 3:
        points = preds[:1000, index[:3]].cpu()
        colors = targets[:1000]
        fig, axes = plt.subplots(3, 3)
        for i in range(3):  # factor
            for j in range(3):  # variable
                axes[i, j].scatter(points[:, j].numpy(), colors[:, i].numpy(), s=0.2)
                axes[i, j].set_title(f'{i},{index[j]}')
        storer['correlated'] = wandb.Image(fig)

    plt.close()


def train(dl, model, iterations, loss_f):
    opt = optim.Adam(model.parameters(), config.learning_rate, weight_decay=0)
    storer = defaultdict(list)
    for e in trange(iterations//len(dl)):
        for i, (img, _) in enumerate(dl):
            img = img.view(-1, 1, 64, 64).cuda()

            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            storer['loss'].append(loss.item())
            itr = e * len(train_loader) + i
            if itr % 100 == 0:
                for k, v in storer.items():
                    if isinstance(v, list):
                        storer[k] = np.mean(v)
                storer['itr'] = itr
                # evaluate(model,train_loader,storer)
                wandb.log(storer, sync=False)
                storer = defaultdict(list)
            if itr > iterations:
                for k, v in storer.items():
                    if isinstance(v, list):
                        storer[k] = np.mean(v)
                return storer
    for k, v in storer.items():
        if isinstance(v, list):
            storer[k] = np.mean(v)
    return storer


dsprites = DSprites()
imgs = dsprites.imgs.reshape(3, 6, 40, 32, 32, 64, 64)
labels = dsprites.lat_values.reshape(3, 6, 40, 32, 32, 6)
lat_sizes=dsprites.lat_sizes
lat_names=dsprites.lat_names


search_range=[
    [20,30,5],
    [80,160,10],
    [5,20,5],
    [80,160,10],
    [80,160,10],
]
for cue in range(5):
    s = [0, 0, 0, 0, 0]
    s[cue] = slice(None, None, None)
    s = tuple(s)
    dsprites.imgs = imgs[s]
    dsprites.lat_values = labels[s]
    dsprites.lat_sizes = lat_sizes[cue:cue + 1]
    dsprites.lat_names = (lat_names[cue])
    train_loader = DataLoader(dsprites,
                              batch_size=config.batch_size,
                              shuffle=True)

    for beta in np.linspace(*search_range[cue]):
        wandb.init(project="experiment", group=__file__.split('/')[-1][:-3], notes=__doc__,
                   reinit=True, config=args, )
        wandb.config.beta = beta
        wandb.config.cue = cue
        iterations = config.iterations
        dim = 10
        vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)

        vae.cuda();
        vae.train()
        if config.loss == 'betaB':
            anneal = get_anneal('monotonic', iterations, 0, 1)
        else:
            anneal = get_anneal('constant', iterations, 1, 1)
        loss_f = get_loss_s(config.loss, anneal, beta, len(imgs))
        storer = train(train_loader, vae, iterations, loss_f)

        # torch.save(vae.state_dict(),'tmp.pt')
        # vae.load_state_dict(torch.load('tmp.pt'))

        test_loader = DataLoader(dsprites,
                                  batch_size=512,
                                  shuffle=False)
        evaluator = Evaluator(vae, loss_f,
                              device='cuda',
                              logger=wandb.logger,
                              save_dir=wandb.wandb_dir(),
                              is_progress_bar=False)

        evaluator(test_loader, is_metrics=True, is_losses=True)
        wandb.join()
        try:
            if storer['KL_loss']<0.3:
                break
        except :
            print('error')

