import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from models import DecoderWithAttention, VqaModel, Discriminator
# from datasets import *
from torch.optim import lr_scheduler
from utils import *
from loss import CustomLoss
from PIL import Image
from vocab import Vocabulary, load_vocab
from data_loader import get_loader
# from torchviz import make_dot
from collections import OrderedDict
import re
import glob
def sortedStringList(array=[]):
    sortDict=OrderedDict()
    for splitList in array:
        sortDict.update({splitList:[int(x) for x in re.split("(\d+)",splitList)if bool(re.match("\d*",x).group())]})
    return [sortObjKey for sortObjKey,sortObjValue in sorted(sortDict.items(), key=lambda x:x[1])]
from scipy.stats import rankdata


def fix_key(state_dict):
    # print(state_dict["state_dict"])
    new_state_dict = OrderedDict()
    for k, v in state_dict["state_dict"].items():
        # print(k)
        if k.startswith('module.'):
            k = k[7:]
        new_state_dict[k] = v
    # print(new_state_dict)
    return new_state_dict

def fix_key_2(state_dict):
    # print(state_dict["state_dict"])
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        # print(k)
        if k.startswith('module.'):
            k = k[7:]
        new_state_dict[k] = v
    # print(new_state_dict)
    return new_state_dict


data_name = 'vqa'  # base name shared by data files

# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead



def main():
    """
    Training and validation.
    """

    word_map = VocabDict(args.input_dir+'/vocab_questions.txt')
    ans_word_map = VocabDict(args.input_dir+'/vocab_answers.txt')

    candidate_path = "path/to/candidate/*.jpg"
    CandidateList = sortedStringList(glob.glob(candidate_path))

    print("VQG loading...")
    VQG_decoder = DecoderWithAttention(attention_dim=attention_dim,
                                       embed_dim=emb_dim,
                                       decoder_dim=decoder_dim,
                                       vocab_size=len(word_map),
                                       dropout=dropout)
    VQG_decoder.load_state_dict(fix_key_2(torch.load(args.VQG_checkpoint)))

    VQG_decoder = VQG_decoder.to(device)

    optimizer = torch.optim.Adam([
        {"params": VQG_decoder.module.parameters()}
    ], lr=args.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)


    print("Discriminator loading...")
    D = Discriminator(
        qst_vocab_size= len(word_map),
        word_embed_size= 300,
        num_layers= 2,
        hidden_size= 512,
        embed_size=1024
    ) 

    D.load_state_dict(fix_key_2(torch.load(args.discriminator_PATH)))
    D = D.to(device)
    word_embeddings_discriminator = torch.arange(len(word_map)).to(device)
    word_embeddings_discriminator = D.word2vec(word_embeddings_discriminator)#.to(device)

    ## VQA
    print("VQA loading...")    
    VQA = VqaModel(
        embed_size=1024,
        qst_vocab_size=len(word_map),
        ans_vocab_size=len(ans_word_map),
        word_embed_size=300,
        num_layers=2,
        hidden_size=512)

    VQA.load_state_dict(fix_key(torch.load(args.VQA_checkpoint)))
    VQA = VQA.to(device)

    # for converting onehot to embeddings    
    word_embeddings = torch.arange(len(word_map)).to(device)
    word_embeddings = VQA.qst_encoder.word2vec(word_embeddings)


    # VQA grad false since no trainig 
    for name, p in VQA.module.named_parameters():
        p.requires_grad = False
 
    data_loader, img_loader = get_loader(
        sim_PATH=args.sim_PATH,
        feat_PATH=args.feat_PATH,
        images_PATHs=CandidateList,
        batch_size=args.batch_size,
        img_batch_size=args.img_batch_size,
        num_workers=args.num_workers,
        test= True)

    # Epochs
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    loss_disp = torch.tensor([0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0],dtype=float)
    VQG_decoder.eval()
    VQA.eval()

    original = [] #np.zeros(5000)
    new = [] #np.zeros(5000)
    save_idx = []
    ans_count = []
    gen_qustion_writer = "./results/question.txt"
    discriminative_label = list()
    save_sim = np.zeros((5000,5000))
    with open(gen_qustion_writer,mode="a") as f:
        for i, (GT_imgs, sims, feats, GT_idxs) in enumerate(data_loader):
            optimizer.zero_grad()

            # to device
            GT_imgs = GT_imgs.to(device)
            sims = sims.to(device)
            feats = feats.to(device)
            GT_idxs = GT_idxs.to(device)
            re_sims = sims.clone()

            # set requires grad false
            sims.requires_grad = False
            GT_imgs.requires_grad = False
            feats.requires_grad = False
            re_sims.requires_grad = False
            GT_idxs.requires_grad = False


            # VQG decoder
            generated_qsts, scores = VQG_decoding(VQG_decoder, feats)#encoder_out)
            z = torch.nn.functional.gumbel_softmax(scores, tau=2, hard=True, dim = 2)
            for j in range(z.size(0)):
                gen = [word_map.idx2word(w) for w in torch.argmax(z[j], dim = 1) if w not in {word_map.word2idx('<end>'), word_map.word2idx('<start>'), word_map.word2idx('<pad>')}] 
                f.write(" ".join(gen) + "\n")
                print(" ".join(gen))
                # print(" ".join(gen))
            z_for_VQA = torch.matmul(z, word_embeddings)
            z_for_discriminative = torch.matmul(z, word_embeddings_discriminator)

            # VQA
            GT_ans_score = VQA(GT_imgs, z_for_VQA)#
            _, GT_ans = torch.max(GT_ans_score, 1)
            for j in GT_ans:
                ans_count.append(j.to('cpu').detach().numpy().copy())
            ar = torch.arange(args.batch_size)
            GT_ans_onehot = torch.zeros(args.batch_size, len(ans_word_map)).to(device)
            GT_ans_onehot[ar,GT_ans[ar]] = 1


            for j in range(generated_qsts.size(0)):# each question
                now_qst = z_for_VQA[j].repeat(args.img_batch_size,1).view(args.img_batch_size, args.max_qst_length, z_for_VQA.size(2))
                ans_max_idx = list()

                # set image loader
                target_index_list = torch.where(sims[j] > sims[j][GT_idxs[j]])
                for k, (imgs, idxs) in enumerate(img_loader):# each image

                    imgs = imgs.to(device)
                    # set requires grad false
                    imgs.requires_grad = False

                    ans_score = VQA(imgs, now_qst) # batch word
                    _, ans = torch.max(ans_score, 1)


                    ar = torch.arange(args.img_batch_size)
                    for l in range(args.img_batch_size):
                        ans_max_idx.append(ans[l])
                    ans_onehot = torch.zeros(args.img_batch_size, len(ans_word_map)).to(device)
                    ans_onehot[ar,ans[ar]] = 1
                    tmp = args.hyp_sim * torch.matmul(ans_onehot,GT_ans_onehot[j])
                    re_sims[j][idxs] = sims[j][idxs] + tmp

                x = sims[j].to('cpu').detach().numpy().copy()
                y = re_sims[j].to('cpu').detach().numpy().copy()
                original.append(rankdata(-x)[GT_idxs[j]])
                new.append(rankdata(-y)[GT_idxs[j]])
                save_idx.append(GT_idxs[j].to('cpu').detach().numpy().copy())
                save_sim[GT_idxs[j]] = y
                print(GT_idxs[j])
                print( rankdata(-x)[GT_idxs[j]], rankdata(-y)[GT_idxs[j]])
                print("mean rank: {}, {}".format(np.mean(original),np.mean(new)))
                print("median rank: {}, {}".format(np.median(original),np.median(new)))
                print()
                original_rank = rankdata(-x)[GT_idxs[j]]

            ## Discriminative acc
            outs = D(z_for_discriminative)
            outs = outs.detach().cpu().data.max(1)[1].numpy().copy()
            discriminative_label.append(outs)
            






    np.save("./results/original.npy",np.array(original))               
    np.save("./results/rerank.npy",np.array(new))         
    np.save("./results/rerank_sim.npy".format(i*5),np.array(save_sim))
    np.save("./results/idx.npy",np.array(save_idx))               
    np.save("./results/ans_count.npy",np.array(ans_count))               
    np.save("./results/discriminative_labels.npy",np.array(discriminative_label))

