import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from models import VqaModel, Discriminator
from torch.optim import lr_scheduler
from utils import *
from loss import CustomLoss
from PIL import Image
from vocab import Vocabulary, load_vocab
from data_loader import get_loader
from collections import OrderedDict
import re
import glob
def sortedStringList(array=[]):
    sortDict=OrderedDict()
    for splitList in array:
        sortDict.update({splitList:[int(x) for x in re.split("(\d+)",splitList)if bool(re.match("\d*",x).group())]})
    return [sortObjKey for sortObjKey,sortObjValue in sorted(sortDict.items(), key=lambda x:x[1])]




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead


def main():
    """
    Training and validation.
    """

    word_map = VocabDict(args.input_dir+'/vocab_questions.txt')
    ans_word_map = VocabDict(args.input_dir+'/vocab_answers.txt')

    candidate_path = "path/to/candidate/*.jpg"
    CandidateList = sortedStringList(glob.glob(candidate_path))

    print("VQG loading...")    
    checkpoint = torch.load(args.VQG_checkpoint)
    VQG_decoder = checkpoint['decoder']
    VQG_decoder = VQG_decoder.to(device)

    optimizer = torch.optim.Adam([
        {"params": VQG_decoder.module.parameters()}
    ], lr=args.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

    ## Discriminator
    if args.discriminative_loss:
        print("Discriminator loading...")
        D = Discriminator(
            qst_vocab_size= len(word_map),
            word_embed_size= 300,
            num_layers= 2,
            hidden_size= 512,
            embed_size=1024
        ) 
        D.load_state_dict(fix_key2(torch.load(args.discriminator_PATH)))
        D = D.to(device)

        word_embeddings_discriminator = torch.arange(len(word_map)).to(device)
        word_embeddings_discriminator = D.word2vec(word_embeddings_discriminator)

    ## VQA
    print("VQA loading...")    
    VQA = VqaModel(
        embed_size=1024,
        qst_vocab_size=len(word_map),
        ans_vocab_size=len(ans_word_map),
        word_embed_size=300,
        num_layers=2,
        hidden_size=512)

    VQA.load_state_dict(fix_key(torch.load(args.VQA_checkpoint)))
    VQA = VQA.to(device)

    # for converting onehot to embeddings    
    word_embeddings = torch.arange(len(word_map)).to(device)
    word_embeddings = VQA.qst_encoder.word2vec(word_embeddings)
    VQA = torch.nn.DataParallel(VQA)


    # VQA grad false since no trainig 
    for name, p in VQA.module.named_parameters():
        p.requires_grad = False
 

    # Loss function
    criterion = CustomLoss(args.loss_margin, args.batch_size, args.weight_word, args.word_weight_index, args.weight_value).to(device)
    criterion_discriminative = torch.nn.CrossEntropyLoss().to(device)
    # Custom dataloaders

    data_loader, img_loader = get_loader(
        sim_PATH=args.sim_PATH,
        feat_PATH=args.feat_PATH,
        images_PATHs=CandidateList,
        batch_size=args.batch_size,
        img_batch_size=args.img_batch_size,
        num_workers=args.num_workers)

    # Epochs
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    loss_disp = torch.tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],dtype=float)
    for epoch in range(args.epochs):
        VQG_decoder.train()
        VQA.train()


        for i, (GT_imgs, sims, feats, GT_idxs) in enumerate(data_loader):
            optimizer.zero_grad()

            # to device
            GT_imgs = GT_imgs.to(device)
            sims = sims.to(device)
            feats = feats.to(device)
            GT_idxs = GT_idxs.to(device)
            sims.requires_grad = True
            re_sims = sims.clone()

            generated_qsts, scores = VQG_decoding(VQG_decoder, feats)
            z = torch.nn.functional.gumbel_softmax(scores, tau=2, hard=True, dim = 2)
        
            z_VQA = torch.matmul(z,word_embeddings)
            z_Discriminator = torch.matmul(z, word_embeddings_discriminator)


            # VQA
            GT_ans_score = VQA(GT_imgs, z_VQA)#
            _, GT_ans = torch.max(GT_ans_score, 1)

            ar = torch.arange(args.batch_size)
            z_GT_ans = torch.nn.functional.gumbel_softmax(GT_ans_score, tau=2, hard=True, dim = 1)




            qst_time = AverageMeter()
            start = time.time()

            for j in range(generated_qsts.size(0)):
                now_qst = z_VQA[j].repeat(args.img_batch_size,1).view(args.img_batch_size, args.max_qst_length, z_VQA.size(2))

                for k, (imgs, idxs) in enumerate(img_loader):
                    imgs = imgs.to(device)
                    ans_score = VQA(imgs, now_qst)

                    _, ans = torch.max(ans_score, 1)
                    ar_2 = torch.arange(args.img_batch_size)
                    z_ans_score = torch.nn.functional.gumbel_softmax(ans_score, tau=2, hard=True, dim = 1)
                    tmp = args.hyp_sim * torch.matmul(z_ans_score,z_GT_ans[j])
                    re_sims[j][idxs] = sims[j][idxs] + tmp

                
                qst_time.update(time.time() - start)
                start = time.time()
                print("epoch: {}/{}, step: {}/{}, qst: {}/{}, qst Time {qst_time.val:.3f} ({qst_time.avg:.3f})".format(epoch,args.epochs,i, int(len(CandidateList)/args.batch_size),j,generated_qsts.size(0),qst_time=qst_time))


            if args.discriminative_loss:
                predicts = D(z_Discriminator)
                labels = torch.full_like(predicts[:,0], 0).to(device)
                loss_discriminative = criterion_discriminative(predicts, labels.to(torch.long))
                loss_discriminative = loss_discriminative.to(device)

            loss_VQA = criterion(re_sims,GT_idxs, GT_ans)
            if args.discriminative_loss:
                loss = (1-args.discriminative_loss_weight)* loss_VQA + args.discriminative_loss_weight * loss_discriminative
            else:
                loss = loss_VQA

            sims.retain_grad()
            re_sims.retain_grad()

            z_GT_ans.retain_grad()
            GT_ans_score.retain_grad()
            z.retain_grad()
            scores.retain_grad()
            loss.retain_grad()
            loss.backward(retain_graph=True)
            # print(scores.grad)
            optimizer.step()
            loss_disp[i%10] = loss.clone().detach().to(torch.float)
            print(loss_disp)
            print(torch.mean(loss_disp))
            if args.discriminative_loss:
                print("All:{}, Ranking:{}, Discriminative:{}".format(loss,loss_VQA,loss_discriminative))

            del loss, re_sims, sims, ans_score, z_GT_ans, GT_ans_score, z, z_VQA, z_Discriminator,loss_VQA, loss_discriminative,now_qst, tmp, z_ans_score, _

        scheduler.step()
        torch.save(VQG_decoder.state_dict(), "./checkpoints/chechpoints_{}epoch.pth".format(epoch))
            

