"""
采样部分部分标签数据，检验transformation的性质。探究model bias，可以视为一种 model selection。
1-5: shape,scaling,orientation,x,y
"""
import argparse
from collections import defaultdict

import numpy as np
import wandb
from torch import optim
from torch.utils.data import DataLoader
from tqdm import trange
from fastai.vision import *
from disvae import Evaluator
from disvae.models.anneal import *
from disvae.models.decoders import get_decoder
from disvae.models.encoders import get_encoder
from disvae.models.losses import *
from disvae.models.vae import VAE
from utils.datasets import get_dataloaders, DSprites
from utils.helpers import set_seed
from utils.visualize import *

hyperparameter_defaults = dict(
    batch_size=512,
    learning_rate=0.0005,
    epochs=3,
    loss='betaH',
    random_seed=156,
    beta=1,
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args
set_seed(config.random_seed)



def train(dl, model, epochs, loss_f):
    opt = optim.Adam(model.parameters(), config.learning_rate, weight_decay=0)
    storer = defaultdict(list)
    itr = 0
    for e in trange(epochs):
        for i, (img, label) in enumerate(dl):
            img = img.view(-1, 1, 64, 64).cuda().float()
            label = label.float().clone().cuda()
            recon_batch, latent_dist, latent_sample = model(img)
            loss = loss_f(img, recon_batch, latent_dist, model.training,
                          storer, latent_sample=latent_sample)

            opt.zero_grad()
            loss.backward()
            opt.step()
            storer['loss'].append(loss.item())
            itr += 1
            if itr % 100 == 0:
                for k, v in storer.items():
                    if isinstance(v, list):
                        storer[k] = np.mean(v)
                storer['itr'] = itr
                # evaluate(model,train_loader,storer)
                wandb.log(storer, sync=False)
                storer = defaultdict(list)

        model.eval()
        fig = plt_sample_traversal(img[:1],  vae, 7, range(dim))
        storer['traversal'] = wandb.Image(fig)
        model.train()

    for k, v in storer.items():
        if isinstance(v, list):
            storer[k] = np.mean(v)
    return storer


beta = args.beta
epochs = config.epochs

dsprites = DSprites()
train_loader = DataLoader(dsprites,num_workers=4,
                          batch_size=config.batch_size,
                          shuffle=True)

wandb.init(project="week_supervised", group=__file__.split('/')[-1][:-3], notes=__doc__, #tags=['equilibrium']
           config=args, )

image_size = (1, 64, 64)
dim=6
vae = VAE(image_size, get_encoder('Burgess'), get_decoder('Burgess'),dim)
vae.cuda()
vae.train()

anneal = get_anneal('constant', epochs, 1, 1)
loss_f = get_loss_s(config.loss, anneal, beta, len(train_loader.dataset))
storer = train(train_loader, vae, epochs, loss_f)
#
# torch.save(vae.state_dict(),'tmp.pt')
# vae.load_state_dict(torch.load('tmp.pt'))
vae.eval()
test_loader = DataLoader(dsprites,
                         batch_size=config.batch_size,
                         num_workers=0,
                         shuffle=False, )
# anneal

anneal = get_anneal('constant', 1, 1, 1)

evaluator = Evaluator(vae, loss_f,
                      device='cuda',
                      save_dir=wandb.run.dir,
                      is_progress_bar=False)
def inversion(seq):
    s1 = 0
    s2 = 0
    for i in range(1, seq.size(1)):
        for j in range(i):
            s1 += (seq[:, j] > seq[:, i]).long()
            s2 += (seq[:, j] <= seq[:, i]).long()
    t = torch.stack([s1, s2])
    return t.min(0)[0], t.max(0)[0]

params_zCX, labels = evaluator.compute(test_loader)

def inv_rate(target_z,transformation):
    global params_zCX,labels
    mu = params_zCX[0][:,target_z].view([3, 6, 40, 32, 32, 1]).cpu()
    labels = labels.view([3, 6, 40, 32, 32, 6])
    order = mu.transpose(transformation - 1, 5).argsort(5)
    gt_order = labels[..., transformation:transformation + 1].transpose(transformation - 1, 5)

    elements = order.size(5)
    order = order.reshape(-1, elements)
    gt_order = gt_order.reshape(-1, elements)
    x = torch.arange(len(order)).reshape(-1, 1).expand(len(order), elements)
    seq = gt_order[x, order]
    s_min, s_max = inversion(seq)
    return (s_min.float() / s_max.float()).mean().item()

vae.cpu()
ans=np.zeros((dim,5))
for i in range(dim):
    for j in range(1,6):
        ans[i,j-1]=inv_rate(i,j)