# Load via Huggingface Style
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-g', '--gpu', default=7)
parser.add_argument('-p', '--ckpt_path')
parser.add_argument('-e', '--exp_id',default='0.3.brkt')
parser.add_argument('-bs', '--batch_size', default=1)
parser.add_argument('-t', '--time_to_sleep',default=10)
parser.add_argument('-s', '--start_index', type=int, default=-1)
args = parser.parse_args()

# import pdb;pdb.set_trace()
import torch
import os
import string
import re

from safetensors.torch import load_file
import torch


torch.set_grad_enabled(False)
os.environ['CUDA_VISIBLE_DEVICES']=str(args.gpu)
# import pdb;pdb.set_trace()
print(torch.cuda.device_count())

def write_to_record_file(data, file_path, verbose=True):
    if verbose:
        print(data)
    record_file = open(file_path, 'a')
    record_file.write(data+'\n')
    record_file.close()

def poll_checkpoint_folder(
    checkpoint_folder: str, previous_ckpt_ind: int
):
    r"""Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder
    (sorted by time of last modification).

    Args:
        checkpoint_folder: directory to look for checkpoints.
        previous_ckpt_ind: index of checkpoint last returned.

    Returns:
        return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found
        else return None.
    """
    assert os.path.isdir(checkpoint_folder), (
        f"invalid checkpoint folder " f"path {checkpoint_folder}"
    )
    models_paths = list(
        filter(
            lambda name: "latest" not in name,
            filter(os.path.isdir, glob.glob(checkpoint_folder + "/*")),
        )
    )
    models_paths = [p for p in models_paths if 'checkpoint' in p]
    models_paths.sort(key=os.path.getmtime)
    ind = previous_ckpt_ind + 1
    if ind < len(models_paths):
        return models_paths[ind]
    return None


device="cuda"

import sys
sys.path.append('.')
sys.path.append('..')
from mantis.models.mllava import MLlavaProcessor, LlavaForConditionalGeneration
from mantis.models.mllava import chat_mllava

processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
attn_implementation = "flash_attention_2"
model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3", device_map="cuda", torch_dtype=torch.bfloat16, attn_implementation=attn_implementation)

from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    r=128,
    lora_alpha=256,
    target_modules=r'.*language_model.*\.(q_proj|v_proj)',
    lora_dropout=0.05,
    bias='none',
    task_type="CAUSAL_LM",
)
# model.to(torch.bfloat16)
print("Adding LoRA adapters...")
model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
print("Successfully added LoRA adapters")

from unittest import result
from transformers import AutoTokenizer
import torch

import os
import glob
import time
import json
import numpy as np
import sys

def evaluate_caption2(ref_caps=None, pred_caps=None, scorer_names=None, remove_puncts=True):
    sys.path.append('~/bleu-rouge-meteor-cider-spice-eval4imagecaption')
    from bleu.bleu import Bleu
    from meteor.meteor import Meteor
    from rouge.rouge import Rouge
    from cider.cider import Cider
    from spice.spice import Spice

    # PUNCTUATIONS = set(["''", "'", "``", "`", "(", ")", "/", '"', 
    #                 ".", "?", "!", ",", ":", "-", "--", "...", ";"])

    # def clean_punctuation_in_sentence(sent):
    #     tokens = sent.strip().split(' ')
    #     tokens_ = []
    #     for w in tokens:
    #         if len(w) > 0:
    #             if w[-1] not in PUNCTUATIONS:
    #                 tokens_.append(w)
    #             else:
    #                 if len(w[:-1]) > 0:
    #                     tokens_.append(w[:-1])
    #     sent = ' '.join(tokens_)
    #     return sent

    # preds = {}
    # for key, value in pred_caps.items():
    #     if remove_puncts:
    #         preds[key] = [clean_punctuation_in_sentence(value[0].lower())]
    #     else:
    #         preds[key] = value

    # refs = {}
    # for key in preds.keys():
    #     if remove_puncts:
    #         refs[key] = [clean_punctuation_in_sentence(sent.lower()) for sent in ref_caps[key.split('_')[0]]]
    #     else:
    #         refs[key] = ref_caps[key]

    def split_sentence(sentence):
        ''' Break sentence into a list of words and punctuation '''
        toks = []
        for word in [s.strip().lower() for s in (re.compile(r'(\W+)')).split(sentence.strip()) if len(s.strip()) > 0]:
            # Break up any words containing punctuation only, e.g. '!?', unless it is multiple full stops e.g. '..'
            if all(c in string.punctuation for c in word) and not all(c in '.' for c in word):
                toks += list(word)
            else:
                toks.append(word)
#         import pdb;pdb.set_trace()
        return toks

    preds = {}
    for key, value in pred_caps.items():
        if remove_puncts:
            preds[key] = [' '.join(split_sentence(value[0].lower()))]
        else:
            preds[key] = value

    refs = {}
    for key in preds.keys():
        if remove_puncts:
            refs[key] = [' '.join(split_sentence(sent.lower())) for sent in ref_caps[key.split('_')[0]]]
        else:
            refs[key] = ref_caps[key]


    scorers = {
        'bleu1': Bleu(1),
        'bleu4': Bleu(4),
        'meteor': Meteor(),
        'rouge': Rouge(),
        'cider': Cider(),
        'spice': Spice(),
    }
    if scorer_names is None:
        scorer_names = list(scorers.keys())

    scores = {}
    for measure_name in scorer_names:
        scorer = scorers[measure_name]
        s, _ = scorer.compute_score(refs, preds)
        if measure_name == 'bleu4' or measure_name == 'bleu1':
            scores[measure_name] = s[-1] * 100
        else:
            scores[measure_name] = s * 100

    # scorers['meteor'].meteor_p.kill()
    unique_words = set()
    sent_lens = []
    for key, value in preds.items():
        for sent in value:
            unique_words.update(sent.split())
            sent_lens.append(len(sent.split()))
    scores['num_words'] = len(unique_words)
    scores['avg_lens'] = np.mean(sent_lens)
    return scores

