import os
import numpy as np
np.random.seed(101)
from argparse import ArgumentParser
from transformers import AutoTokenizer, AutoModelWithLMHead, T5ForConditionalGeneration
from transformers import StoppingCriteria, StoppingCriteriaList
import tqdm
import torch
import pickle
import clip
import timm
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import torch.nn.functional as nnf
import faiss
from torch.distributions import Categorical
import json
from clip.simple_tokenizer import SimpleTokenizer
import time
from gensim.models import KeyedVectors
import glob


def create_parser():
    parser = ArgumentParser()
    parser.add_argument('--datadir', type=str, default='data/coco/imgs_val.pkl', help='path to images')
    parser.add_argument('--k', type=int, default=5, help='How many tokens to retrieve')
    parser.add_argument('--data', type=str, choices=['mscoco', 'flickr30k'],
                        help="Type of data used for training the mapping")
    parser.add_argument('--lm', type=str, default='google/flan-t5-large',
                        help='language model to align to')
    parser.add_argument('--vis-encoder', type=str, default='RN50x64',
                        choices=['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14',
                                 'ViT-L/14@336px', 'beit_base_patch16_224', 'vit_large_patch16_224_in21k'],
                        help='Vision encoder to use')
    parser.add_argument('--train_method', help='Method that was used to train the semantic mapping')
    parser.add_argument('--paragraphs', action='store_true', default=False,
                        help="Whether to use paragraphs for retrieval")
    parser.add_argument('--localized-narratives', action='store_true', default=False,
                        help="Whether to use paragraphs for retrieval")
    parser.add_argument('--decoding', type=str, choices=['greedy', 'sampling', 'topk', 'nucleus'], required=True,
                        help='What decoding strategy to use')
    parser.add_argument('--mapping-transfer', action='store_true', default=False,
                        help="Use mapping computed on one dataset for another")
    parser.add_argument('--datastore-transfer', action='store_true', default=False,
                        help="Use mapping computed on one dataset for another")
    parser.add_argument('--inverse-order', action='store_true', default=False,
                        help="Inverse prompting order from best-to-worst to worst-to-best")
    return parser.parse_args()


def search_index(queries, index, k=100):
    normed_queries = queries / torch.linalg.norm(queries, ord=2, axis=-1, keepdims=True)
    _, I = index.search(normed_queries, k=k)
    I = I.cpu().numpy()
    return I


def get_index(embs):
    d = embs.shape[-1]
    index = faiss.IndexFlatIP(d)
    index.add(embs.astype(np.float32))
    return index


