import time
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from models import DecoderWithAttention, MLP, VqaModel
# from datasets import *
from torch.optim import lr_scheduler
from utils import *
from loss import CustomLoss
from PIL import Image
from vocab import Vocabulary, load_vocab
from data_loader import get_loader
# from torchviz import make_dot
from scipy.stats import rankdata
from collections import OrderedDict
import re
import glob
def sortedStringList(array=[]):
    sortDict=OrderedDict()
    for splitList in array:
        sortDict.update({splitList:[int(x) for x in re.split("(\d+)",splitList)if bool(re.match("\d*",x).group())]})
    return [sortObjKey for sortObjKey,sortObjValue in sorted(sortDict.items(), key=lambda x:x[1])]

from nltk.translate.bleu_score import corpus_bleu

def fix_key(state_dict):
    # print(state_dict["state_dict"])
    new_state_dict = OrderedDict()
    for k, v in state_dict["state_dict"].items():
        # print(k)
        if k.startswith('module.'):
            k = k[7:]
        new_state_dict[k] = v
    # print(new_state_dict)
    return new_state_dict


# Data parameters
# data_folder = '/media/ssd/caption data'  # folder with data files saved by create_input_files.py
data_name = 'vqa'  # base name shared by data files

# Model parameters
emb_dim = 512  # dimension of word embeddings
attention_dim = 512  # dimension of attention linear layers
decoder_dim = 512  # dimension of decoder RNN
dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # sets device for model and PyTorch tensors
cudnn.benchmark = True  # set to true only if inputs to model are fixed size; otherwise lot of computational overhead

# Training parameters
start_epoch = 0
epochs = 5  # number of epochs to train for (if early stopping is not triggered)
epochs_since_improvement = 0  # keeps track of number of epochs since there's been an improvement in validation BLEU
batch_size = 2
workers = 1  # for data-loading; right now, only 1 works with h5py
encoder_lr = 1e-4  # learning rate for encoder if fine-tuning
decoder_lr = 4e-4  # learning rate for decoder
grad_clip = 5.  # clip gradients at an absolute value of
alpha_c = 1.  # regularization parameter for 'doubly stochastic attention', as in the paper
best_bleu4 = 0.  # BLEU-4 score right now
print_freq = 1  # print training/validation stats every __ batches
fine_tune_encoder = True  # fine-tune encoder?
checkpoint = None  # path to checkpoint, None if none


def main():
    """
    Training and validation.
    """

    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    word_map = VocabDict(args.input_dir+'/vocab_questions.txt')
    ans_word_map = VocabDict(args.input_dir+'/vocab_answers.txt')
    
    candidate_path = "path/to/candidate/*.jpg"
    CandidateList = sortedStringList(glob.glob(candidate_path))

    print("VQG loading...")    
    checkpoint = torch.load(args.VQG_checkpoint)
    VQG_encoder = checkpoint['encoder']
    VQG_encoder = VQG_encoder.to(device)
    VQG_encoder = torch.nn.DataParallel(VQG_encoder)

    data_loader, img_loader = get_loader(
        sim_PATH=args.sim_PATH,
        images_PATHs=CandidateList,
        batch_size=args.batch_size,
        img_batch_size=args.img_batch_size,
        num_workers=args.num_workers)

    feat_List = np.zeros((((len(CandidateList),14,14,2048))),dtype="float32")#torch.zeros(len(CandidateList),14,14,2048)
    feat_weight_List = np.ones((((len(CandidateList),14,14,2048))),dtype="float32")

    for k, (imgs, idxs) in enumerate(img_loader):
        imgs = imgs.to(device)
        feats = VQG_encoder(imgs)
        feats = torch.squeeze(feats)
        feat_List[k] = feats.to('cpu').detach().numpy().copy().astype(np.float32)

    alpha = 0.99
    c = np.zeros(len(CandidateList))
    for j in range(len(CandidateList)):
        c[j] = (alpha ** j)
    for i in tqdm(range(len(CandidateList))):
        for j in range(len(CandidateList)):
            feat_weight_List[i] += c[j] * feat_List[arg[i,j]]
        feat_weight_List[i] /= np.sum(c)
    np.save("./data/InputfeatureImageFeatWeightAccumurate.npy", feat_weight_List)

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_dir', type=str, default='./../Dataset',
                        help='input directory for visual question answering.')

    parser.add_argument('--log_dir', type=str, default='./logs',
                        help='directory for logs.')
    parser.add_argument('--VQG_checkpoint', type=str,
            default="./../VQG/checkpoints/BEST_checkpoint_vqa.pth.tar", help='path to saved checkpoint')


    parser.add_argument('--batch_size', type=int, default=1,
                        help='batch_size.')
    parser.add_argument('--img_batch_size', type=int, default=1,
                        help='batch_size.')

    parser.add_argument('--num_workers', type=int, default=0,
                        help='number of processes working on cpu.')

    args = parser.parse_args()

    main()