def load_jsonl(filename):
    with open(filename, "r", encoding="utf-8") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]

val = load_jsonl("data/vln/r2r_val_unseen_v%s.jsonl"%args.exp_id)
# val = val[:100]

ref_caps = {}
with open('~/eval/R2R_val_unseen.json','r') as f:
    unseen = json.load(f)
for v in unseen:
    ref_caps[str(v['path_id'])] = v['instructions']

def eval_checkpoint(checkpoint_path, index, bs=1):
    files = os.listdir(checkpoint_path)
    use_bin = 'adapter_model.bin' in files
    use_safetensors = 'adapter_model.safetensors' in files
    if use_bin and use_safetensors:
        print('wft? try safetensors')
    
    if use_bin:
        lora_path = os.path.join(checkpoint_path, 'adapter_model.bin')
        print('load lora from {}'.format(lora_path))
        prefix_state_dict = torch.load(lora_path, map_location='cpu')
    if use_safetensors:
        lora_path = os.path.join(checkpoint_path, 'adapter_model.safetensors')
        prefix_state_dict = load_file(lora_path)
    state_dict = {}
    for k,v in prefix_state_dict.items():
        state_dict[k.replace('weight','default.weight')] = v
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    generate_kwargs = {
        "max_new_tokens": 100,
        "num_beams": 1,
        "do_sample": False,
        'num_return_sequences': 1,
    }

    if bs>1: # currently we don't support it
        import math
        iters = math.ceil(len(val)/bs)

        infers = {}
        import tqdm

        # i = 0
        for idx in tqdm.tqdm(range(iters)):
        #     i+=1
        #     if i > 2:
        #         break
            instr_ids = [] 
            image_list = []
            prompts = []
            for item in val[bs*idx:bs*(idx+1)]:
                prompts.append(item['text'])
                image_list += item['image']
                instr_ids.append(item['instruction_id'])
            generate_kwargs = {
                'do_sample': False,
                'top_k': 5,
                'max_length': 75
            }
            from PIL import Image
            images = [Image.open(os.path.join('data/vln/imgs_90fov',_)) for _ in image_list]
            inputs = processor(text=prompts, images=images, return_tensors='pt')
            inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()}
            inputs = {k: v.to(model.device) for k, v in inputs.items()}
            with torch.no_grad():
                res = model.generate(**inputs, **generate_kwargs)
            del inputs
            sentences = tokenizer.batch_decode(res.tolist(), skip_special_tokens=True)
            for iid, s in zip(instr_ids, sentences):
                infers[iid] = s
    else:
        iters = len(val)
        infers = {}
        import tqdm

        for item in tqdm.tqdm(val):
            instr_id = item['instruction_id']
            images = [os.path.join('data/vln/imgs_90fov',_) for _ in item['image']]
            text = item['text']
            with torch.no_grad():
                response, _ = chat_mllava(text, images, model, processor, **generate_kwargs)
                print(response)
                print('---')
            infers[instr_id] = response

    print(len(infers))
    infers = {k:[v] for k,v in infers.items()}
    eval_results = evaluate_caption2(ref_caps,infers)
    return eval_results

def compare(best_bleu, best_spice, checkpoint_path, results, record_file):
    name = checkpoint_path.split('/')[-1]
    loss_str = name
    for metric, val in results.items():
        loss_str += ', %s: %.2f' % (metric, val)
    if results['bleu4'] > best_bleu['bleu4']:
        best_bleu['bleu4'] = results['bleu4']
        best_bleu['state'] = loss_str
    if results["spice"] > best_spice["spice"]:
        best_spice["spice"] = results["spice"]
        best_spice['state'] = loss_str
    
    write_to_record_file(
        loss_str,
        record_file
    )
    write_to_record_file("BEST RESULT TILL NOW", record_file)
    write_to_record_file('best bleu' + ' | ' + best_bleu['state'], record_file)
    write_to_record_file('best spice' + ' | ' + best_spice['state'], record_file)

exp_path = args.ckpt_path
record_file = os.path.join(exp_path, 'eval.txt')

best_bleu = {"bleu4": 0., "state":""}
best_spice = {"spice": 0., "state":""}

prev_ckpt_ind = args.start_index

while True:
    current_ckpt = None
    while current_ckpt is None:
        current_ckpt = poll_checkpoint_folder(
            exp_path,
            prev_ckpt_ind,
        )
        time.sleep(2)  # sleep for 2 secs before polling again
    write_to_record_file(f"=======current_ckpt: {current_ckpt}=======", record_file)  # type: ignore
    prev_ckpt_ind += 1
    time.sleep(int(args.time_to_sleep))
    results = eval_checkpoint(
        checkpoint_path=current_ckpt,
        index = prev_ckpt_ind,
        bs = int(args.batch_size)
    )
    compare(best_bleu, best_spice, current_ckpt, results, record_file)