def main():
    options = create_parser()

    suffix = ''
    if options.paragraphs:
        suffix += '_paragraphs'

    if options.localized_narratives:
        suffix += '_localized_narratives'

    if options.mapping_transfer:
        suffix += '_mapping_transfer'

    if options.datastore_transfer:
        suffix += '_datastore_transfer'

    if options.inverse_order:
        suffix += '_inverse_order'

    dataset = "flickr30k" if "flickr" in options.datadir else "coco"
    suffix += f'_{options.decoding}_{dataset}'
    if 'val' in options.datadir:
        set = '_val'
    elif 'test' in options.datadir:
        set = '_test'
    elif 'train' in options.datadir:
        set = '_train'
    elif 'nocaps' in options.datadir:
        set = options.datadir.split('_')[-1].split('.')[0]
    else:
        set = '_all'

    suffix += set

    if options.datadir.endswith('pkl'):
        images = pickle.load(open(options.datadir, 'rb'))
    else:
        raise NotImplementedError(f"Not able to load from {options.datadir}")

    keys = list(images.keys())
    os.makedirs('./results', exist_ok=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    fm = options.vis_encoder
    lm = options.lm

    print(f"Generating for {lm}, {options.fraction}, {options.train_method}")
    lm_clean = lm.split('/')[-1]
    prompt = "Similar images show: "

    model, preprocess = clip.load(fm)
    model = model.to(device)
    model.eval()

    if options.datastore_transfer:
        tmp = 'flickr30k' if dataset == 'coco' else 'coco'
    else:
        tmp = 'coco' if dataset == 'coco' else 'flickr30k'

    if options.paragraphs:
        train_info = pickle.load(open(f'./data/{tmp}/{fm_clean}_train_paragraphs.pkl', 'rb'))
        train_caps = np.array([p[1] for p in train_info])
    elif options.localized_narratives:
        train_info = pickle.load(open(f'./data/{tmp}/{fm_clean}_train_caps_localized_narratives.pkl', 'rb'))
        train_caps = np.array([p[1] for p in train_info])
    else:
        train_info = pickle.load(open(f'./data/{tmp}/{fm_clean}_train_caps.pkl', 'rb'))
        train_caps = np.concatenate([p[1] for p in train_info])
    transformer_embs = np.concatenate([p[-1] for p in train_info])

    if options.train_method is not None:
        transformer_embs /= np.linalg.norm(transformer_embs, ord=2, axis=-1, keepdims=True)
        transformer_embs -= transformer_embs.mean(0)

    transformer_embs /= np.linalg.norm(transformer_embs, ord=2, axis=-1, keepdims=True)
    index = get_index(transformer_embs)
    import faiss.contrib.torch_utils
    res = faiss.StandardGpuResources()
    index = faiss.index_cpu_to_gpu(res, 0, index)

    if 'gpt-j' in lm:
        transformer = GPTJForCausalLM.from_pretrained(
            lm, torch_dtype=torch.float16, low_cpu_mem_usage=True,
            cache_dir="/system/user/publicdata/llm"
        )
        transformer = transformer.to(device)
        tokenizer = AutoTokenizer.from_pretrained(lm,
                                                  cache_dir="/system/user/publicdata/llm")
    elif 'gpt2' in lm:
        transformer = GPT2LMHeadModel.from_pretrained(lm, cache_dir="/system/user/publicdata/llm")
        tokenizer = GPT2Tokenizer.from_pretrained(lm, cache_dir="/system/user/publicdata/llm")
        transformer = transformer.to(device)
    elif 'GPT-JT' in lm:
        transformer = AutoModelForCausalLM.from_pretrained(
            "togethercomputer/GPT-JT-6B-v1", torch_dtype=torch.float32, low_cpu_mem_usage=True,
            cache_dir="/system/user/publicdata/llm"
        ).to(device)
        tokenizer = AutoTokenizer.from_pretrained("togethercomputer/GPT-JT-6B-v1",
                                                  cache_dir="/system/user/publicdata/llm")
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    elif 'flan-t5' in lm:
        transformer = AutoModelWithLMHead.from_pretrained(
            lm, torch_dtype=torch.float32, low_cpu_mem_usage=True,
            cache_dir="/system/user/publicdata/llm"
        )
        transformer = transformer.to(device)
        tokenizer = AutoTokenizer.from_pretrained(lm, cache_dir="/system/user/publicdata/llm")
    elif 't5-v1_1' in lm:
        if 'xxl' in lm:
            # load 8-bit quantized model
            transformer = AutoModelWithLMHead.from_pretrained(lm, device_map='auto',
                                                              offload_folder="offload", load_in_8bit=True,
                                                              cache_dir="/system/user/publicdata/llm")
        else:
            transformer = AutoModelWithLMHead.from_pretrained(
                lm, torch_dtype=torch.float32, low_cpu_mem_usage=True,
                cache_dir="/system/user/publicdata/llm"
            )
            transformer = transformer.to(device)
        tokenizer = AutoTokenizer.from_pretrained(lm, cache_dir="/system/user/publicdata/llm",
                                                  use_fast=False)
    elif 'llama' in lm:
        transformer = LlamaForCausalLM.from_pretrained(
            "decapoda-research/llama-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True,
            cache_dir="/system/user/publicdata/llm"
        )
        tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf",
                                                   cache_dir="/system/user/publicdata/llm",
                                                   use_fast=False)
        transformer = transformer.to(device)
    else:
        raise NotImplementedError(f"{lm} - Language model not supported!!!!")

    transformer.eval()

    if not os.path.exists(f'./data/{dataset}/image_features_{fm_clean}_{options.data}_{set}.npy'):
        image_features = []
        with torch.no_grad():
            for i in tqdm.trange(0, len(keys), 16):
                ids = keys[i: i + 16]
                batch = torch.stack([preprocess(images[id]) for id in ids]).to(device)
                embeddings = model.encode_image(batch).float().cpu().numpy()
                image_features.append(embeddings)
            image_features = np.concatenate(image_features)
            np.save(f'./data/{dataset}/image_features_{fm_clean}_{options.data}{env}_{set}.npy', image_features)
    else:
        image_features = np.load(f'./data/{dataset}/image_features_{fm_clean}_{options.data}{env}_{set}.npy')

    if options.mapping_transfer:
        tmp = 'flickr30k' if dataset == 'coco' else 'mscoco'
    else:
        tmp = 'mscoco' if dataset == 'coco' else 'flickr30k'

    if options.train_method is not None:
        image_features /= np.linalg.norm(image_features, ord=2, axis=-1, keepdims=True)
        image_features -= image_features.mean(0)

        if options.paragraphs:
            proj_mat = np.load(os.path.join('./models', f'{fm_clean}_{options.train_method}_{tmp}_{options.fraction}_retrieval_paragraphs.npy'))
        elif options.localized_narratives:
            proj_mat = np.load(os.path.join('./models', f'{fm_clean}_{options.train_method}_{tmp}_{options.fraction}_retrieval_localized_narratives.npy'))
        else:
            proj_mat = np.load(os.path.join('./models', f'{fm_clean}_{options.train_method}_{tmp}_{options.fraction}_retrieval.npy'))

        if proj_mat is not None:
            proj_features = image_features @ proj_mat
        else:
            proj_features = image_features

    proj_features = torch.FloatTensor(proj_features).to(device)
    ranked_sims = search_index(proj_features, index, k=options.k)

    ann_file = []
    inference_times = []
    for i, key in tqdm.tqdm(enumerate(keys), desc="Creating Captions...."):
        start = time.time()
        cur_ranks = np.array(ranked_sims[i]).copy()
        if options.inverse_order:
            cur_ranks = cur_ranks[::-1]
        prompts = ['Similar images show: ' + ' '.join([train_caps[tok] for tok in cur_ranks]) + prompt]

        with torch.no_grad():
            if not isinstance(transformer, T5ForConditionalGeneration):
                tokenizer.pad_token = tokenizer.eos_token
                tokenizer.pad_token_id = tokenizer.eos_token_id
                tokenizer.padding_side = 'left'

            encoding = tokenizer(prompts, return_tensors='pt', padding=True)
            for temp in encoding.keys():
                encoding[temp] = encoding[temp].to(device)
            length = encoding['input_ids'].shape[1]

        # prompt the LM to generate text
        with torch.no_grad():
            if not isinstance(transformer, T5ForConditionalGeneration):
                generated = transformer.generate(**encoding, max_new_tokens=67, pad_token_id=tokenizer.pad_token_id)
                generated = generated[:, length:]
                caps = tokenizer.batch_decode(generated, skip_special_tokens=True)
                caps = [cap[:cap.index('.')] if '.' in cap else cap for cap in caps]
                caption = [cap[:cap.index('\n')] if '\n' in cap else cap for cap in caps]
            else:
                if options.decoding == "greedy":
                    kwargs = {'max_length': 67}
                elif options.decoding == 'sampling':
                    kwargs = {'do_sample': True, 'top_k': 0, 'max_length': 67}
                elif options.decoding == 'topk':
                    kwargs = {'do_sample': True, 'top_k': 640, 'max_length': 67}
                else:
                    kwargs = {'do_sample': True, 'top_p': 0.9, 'max_length': 67, 'temperature': options.temp}

                generated = transformer.generate(**encoding, **kwargs)
                caption = tokenizer.batch_decode(generated, skip_special_tokens=True)

        end = time.time()
        inference_times.append(end - start)
        if isinstance(key, str):
            image_id = key.split("/")[-1].split('.')[0].split('_')[-1].lstrip('0')
        else:
            image_id = key
        ann_file.append({'image_id': image_id, 'caption': caption,
                         'tokens': [train_caps[tok] for tok in cur_img_emb]})

    print(f"Average inference time: {np.mean(inference_times)}+-{np.std(inference_times)}")
    json.dump(ann_file, open(f'./results/captions_val_{fm_clean}_{lm_clean}{setting_suffix}.json', 'w'))


if __name__ == '__main__':
    main()