def VQG_decoding(VQG_decoder,encoder_out):

    # VQG encoding
    batch_size = encoder_out.size(0)
    encoder_dim = encoder_out.size(-1)
    vocab_size = VQG_decoder.module.vocab_size       

    # Flatten image
    encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)

    # start word
    k_prev_words = torch.LongTensor([[word_map.word2idx('<start>')]] * batch_size).to(device)
    seqs = k_prev_words

    socres_ret = torch.zeros(args.max_qst_length,batch_size,vocab_size).to(device)
    socres_ret[0,:,word_map.word2idx('<start>')] = 20

    incomplete_inds = list()
    complete_inds = list()
    step = 1
    h, c = VQG_decoder.module.init_hidden_state(encoder_out)  
    k = batch_size

    ar = torch.arange(batch_size)
    while True:
        # embeddings
        embeddings = VQG_decoder.module.embedding(k_prev_words).squeeze(1) #(batch_size, embeded_dim)
        awe, _ = VQG_decoder.module.attention(encoder_out, h) #(batch_size, encoder_dim), (batch_size, num_pixels)
        gate = VQG_decoder.module.sigmoid(VQG_decoder.module.f_beta(h))  # gating scalar, (batch_size, encoder_dim)
        awe = gate * awe
        h, c = VQG_decoder.module.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (batch_size, decoder_dim)
        scores = VQG_decoder.module.fc(h)  # (batch_size, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        # store return scores
        socres_ret[step] = scores
        
        # calculate max score words
        next_word_inds = torch.argmax(scores, dim=1)

        # concat words to return sequence
        seqs = torch.cat([seqs, next_word_inds.unsqueeze(1)], dim=1)

        incomplete_inds += [ind for ind, next_word in enumerate(next_word_inds) if
                            (next_word != word_map.word2idx('<end>') and (ind not in incomplete_inds))]


        complete_inds += [ind for ind, next_word in enumerate(next_word_inds) if
                            (next_word == word_map.word2idx('<end>') and (ind not in complete_inds))]        
        k = len(incomplete_inds)
        if k == 0:
            break
        k_prev_words = next_word_inds.unsqueeze(1)
        if step > 28:
            break
        step += 1

    return seqs, socres_ret.permute(1,0,2)