def VQG_decoding(VQG_decoder,encoder_out):

    # VQG encoding
    batch_size = encoder_out.size(0)
    encoder_dim = encoder_out.size(-1)
    vocab_size = VQG_decoder.module.vocab_size       

    # Flatten image
    encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
    num_pixels = encoder_out.size(1)

    # start word
    k_prev_words = torch.LongTensor([[word_map.word2idx('<start>')]] * batch_size).to(device)
    seqs = k_prev_words

    socres_ret = torch.zeros(args.max_qst_length,batch_size,vocab_size).to(device)
    socres_ret[0,:,word_map.word2idx('<start>')] = 100

    incomplete_inds = list()
    complete_inds = list()
    step = 1
    h, c = VQG_decoder.module.init_hidden_state(encoder_out)  
    k = batch_size

    ar = torch.arange(batch_size)
    while True:
        # embeddings
        embeddings = VQG_decoder.module.embedding(k_prev_words).squeeze(1) #(batch_size, embeded_dim)
        awe, _ = VQG_decoder.module.attention(encoder_out, h) #(batch_size, encoder_dim), (batch_size, num_pixels)
        gate = VQG_decoder.module.sigmoid(VQG_decoder.module.f_beta(h))  # gating scalar, (batch_size, encoder_dim)
        awe = gate * awe
        h, c = VQG_decoder.module.decode_step(torch.cat([embeddings, awe], dim=1), (h, c))  # (batch_size, decoder_dim)
        scores = VQG_decoder.module.fc(h)  # (batch_size, vocab_size)
        scores = F.log_softmax(scores, dim=1)

        # store return scores
        socres_ret[step] = scores
        
        # calculate max score words
        next_word_inds = torch.argmax(scores, dim=1)

        # add 100 to max score words for stabilizing gumbel_softmax
        socres_ret[step,ar,next_word_inds[ar]] += 100 

        # concat words to return sequence
        seqs = torch.cat([seqs, next_word_inds.unsqueeze(1)], dim=1)

        incomplete_inds += [ind for ind, next_word in enumerate(next_word_inds) if
                            (next_word != word_map.word2idx('<end>') and (ind not in incomplete_inds))]


        complete_inds += [ind for ind, next_word in enumerate(next_word_inds) if
                            (next_word == word_map.word2idx('<end>') and (ind not in complete_inds))]        

        k = len(incomplete_inds)
        if k == 0:
            break
        k_prev_words = next_word_inds.unsqueeze(1)
        if step > 28:
            break
        step += 1

    return seqs, socres_ret.permute(1,0,2)






if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', type=str, default='./../Dataset',
                        help='input directory for visual question answering.')

    parser.add_argument('--log_dir', type=str, default='./logs',
                        help='directory for logs.')

    parser.add_argument('--VQG_checkpoint', type=str,
            default="./checkpoints/chechpoints_0epoch.pth", help='path to saved checkpoint')

    parser.add_argument('--sim_PATH', type=str,
            default="/path/to/sim.npy", help='path to similarity')
    parser.add_argument('--feat_PATH', type=str,
            default="./data/ImageFeatWeightAccumurateTest.npy", help='path to input feat')

    parser.add_argument('--VQA_checkpoint', type=str,
            default="./../VQA/checkpoints/model-epoch-30.ckpt", help='path to saved checkpoint')

    parser.add_argument('--discriminator_PATH', type=str,
            default="./../question_discriminator/checkpoints/model.pth", help='path to saved checkpoint')


    parser.add_argument('--loss_margin', type=float,
            default=0.5, help='loss margin')

    parser.add_argument('--hyp_sim', type=float,
            default=0.3, help='parameter that balances initial and re-rank similarities')

    parser.add_argument('--lr', type=float,
            default=0.0001, help='learning rate')


    parser.add_argument('--max_qst_length', type=int, default=30,
                        help='maximum length of question. the length in the VQA dataset = 26.')

    parser.add_argument('--max_num_ans', type=int, default=10,
                        help='maximum number of answers.')

    parser.add_argument('--batch_size', type=int, default=10,
                        help='batch_size.')
    parser.add_argument('--img_batch_size', type=int, default=1000,
                        help='batch_size.')

    parser.add_argument('--num_workers', type=int, default=0,
                        help='number of processes working on cpu.')

    parser.add_argument('--rnn_num_layers', type=int, default=1,
                        help='number of layers of rnn models.')

    parser.add_argument('--rec_unit', type=str,
            default='lstm', help='choose "gru", "lstm" or "elman"')
    parser.add_argument('--sample', default=False, 
            action='store_true', help='just show result, requires --checkpoint_file')
    parser.add_argument('--log_step', type=int,
            default=500, help='number of steps in between calculating loss')
    parser.add_argument('--num_hidden', type=int,
            default=512, help='number of hidden units in the RNN')
    parser.add_argument('--embed_size', type=int,
            default=512, help='number of embeddings in the RNN')
    parser.add_argument('--gamma', type=float,
        default=0.1, help='lr step scheduler gamma')
    parser.add_argument('--step_size', type=int,
        default=1, help='lr step scheduler step')

    args = parser.parse_args()


    main()
