import base64
from io import BytesIO
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))

import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import torch
import transformers
import tokenizers
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from torch.utils.data import Dataset

from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import tokenizer_image_token
from llava.model.language_model.llava_phi3 import LlavaPhiForCausalLM, LlavaPhiConfig
from PIL import Image
import pickle
import argparse
from packaging import version
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
from utils import find_all_linear_names, add_special_tokens_and_resize_model, load_weights, expand2square
from tqdm import tqdm

def infer():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name_or_path', type=str, default='microsoft/Phi-3-mini-4k-instruct')
    parser.add_argument('--dtype', type=str, default='FP32')
    parser.add_argument('--attn_implementation', type=str, default=None)
    parser.add_argument('--instruct_template', type=str, default='phi3_instruct')
    parser.add_argument('--vit_path', type=str, default='openai/clip-vit-large-patch14-336')
    parser.add_argument('--lora_path', type=str, default=None)
    parser.add_argument('--soft_moe_path', type=str, default=None)
    parser.add_argument('--do_sample', type=bool, default=False)
    parser.add_argument('--soft_moe', type=bool, default=False)
    parser.add_argument('--temperature', type=float, default=0.0)
    parser.add_argument('--top_p', type=float, default=None)
    parser.add_argument('--num_beams', type=int, default=1)
    parser.add_argument('--max_new_tokens', type=int, default=1024)
    parser.add_argument('--lora_config_path', type=str)
    parser.add_argument('--hyper_lora_and_projector_config_path', type=str)
    parser.add_argument('--test_data', type=str)
    parser.add_argument('--output_dir', type=str)
    
    
    args = parser.parse_args()

    model_dtype = torch.float32 if args.dtype == 'FP32' else (torch.float16 if args.dtype == 'FP16' else torch.bfloat16)

    config = LlavaPhiConfig.from_pretrained(args.hyper_lora_and_projector_config_path)
    config.soft_moe= True
    model, loading_info = LlavaPhiForCausalLM.from_pretrained(
        pretrained_model_name_or_path=args.model_name_or_path,
        config=config,
        output_loading_info=True,
    )
    
    bad = []
    for n, p in model.named_parameters():
        if p.numel() == 0 or 0 in p.shape:
            bad.append((n, tuple(p.shape), getattr(p, "is_meta", False), type(p).__name__))
    print("zero-sized params:", len(bad))
    for x in bad:
        print(x)
    
    print("loading_info")
    print(loading_info)
    
    from llava.peft import LoraConfig, get_peft_model
    with open(args.lora_config_path, "r") as f:
        cfg_dict = json.load(f)
    cfg_dict["soft_moe"]=True

    lora_config = LoraConfig(**cfg_dict)
    model = get_peft_model(model, lora_config)
    
    model.config.use_cache = False
    from utils import com_vision_args
    com_vision_args.model_name_or_path = args.model_name_or_path
    com_vision_args.vision_tower = args.vit_path
    com_vision_args.version = args.instruct_template
    model.get_model().initialize_vision_modules(model_args=com_vision_args)
    print("load 11 task lora!")
    model = load_weights(model, args.lora_path)
    print("load soft moe weight!")
    model = load_weights(model, args.soft_moe_path)


    tokenizer = transformers.AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        padding_side="right",
        use_fast=False,
    )
    
    model.eval()
    model.to(model_dtype).cuda()
    with open(args.test_data, 'r') as f:
        data_list = json.load(f)
    print(args.test_data)
    os.makedirs(args.output_dir, exist_ok=True)

    dir_name = os.path.basename(os.path.dirname(args.test_data))  
    file_name = os.path.basename(args.test_data)     
    output_file_path = os.path.join(args.output_dir,f"{dir_name}_{file_name}")

    output_file = open(output_file_path, 'a', encoding='utf-8')

    results = []
    
    for item in tqdm(data_list):
        if("image" in item.keys()):
            image = Image.open(BytesIO(base64.b64decode(item['image']))).convert("RGB")
            image = expand2square(image, tuple(int(x*255) for x in model.get_vision_tower().image_processor.image_mean))
        else:
            image=Image.new('RGB', (224, 224), (0, 0, 0))
            image = expand2square(image, tuple(int(x*255) for x in model.get_vision_tower().image_processor.image_mean))
            item["image"]=""
        image_tensor = model.get_vision_tower().image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].unsqueeze_(0)
        qs = None
        gt_answer = None
        for conv in item["conversations"]:
            if conv["from"] == "human":
                qs = conv["value"]
            elif conv["from"] == "gpt":
                gt_answer = conv["value"]
        if not qs:
            continue
        conv = conversation_lib.conv_templates[args.instruct_template].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').cuda().unsqueeze_(0)
    
        with torch.inference_mode():
            output_ids = model.base_model.generate(
            input_ids,
            images=image_tensor.to(dtype=model_dtype, device='cuda', non_blocking=True),
            image_sizes=image.size,
            do_sample=args.do_sample,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            max_new_tokens=args.max_new_tokens,
            use_cache=False
            )
        response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[:-8]
        item["response"]=response
        output_file.write(json.dumps(item, ensure_ascii=False) + "\n")



if __name__ == "__main__":

    infer()