if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', type=str, default='./../Dataset',
                        help='input directory for visual question answering.')

    parser.add_argument('--log_dir', type=str, default='./logs',
                        help='directory for logs.')

    parser.add_argument('--VQG_checkpoint', type=str,
            default="./../VQG/checkpoints/BEST_checkpoint_vqa.pth.tar", help='path to saved checkpoint')

    parser.add_argument('--sim_PATH', type=str,
            default="/path/to/sim.npy", help='path to similarity')
    parser.add_argument('--feat_PATH', type=str,
            default="./data/ImageFeatWeightAccumurate.npy", help='path to input feat')

    parser.add_argument('--VQA_checkpoint', type=str,
            default="./../chechpoints/model-epoch-30.ckpt", help='path to saved checkpoint')

    parser.add_argument('--loss_margin', type=float,
            default=0.15, help='loss margin')

    parser.add_argument('--hyp_sim', type=float,
            default=0.3, help='parameter that balances initial and re-rank similarities')

    parser.add_argument('--lr', type=float,
            default=0.0001, help='learning rate')

    parser.add_argument('--unbias',
            default=False, help='unbias vqa')


    parser.add_argument('--discriminative_loss',
            default=True, help='whether weighting word')
    parser.add_argument('--discriminative_loss_weight', type=float,
            default=0.8, help='weight discriminative loss')
    parser.add_argument('--discriminator_PATH', type=str,
            default="./../question_discriminator/checkpoints/model.pth", help='path to saved checkpoint')



    parser.add_argument('--max_qst_length', type=int, default=30,
                        help='maximum length of question. the length in the VQA dataset = 26.')

    parser.add_argument('--max_num_ans', type=int, default=10,
                        help='maximum number of answers.')

    parser.add_argument('--batch_size', type=int, default=5,
                        help='batch_size.')
    parser.add_argument('--img_batch_size', type=int, default=250,
                        help='batch_size.')

    parser.add_argument('--num_workers', type=int, default=0,
                        help='number of processes working on cpu.')

    parser.add_argument('--rnn_num_layers', type=int, default=1,
                        help='number of layers of rnn models.')

    parser.add_argument('--rec_unit', type=str,
            default='lstm', help='choose "gru", "lstm" or "elman"')
    parser.add_argument('--sample', default=False, 
            action='store_true', help='just show result, requires --checkpoint_file')
    parser.add_argument('--log_step', type=int,
            default=500, help='number of steps in between calculating loss')
    parser.add_argument('--num_hidden', type=int,
            default=512, help='number of hidden units in the RNN')
    parser.add_argument('--embed_size', type=int,
            default=512, help='number of embeddings in the RNN')
    parser.add_argument('--gamma', type=float,
        default=0.1, help='lr step scheduler gamma')
    parser.add_argument('--step_size', type=int,
        default=1, help='lr step scheduler step')

    parser.add_argument('--epochs', type=int,
        default=5, help='Max epoch')

    args = parser.parse_args()


    with open("./checkpoints/params.json", mode="w") as f:
        json.dump(args.__dict__, f, indent=4)
    main()
