import torch
import os
import random

import numpy as np
from tqdm import tqdm
from datasets import load_dataset, concatenate_datasets

# from transformers import AutoModelForCausalLM 
from modeling_phi3_v import Phi3VForCausalLM, Phi3Attention
from transformers import AutoProcessor 

from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG
from utils.model_utils import phi3_image_processor, call_phi3_engine_df
from utils.eval_utils import parse_multi_choice_response, parse_open_response
from argparse import ArgumentParser

TAGET_MODULE = {
    "phi3": None,
    "phi3_h2o": Phi3Attention
}

def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
    
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    repetitions = 300
    timings=np.zeros((repetitions,1))
    #GPU-WARM-UP
    for sample in samples[0:5]:
        response = call_model_engine_fn(args, sample, model, tokenizer, processor, TAGET_MODULE)
    
    # MEASURE PERFORMANCE
    out_samples = dict()
    rep = 0
    with torch.no_grad():
        for sample in tqdm(samples[0:300]):
            starter.record()
            response = call_model_engine_fn(args, sample, model, tokenizer, processor, TAGET_MODULE)
            ender.record()
            # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[rep] = curr_time
            rep = rep + 1 
            if sample['question_type'] == 'multiple-choice':
                pred_ans = parse_multi_choice_response(response, sample['all_choices'], sample['index2ans'])
            else:  # open question
                pred_ans = response
            out_samples[sample['id']] = pred_ans
    mean_syn = np.sum(timings) / repetitions
    std_syn = np.std(timings)
    mean_fps = 1000. / mean_syn
    print(' * Mean@1 {mean_syn:.3f}ms Std@5 {std_syn:.3f}ms FPS@1 {mean_fps:.2f}'.format(mean_syn=mean_syn, std_syn=std_syn, mean_fps=mean_fps))
    print(mean_syn)    
    return out_samples

def set_seed(seed_value):
    """
    Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.

    :param seed_value: An integer value to be used as the seed.
    """
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)  # For multi-GPU setups
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    parser = ArgumentParser()
    parser.add_argument('--output_path', type=str, default='./example_outputs/Phi3.5_4b_val.json',
                        help='name of saved json')
    parser.add_argument('--config_path', type=str, default="./configs/phi3.5-vision.yaml")
    parser.add_argument('--data_path', type=str, default="./MMMU") # hf dataset path.
    parser.add_argument('--model_path', type=str, default="/hy-tmp/Phi-3.5-vision-instruct")
    parser.add_argument('--split', type=str, default='validation')
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()
    device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
    set_seed(args.seed)
    print('Phi3_initializing...')
    processor = None
    call_model_engine = call_phi3_engine_df
    processor = AutoProcessor.from_pretrained(args.model_path, 
       trust_remote_code=True, 
       num_crops=16
    ) 
    tokenizer = processor.tokenizer
    # load config and process to one value
    args.config = load_yaml(args.config_path)
    for key, value in args.config.items():
        if key != 'eval_params' and type(value) == list:
            assert len(value) == 1, 'key {} has more than one value'.format(key)
            args.config[key] = value[0]
            
    # run for each subject
    sub_dataset_list = []
    for subject in CAT_SHORT2LONG.values():
        sub_dataset = load_dataset(args.data_path, subject, split=args.split)
        sub_dataset_list.append(sub_dataset)

    dataset = concatenate_datasets(sub_dataset_list)
    
    # load model
    model = Phi3VForCausalLM.from_pretrained(
         args.model_path, 
         device_map="cuda", 
         trust_remote_code=True, 
         torch_dtype="auto", 
         _attn_implementation='eager'
    )
    
    # batch_size = 4
    
    samples = []
    for sample in dataset:
        sample = process_single_sample(sample)
        sample = construct_prompt(sample, args.config)
        samples.append(sample)
    del samples[678]
    # batch_samples = []
    # batch_item = []
    # for item in samples:
    #     if len(batch_item) < batch_size:
    #         batch_item.append(item)
    #     else:
    #         batch_samples.append(batch_item)
    #         batch_item = []
        
    # run ex
    out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
    save_json(args.output_path, out_samples)
    
if __name__ == '__main__':
    main()



    

   