#tf-vaegan inductive sample-probing
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.autograd as autograd
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import numpy as np
import math
import sys
from sklearn import preprocessing
import csv
#import functions
import model
import util
import classifier as classifier
from config import opt
# --------------------
from meta_dset import MetaDataset
from eszsl import ESZSL
FN = torch.from_numpy
F = torch.nn.functional
# --------------------

print('HYPER-PARAMETERS', flush=True)
for k, v in vars(opt).items():
    print('{:20s}:{}'.format(k, v), flush=True) 
print(flush=True)

if opt.manualSeed is None:
    opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed, flush=True)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
    torch.cuda.manual_seed_all(opt.manualSeed)
cudnn.benchmark = True
if torch.cuda.is_available() and not opt.cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")
# load data
data = util.DATA_LOADER(opt)
print("# of training samples: ", data.ntrain, flush=True)   

netE = model.Encoder(opt)
netG = model.Generator(opt)
netD = model.Discriminator_D1(opt)

# Init models: Feedback module, auxillary module
if opt.feedback_loop == 2:
    netF = model.Feedback(opt)
else:
    netF = None
netDec = model.AttDec(opt,opt.attSize)

print(netE, flush=True)
print(netG, flush=True)
print(netD, flush=True)
if opt.feedback_loop == 2:
    print(netF, flush=True)
print(netDec, flush=True)

###########
# Init Tensors
input_res = torch.FloatTensor(opt.batch_size, opt.resSize)
input_att = torch.FloatTensor(opt.batch_size, opt.attSize) #attSize class-embedding size
noise = torch.FloatTensor(opt.batch_size, opt.nz)
noise_meta = torch.FloatTensor(opt.n_meta_syn, opt.nz)
one = torch.FloatTensor([1])
mone = one * -1
##########
# Cuda
if opt.cuda:
    netD.cuda()
    netE.cuda()
    if opt.feedback_loop == 2:
        netF.cuda()
    netG.cuda()
    netDec.cuda()
    input_res = input_res.cuda()
    noise, input_att = noise.cuda(), input_att.cuda()
    one = one.cuda()
    mone = mone.cuda()

def loss_fn(recon_x, x, mean, log_var):
    BCE = torch.nn.functional.binary_cross_entropy(recon_x+1e-12, x.detach(),size_average=False)
    BCE = BCE.sum()/ x.size(0)
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())/ x.size(0)
    return (BCE + KLD)
           
def sample():
    batch_feature, batch_att = data.next_seen_batch(opt.batch_size)
    input_res.copy_(batch_feature)
    input_att.copy_(batch_att)

def WeightedL1(pred, gt):
    wt = (pred-gt).pow(2)
    wt /= wt.sum(1).sqrt().unsqueeze(1).expand(wt.size(0),wt.size(1))
    loss = wt * (pred-gt).abs()
    return loss.sum()/loss.size(0)
    
def generate_syn_feature(generator,classes, attribute,num,netF=None,netDec=None):
    nclass = classes.size(0)
    syn_feature = torch.FloatTensor(nclass*num, opt.resSize)
    syn_label = torch.LongTensor(nclass*num) 
    syn_att = torch.FloatTensor(num, opt.attSize)
    syn_noise = torch.FloatTensor(num, opt.nz)
    if opt.cuda:
        syn_att = syn_att.cuda()
        syn_noise = syn_noise.cuda()
    for i in range(nclass):
        iclass = classes[i]
        iclass_att = attribute[iclass]
        syn_att.copy_(iclass_att.repeat(num, 1))
        syn_noise.normal_(0, 1)
        syn_noisev = Variable(syn_noise,volatile=True)
        syn_attv = Variable(syn_att,volatile=True)
        fake = generator(syn_noisev,c=syn_attv)
        if netF is not None:
            dec_out = netDec(fake) # only to call the forward function of decoder
            dec_hidden_feat = netDec.getLayersOutDet() #no detach layers
            feedback_out = netF(dec_hidden_feat)
            fake = generator(syn_noisev, a1=opt.a2, c=syn_attv, feedback_layers=feedback_out)
        output = fake
        syn_feature.narrow(0, i*num, num).copy_(output.data.cpu())
        syn_label.narrow(0, i*num, num).fill_(iclass)

    return syn_feature, syn_label

optimizer = optim.Adam(netE.parameters(), lr=opt.lr)
optimizerD = optim.Adam(netD.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))

