"""
同样的图像与序列，只有顺序不同。
"""
import argparse
from collections import defaultdict

import numpy as np
import wandb
from torch import nn, optim
from tqdm import trange

from disvae.models.decoders import get_decoder
from utils import data_generator
from utils.helpers import set_seed
from utils.visualize import *

hyperparameter_defaults = dict(
    batch_size=256,
    learning_rate=0.0005,
    epochs=2001,
    dimension=1,
    width=500,
    transformation=0,
    angle=0,
    loss='betaH',
    img_id=3,
    random_seed=224
)
parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args

set_seed(config.random_seed)
wandb.init(project="exps", config=config, group='transformation', tags=['shuffle'])


def train(model, epochs, log=True):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    z = torch.linspace(-2, 2, len(dataset)).cuda()
    for e in trange(epochs):
        storer = defaultdict(list)

        img = dataset.cuda()
        recon_batch = model(z.unsqueeze(1))
        loss = F.binary_cross_entropy(recon_batch, img)
        opt.zero_grad()
        loss.backward()
        opt.step()
        storer['loss'].append(loss.item())

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        if e % 200 == 0:
            wandb.log({
                'recon': wandb.Image(recon_batch[0, 0]),
                'img': wandb.Image(img[0, 0])
            })
            model.cpu()
            model.eval()
            m = nn.Module()
            m.decoder = model
            fig = plt_sample_traversal(None, m, 7, range(dim), r=2)
            storer['traversal'] = fig
            model.cuda()
            model.train()
            plt.close()
        storer['epoch'] = e
        if log:
            wandb.log(storer, sync=False)
    return storer


# generate data
img_id = config.img_id
epochs = config.epochs
dim = config.dimension

img = torch.load(f'patterns/{img_id}.pat')
imgs = data_generator.gen_rotation(img, img.shape)[0]
angle = config.angle
dataset = imgs[angle, :, 0].reshape(-1, 1, 64, 64)
if config.transformation == 1:
    shuffle = torch.randperm(40)
    dataset = dataset[shuffle]

decoder = get_decoder('Burgess')((1, 64, 64), dim, config.width)

decoder.cuda()
decoder.train()

iterations = epochs
storer = train(decoder, epochs)
plt.show()
