"""
data cue是指数据中的引导解耦的线索。
"""
import argparse
from collections import defaultdict

import numpy as np
import wandb
from torch import optim
from torch.utils.data import DataLoader
from tqdm import trange

from disvae import Evaluator
from disvae.models.anneal import *
from disvae.models.decoders import get_decoder
from disvae.models.encoders import get_encoder
from disvae.models.losses import get_loss_s
from disvae.models.vae import VAE
from utils import *
from utils.data_generator import RotationTranslation
from utils.helpers import set_seed
from utils.visualize import *

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=0.0005,
    iterations=10001,
    loss='betaH',
    random_seed=210
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args
set_seed(config.random_seed)


def get_preds(model, dl):
    model.eval()
    preds = []
    reals = []
    targets = []
    for img, label in dl:
        with torch.no_grad():
            img = img.view(-1, 1, 64, 64).cuda()
            mu, _ = model.encoder(img)
            preds.append(mu)
            reals.append(img)
            targets.append(label)
    preds, targets = torch.cat(preds), torch.cat(targets)
    model.train()
    return preds, torch.cat(reals), targets


def evaluate(model, dl, storer):
    preds, img, targets = get_preds(model, dl)

    kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
    _, index = kl_loss.sort(descending=True)

    if len(index) >= 3:
        points = preds[:1000, index[:3]].cpu()
        colors = targets[:1000]
        fig, axes = plt.subplots(3, 3)
        for i in range(3):  # factor
            for j in range(3):  # variable
                axes[i, j].scatter(points[:, j].numpy(), colors[:, i].numpy(), s=0.2)
                axes[i, j].set_title(f'{i},{index[j]}')
        storer['correlated'] = wandb.Image(fig)

    plt.close()


def train(dl, model, iterations, loss_f):
    opt = optim.Adam(model.parameters(), config.learning_rate, weight_decay=0)
    storer = defaultdict(list)
    for e in trange(iterations//len(dl)):
        for i, (img, _) in enumerate(dl):
            img = img.view(-1, 1, 64, 64).cuda()

            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            storer['loss'].append(loss.item())
            itr = e * len(train_loader) + i
            if itr % 100 == 0:
                for k, v in storer.items():
                    if isinstance(v, list):
                        storer[k] = np.mean(v)
                storer['itr'] = itr
                evaluate(model,train_loader,storer)
                wandb.log(storer, sync=False)
                storer = defaultdict(list)
            if itr > iterations:
                for k, v in storer.items():
                    if isinstance(v, list):
                        storer[k] = np.mean(v)
                return storer
    for k, v in storer.items():
        if isinstance(v, list):
            storer[k] = np.mean(v)
    return storer


search_range=[
    [10,40,10],
    [60,160,10],
    [60,160,10],
]

for pattern in range(0,3):
    img = torch.load(f'exps/patterns/{pattern}.pat')
    dataset = RotationTranslation(pattern)
    test_loader = DataLoader(dataset,
                             batch_size=512,
                             shuffle=False)
    for cue in range(3):
        s = [0, 0, 0]
        s[cue] = slice(None, None, None)
        s = tuple(s)
        train_ds = list(zip(dataset.o_imgs[s].reshape(-1,1,64,64),
                            dataset.o_labels[s].reshape(-1,3)))
        train_loader = DataLoader(train_ds,
                                  batch_size=config.batch_size,
                                  shuffle=True)



        for beta in np.linspace(*search_range[cue]):
            wandb.init(project="experiment", group=__file__.split('/')[-1][:-3], notes=__doc__,
                       reinit=True, config=args, )
            wandb.config.pattern = pattern
            wandb.config.beta = beta
            wandb.config.cue=cue
            iterations = config.iterations
            dim = 10
            vae = VAE((1, 64, 64), get_encoder('Burgess'), get_decoder('Burgess'), dim)

            vae.cuda();
            vae.train()
            if config.loss == 'betaB':
                anneal = get_anneal('monotonic', iterations, 0, 1)
            else:
                anneal = get_anneal('constant', iterations, 1, 1)
            loss_f = get_loss_s(config.loss, anneal, beta, len(train_loader.dataset))
            storer = train(train_loader, vae, iterations, loss_f)

            # torch.save(vae.state_dict(),'tmp.pt')
            # vae.load_state_dict(torch.load('tmp.pt'))
            evaluator = Evaluator(vae, loss_f,
                                  device='cuda',
                                  logger=wandb.logger,
                                  save_dir=wandb.wandb_dir(),
                                  is_progress_bar=False)

            evaluator(test_loader, is_metrics=True, is_losses=True)
            wandb.join()
            try:
                if storer['KL_loss']<0.3:
                    break
            except :
                print('error')