# init Generator's optimizer 
optimizerG = optim.Adam(netG.parameters(), lr=opt.lr,betas=(opt.beta1, 0.999))

if opt.feedback_loop == 2:
    optimizerF = optim.Adam(netF.parameters(), lr=opt.feed_lr, betas=(opt.beta1, 0.999))
optimizerDec = optim.Adam(netDec.parameters(), lr=opt.dec_lr, betas=(opt.beta1, 0.999))

def calc_gradient_penalty(netD,real_data, fake_data, input_att):
    alpha = torch.rand(opt.batch_size, 1)
    alpha = alpha.expand(real_data.size())
    if opt.cuda:
        alpha = alpha.cuda()
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    if opt.cuda:
        interpolates = interpolates.cuda()
    interpolates = Variable(interpolates, requires_grad=True)
    disc_interpolates = netD(interpolates, Variable(input_att))
    ones = torch.ones(disc_interpolates.size())
    if opt.cuda:
        ones = ones.cuda()
    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=ones,
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * opt.lambda1
    return gradient_penalty

best_gzsl_acc = 0
best_zsl_acc = 0
for epoch in range(0,opt.nepoch):
    for loop in range(0,opt.feedback_loop):
        for i in range(0, data.ntrain, opt.batch_size):
            #########Discriminator training ##############
            for p in netD.parameters(): #unfreeze discrimator
                p.requires_grad = True

            for p in netDec.parameters(): #unfreeze deocder
                p.requires_grad = True
            # Train D1 and Decoder (and Decoder Discriminator)
            gp_sum = 0 #lAMBDA VARIABLE
            for iter_d in range(opt.critic_iter):
                sample()
                netD.zero_grad()          
                input_resv = Variable(input_res)
                input_attv = Variable(input_att)

                netDec.zero_grad()
                recons = netDec(input_resv)
                R_cost = opt.recons_weight*WeightedL1(recons, input_attv) 
                R_cost.backward()
                optimizerDec.step()
                criticD_real = netD(input_resv, input_attv)
                criticD_real = -opt.gammaD*criticD_real.mean()
                criticD_real.backward()
                if opt.encoded_noise: # --> True for CUB        
                    means, log_var = netE(input_resv, input_attv)
                    std = torch.exp(0.5 * log_var)
                    eps = torch.randn([opt.batch_size, opt.latent_size]).cpu()
                    eps = Variable(eps.cuda())
                    z = eps * std + means #torch.Size([64, 312])
                else:
                    noise.normal_(0, 1)
                    z = Variable(noise)

                if loop == 1:
                    fake = netG(z, c=input_attv)
                    dec_out = netDec(fake)
                    dec_hidden_feat = netDec.getLayersOutDet()
                    feedback_out = netF(dec_hidden_feat)
                    fake = netG(z, a1=opt.a1, c=input_attv, feedback_layers=feedback_out)
                else:
                    fake = netG(z, c=input_attv)

                criticD_fake = netD(fake.detach(), input_attv)
                criticD_fake = opt.gammaD*criticD_fake.mean()
                criticD_fake.backward()
                # gradient penalty
                gradient_penalty = opt.gammaD*calc_gradient_penalty(netD, input_res, fake.data, input_att)
                # if opt.lambda_mult == 1.1:
                gp_sum += gradient_penalty.data
                gradient_penalty.backward()         
                Wasserstein_D = criticD_real - criticD_fake
                D_cost = criticD_fake - criticD_real + gradient_penalty #add Y here and #add vae reconstruction loss
                optimizerD.step()

            gp_sum /= (opt.gammaD*opt.lambda1*opt.critic_iter)
            if (gp_sum > 1.05).sum() > 0:
                opt.lambda1 *= 1.1
            elif (gp_sum < 1.001).sum() > 0:
                opt.lambda1 /= 1.1

            #############Generator training ##############
            # Train Generator and Decoder
            for p in netD.parameters(): #freeze discrimator
                p.requires_grad = False
            if opt.recons_weight > 0 and opt.freeze_dec:
                for p in netDec.parameters(): #freeze decoder
                    p.requires_grad = False

            netE.zero_grad()
            netG.zero_grad()
            if opt.feedback_loop == 2:
                netF.zero_grad()
            input_resv = Variable(input_res)
            input_attv = Variable(input_att)
            means, log_var = netE(input_resv, input_attv)
            std = torch.exp(0.5 * log_var)
            eps = torch.randn([opt.batch_size, opt.latent_size]).cpu()
            eps = Variable(eps.cuda())
            z = eps * std + means #torch.Size([64, 312])
            if loop == 1:
                recon_x = netG(z, c=input_attv)
                dec_out = netDec(recon_x)
                dec_hidden_feat = netDec.getLayersOutDet()
                feedback_out = netF(dec_hidden_feat)
                recon_x = netG(z, a1=opt.a1, c=input_attv, feedback_layers=feedback_out)
            else:
                recon_x = netG(z, c=input_attv)

            vae_loss_seen = loss_fn(recon_x, input_resv, means, log_var) # minimize E 3 with this setting feedback will update the loss as well
            errG = vae_loss_seen
            
            if opt.encoded_noise:
                criticG_fake = netD(recon_x,input_attv).mean()
                fake = recon_x 
            else:
                noise.normal_(0, 1)
                noisev = Variable(noise)
                if loop == 1:
                    fake = netG(noisev, c=input_attv)
                    dec_out = netDec(recon_x) #Feedback from Decoder encoded output
                    dec_hidden_feat = netDec.getLayersOutDet()
                    feedback_out = netF(dec_hidden_feat)
                    fake = netG(noisev, a1=opt.a1, c=input_attv, feedback_layers=feedback_out)
                else:
                    fake = netG(noisev, c=input_attv)
                criticG_fake = netD(fake,input_attv).mean()

            # ------------------------------
            # PERFORM META-LEARNING
            # ------------------------------
            meta_loss = 0. # init meta-loss
            if opt.meta:
                eszsl = ESZSL(d_ft=opt.resSize, d_attr=opt.attSize, alpha=opt.alpha, gamma=opt.gamma) # init ESZSL
                # ------------------------------
                # META DATASET & ITERATOR INITIALIZATION
                # ------------------------------
                meta_dset = MetaDataset(data=data, opt=opt) # init meta-dataset
                meta_iterator = torch.utils.data.DataLoader(meta_dset, batch_size=1) # init meta-iterator
                # ------------------------------
                for x_support, y_support, x_query, y_query in meta_iterator:
                    # ----------------------------------------------------------------
                    # CREATE A SYNTHETIC DATASET USING SUPPORT CLASSES
                    # ----------------------------------------------------------------
                    unq_support_classes = FN(np.unique(y_support[0].numpy()))
                    support_classes_onehot = torch.diag(torch.ones(opt.n_support))
                    X, Y = [], []
                    for idx, c in enumerate(unq_support_classes):
                        noisev = Variable(noise_meta.normal_(0, 1).cuda()) # init noise
                        av = Variable(data.attribute[[c]].repeat(opt.n_meta_syn, 1).cuda()) # attribute
                        if loop == 1: # loop = 1 means that feedback loop is on and we're in feedback loop
                            x_syn = netG(noisev, c=av)
                            dec_out = netDec(x_syn)
                            dec_hidden_feat = netDec.getLayersOutDet()
                            feedback_out = netF(dec_hidden_feat)
                            x_syn = netG(noisev, a1=opt.a1, c=av, feedback_layers=feedback_out)
                        else:
                            x_syn = netG(noisev, c=av)

                        y_syn = Variable(support_classes_onehot[[idx]].repeat(opt.n_meta_syn, 1).cuda())
                        X.append(x_syn); Y.append(y_syn) # store

                    Xv = torch.cat(X) # synthetic image features
                    Yv = torch.cat(Y) # synthetic image labels
                    # ----------------------------------------------------------------
                    # COMPUTE W WITH SYNTHETIC DATASET
                    # ----------------------------------------------------------------
                    Sv = Variable(data.attribute[unq_support_classes].cuda())
                    W = eszsl.find_solution(Xv, Yv, Sv) # compute W
                    # ----------------------------------------------------------------
                    # COMPUTE ESZSL-LOSS WITH QUERY SAMPLES
                    # ----------------------------------------------------------------
                    task_loss = 0.
                    for i in range(opt.n_subset):
                        unq_query_classes = FN(np.unique(y_query[i].numpy()))
                        if opt.meta_loss_type == 'gzsl':
                            unq_classes = torch.sort(torch.cat((unq_support_classes, unq_query_classes)))[0]
                        elif opt.meta_loss_type == 'zsl':
                            unq_classes = unq_query_classes
                        
                        x = Variable(x_query[i][0].cuda())
                        y = Variable(FN(np.searchsorted(unq_classes.numpy(), y_query[i][0].numpy())).cuda())
                        s = Variable(data.attribute[unq_classes].cuda())
                        # compute loss
                        y_ = eszsl.solve(x, W, s)
                        task_loss += torch.nn.functional.cross_entropy(y_, y)

                    task_loss /= opt.n_subset # compute average loss
                    meta_loss += task_loss

                meta_loss /= opt.n_task # compute average meta-loss

                MLoss = opt.meta_weight * meta_loss
                MLoss.backward() # backprop using meta loss
            # ------------------------------

            G_cost = -criticG_fake
            errG += opt.gammaG*G_cost
            netDec.zero_grad()
            recons_fake = netDec(fake)
            R_cost = WeightedL1(recons_fake, input_attv)
            errG += opt.recons_weight * R_cost
            errG.backward()
            
            # write a condition here
            optimizer.step()
            optimizerG.step()
            if loop == 1:
                optimizerF.step()
            if opt.recons_weight > 0 and not opt.freeze_dec: # not train decoder at feedback time
                optimizerDec.step() 
        
    print('[%d/%d]  Loss_D: %.4f Loss_G: %.4f, Wasserstein_dist:%.4f, vae_loss_seen:%.4f'% (
		epoch, opt.nepoch, 
        D_cost.data[0], 
        G_cost.data[0], 
        Wasserstein_D.data[0],
        vae_loss_seen.data[0]), 
        end=" ", flush=True)

    netG.eval()
    netDec.eval()
    if opt.feedback_loop == 2:
        netF.eval()
    syn_feature, syn_label = generate_syn_feature(netG,data.unseenclasses, data.attribute, opt.syn_num,netF=netF,netDec=netDec)

    # Generalized zero-shot learning
    if opt.gzsl:   
        # Concatenate real seen features with synthesized unseen features
        train_X = torch.cat((data.train_feature, syn_feature), dim=0)
        train_Y = torch.cat((data.train_label, syn_label), dim=0)
        nclass = opt.nclass_all
        # Train GZSL classifier
        gzsl_cls = classifier.CLASSIFIER(train_X, train_Y, data, nclass, opt.cuda, opt.classifier_lr, 0.5, \
                25, opt.syn_num, generalized=True, netDec=netDec, dec_size=opt.attSize, dec_hidden_size=4096)
        #if best_gzsl_acc < gzsl_cls.H:
        best_acc_seen, best_acc_unseen, best_gzsl_acc = gzsl_cls.acc_seen, gzsl_cls.acc_unseen, gzsl_cls.H # keep last iteration's results as best
        print('GZSL: seen=%.4f, unseen=%.4f, h=%.4f' % (gzsl_cls.acc_seen, gzsl_cls.acc_unseen, gzsl_cls.H),end=" ", flush=True)

    # Zero-shot learning
    # Train ZSL classifier
    zsl_cls = classifier.CLASSIFIER(syn_feature, util.map_label(syn_label, data.unseenclasses), \
                    data, data.unseenclasses.size(0), opt.cuda, opt.classifier_lr, 0.5, 25, opt.syn_num, \
                    generalized=False, netDec=netDec, dec_size=opt.attSize, dec_hidden_size=4096)
    acc = zsl_cls.acc
    #if best_zsl_acc < acc:
    best_zsl_acc = acc # keep last iteration's results as best
    print('ZSL: unseen accuracy=%.4f' % (acc), flush=True)
    # reset G to training mode
    netG.train()
    netDec.train()
    if opt.feedback_loop == 2:
        netF.train()

    # save generator
    if not opt.validation:
        #save_filepath = "classifier/{}/generator_iter-{:03}_meta-{}.pt".format(opt.dataset, epoch, opt.meta)
        #torch.save(netG.state_dict(), save_filepath)
        #print('saved')
        pass
    # ---

print('Dataset', opt.dataset, flush=True)
print('the best ZSL unseen accuracy is', best_zsl_acc, flush=True)
if opt.gzsl:
    print('Dataset', opt.dataset, flush=True)
    print('the best GZSL seen accuracy is', best_acc_seen, flush=True)
    print('the best GZSL unseen accuracy is', best_acc_unseen, flush=True)
    print('the best GZSL H is', best_gzsl_acc, flush=True)


