import argparse
import itertools
import json
import os
import random
import subprocess
import time
from functools import partial
from typing import Optional

import torch
# from internvl.model import load_model_and_tokenizer
from internvl.train.dataset_test import build_transform, dynamic_preprocess
from PIL import Image
from tqdm import tqdm
import os
import warnings
import shutil
import math
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from internvl.model.internvl_chat import InternVLChatConfig, InternVLChatModel
# from internvl.model.llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT
def split_model(num_layers, vit_alpha=0.5):
    device_map = {}
    world_size = torch.cuda.device_count()
    # Since the first GPU will be used for ViT, treat it as half a GPU.
    num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha))
    num_layers_per_gpu = [num_layers_per_gpu] * world_size
    num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha))
    layer_cnt = 0
    for i, num_layer in enumerate(num_layers_per_gpu):
        for j in range(num_layer):
            device_map[f'language_model.model.layers.{layer_cnt}'] = i
            layer_cnt += 1
    device_map['vision_model'] = 0
    device_map['mlp1'] = 0
    device_map['language_model.model.tok_embeddings'] = 0
    device_map['language_model.model.embed_tokens'] = 0
    device_map['language_model.output'] = 0
    device_map['language_model.model.norm'] = 0
    device_map['language_model.lm_head'] = 0
    device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
    device_map['language_model.model.rotary_emb'] = 0

    return device_map


def load_pretrained_model_both(model_path, model_base, prompt_tuning_adding, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
    kwargs = {"device_map": device_map, **kwargs}

    if device != "cuda":
        kwargs['device_map'] = {"": device}

    if load_8bit:
        kwargs['load_in_8bit'] = True
    elif load_4bit:
        kwargs['load_in_4bit'] = True
        kwargs['quantization_config'] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4'
        )
    else:
        kwargs['torch_dtype'] = torch.float16
    # import ipdb;ipdb.set_trace()
    # if use_flash_attn:
    kwargs['attn_implementation'] = 'flash_attention_2'
    # from llava.model.language_model.llava_llama import LlavaConfig
    # lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
    # tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
    # print('Loading LLaVA from base model...')
    # model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
    config = InternVLChatConfig.from_pretrained(args.model_base)
    num_hidden_layers = config.llm_config.num_hidden_layers
    device_map = split_model(num_hidden_layers)
    kwargs = {'device_map': device_map} if args.auto else {}
    tokenizer = AutoTokenizer.from_pretrained(args.model_base, trust_remote_code=True, use_fast=False)
    model = InternVLChatModel.from_pretrained(
        args.model_base, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
        load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval()
    if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
        model = model.cuda()
    # import ipdb;ipdb.set_trace()
    from copy import deepcopy
    # token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
    # if model.lm_head.weight.shape[0] != token_num:
    #     model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
    #     model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
    lora_path = os.path.join(model_path, "llava-lora")
    language_model = deepcopy(model.language_model)
    vision_model = deepcopy(model.vision_model)
    mlp1 = deepcopy(model.mlp1)
    # lora_path = model_path
    print('Loading additional LLaVA weights...')
    if os.path.exists(os.path.join(lora_path, 'non_lora_trainables.bin')):
        non_lora_trainables = torch.load(os.path.join(lora_path, 'non_lora_trainables.bin'), map_location='cpu')
    non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
    if any(k.startswith('model.model.') for k in non_lora_trainables):
        non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
    language_model.load_state_dict(non_lora_trainables, strict=False)
    
    # withoutlora_model = deepcopy(model)
    original_parameters = {}
    original_lan_parameters = {}
    for name, param in model.named_parameters():
        # print(f"Checking parameter: {name}")
        if param.data.is_meta:
            print(f"Warning: {name} is a meta tensor, skipping.")
            continue
        original_parameters[name] = deepcopy(param.data).to(device="cpu")
        
    for name, param in language_model.named_parameters():
        # print(f"Checking parameter: {name}")
        if param.data.is_meta:
            print(f"Warning: {name} is a meta tensor, skipping.")
            continue
        original_lan_parameters[name] = deepcopy(param.data).to(device="cpu")
        
    from peft import PeftModel, PeftMixedModel, get_peft_model_state_dict, PromptTuningConfig
    # lora_path = os.path.join(lora_path, "lora")
    # import ipdb;ipdb.set_trace()
    print('Loading LoRA weights...')
    language_model = PeftModel.from_pretrained(language_model, lora_path)
    # for name, param in language_model.named_parameters():
    #     if 'lora' in name:
    #         print(f"Checking parameter after LoRA load: {name}")
    print('Merging LoRA weights...')
    
    language_model = language_model.merge_and_unload()
    print('Model is loaded...')
    print(f"prompt_tuning_adding is {prompt_tuning_adding}")
    
    # 加载Prompt Tuning权重
    # ...existing code...
    # for name, param in language_model.named_parameters():
    #     print(f"Checking parameter after LoRA load: {name}")
    # 比较withoutlora_model和model的参数，若全部参数完全相同则报错
    all_equal = True
    for name2, param2 in language_model.named_parameters():
        # print(f"Comparing parameter: {name2}")
        # print(f"Original param shape: {original_lan_parameters[name2]}, Current param shape: {param2.data}")
        if not torch.equal(original_lan_parameters[name2].to(model.device), param2.data):
            all_equal = False
            break
    if all_equal:
        raise RuntimeError("withoutlora_model和model的所有参数完全相同，可能未正确加载LoRA或Prompt Tuning权重。")
    else:
        print("withoutlora_model和model的参数已成功加载，且不相同。")
        
    
    # all_equal = True
    # for name2, param2 in model.named_parameters():
    #     if not torch.equal(original_parameters[name2].to(model.device), param2.data):
    #         all_equal = False
    #         break
    # if all_equal:
    #     raise RuntimeError("withoutlora_model和model的所有参数完全相同，可能未正确加载LoRA或Prompt Tuning权重。")
    # else:
    #     print("withoutlora_model和model的参数已成功加载，且不相同。")
    # ...existing code...
    # import ipdb;ipdb.set_trace()
    if prompt_tuning_adding:
        print(f"prompt_tuning_adding is {prompt_tuning_adding}, loading prompt tuning weights...")
        prompt_tuning_path = os.path.join(model_path, "llava-prompt_tuning")
        if os.path.exists(prompt_tuning_path):
            print('Loading Prompt Tuning weights...')
            language_model = PeftModel.from_pretrained(language_model, prompt_tuning_path)
            print('Prompt Tuning weights loaded.')
    model.language_model = language_model
    model.vision_model = vision_model
    model.mlp1 = mlp1
    # print(model)
    return model,tokenizer
# from .load_my_model import load_pretrained_model_both
class VQADataset(torch.utils.data.Dataset):

    def __init__(self, train, test, prompt, few_shot, image_folder ,input_size=224, dynamic_image_size=False,
                 use_thumbnail=False, max_num=6):
        self.test = open(test).readlines()
        self.prompt = prompt
        self.input_size = input_size
        self.dynamic_image_size = dynamic_image_size
        self.use_thumbnail = use_thumbnail
        self.few_shot = few_shot
        self.max_num = max_num
        if few_shot > 0:
            self.train = open(train).readlines()
        self.transform = build_transform(is_train=False, input_size=input_size)
        self.image_folder = image_folder

    def __len__(self):
        return len(self.test)

    def __getitem__(self, idx):
        data = json.loads(self.test[idx].strip())
        image, question, question_id = data['image'], data[
            'text'], data['question_id']

        few_shot_prompt = ''
        if self.few_shot > 0:
            few_shot_samples = random.sample(self.train, self.few_shot)
            for sample in few_shot_samples:
                sample = json.loads(sample.strip())
                few_shot_prompt += self.prompt.format(
                    sample['image'],
                    sample['question']) + f" {sample['answer']}"
        
        image = Image.open(os.path.join(self.image_folder, image)).convert('RGB')
        if self.dynamic_image_size:
            images = dynamic_preprocess(image, image_size=self.input_size,
                                        use_thumbnail=self.use_thumbnail,
                                        max_num=self.max_num)
        else:
            images = [image]
        pixel_values = [self.transform(image) for image in images]
        pixel_values = torch.stack(pixel_values)
        if len(self.prompt) != 0:
            question = question + ' ' + self.prompt
        return {
            'question_id': question_id,
            'question': question,
            'pixel_values': pixel_values,
        }

def collate_fn(batch, tokenizer):
    # 示例：假设batch是list of dict
    pixel_values = torch.cat([item['pixel_values'] for item in batch], dim=0)
    questions = [item['question'] for item in batch]
    question_ids = [item['question_id'] for item in batch]
    return pixel_values, questions, question_ids
def post_process(response):
    response = response.strip().split('.')[0].split(
        ',')[0].split('!')[0].lower()
    if 'is ' in response:
        response = response.split('is ')[1]
    if 'are ' in response:
        response = response.split('are ')[1]
    if 'a ' in response:
        response = response.split('a ')[1]
    if 'an ' in response:
        response = response.split('an ')[1]
    if 'the ' in response:
        response = response.split('the ')[1]
    if ' of' in response:
        response = response.split(' of')[0]
    response = response.strip()
    return response


def evaluate_chat_model():
    base_prompt = ''
    # vizwiz_prompt = "When the provided information is insufficient, respond with 'Unanswerable'. "
    # infovqa_prompt = 'Answer the question using a single word or phrase.'
    random.seed(args.seed)
    summaries = []


    input_prompt = base_prompt

    dataset = VQADataset(
        train=None,
        test=args.question_file,
        prompt=input_prompt,
        few_shot=args.few_shot,
        input_size=image_size,
        dynamic_image_size=args.dynamic,
        use_thumbnail=use_thumbnail,
        max_num=args.max_num,
        image_folder=args.image_folder
    )
    # 用SequentialSampler替换分布式采样器
    from torch.utils.data import SequentialSampler
    dataloader = torch.utils.data.DataLoader(
        dataset=dataset,
        sampler=SequentialSampler(dataset),
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
        collate_fn=partial(collate_fn, tokenizer=tokenizer),
    )

    outputs = []
    ans_file = open(args.answers_file, "w")
    for _, (pixel_values, questions, question_ids) in tqdm(enumerate(dataloader)):
        pixel_values = pixel_values.to(torch.bfloat16).cuda()
        generation_config = dict(
            num_beams=args.num_beams,
            max_new_tokens=128,
            min_new_tokens=1,
            do_sample=True if args.temperature > 0 else False,
            temperature=args.temperature,
        )
        pred = model.chat(
            tokenizer=tokenizer,
            pixel_values=pixel_values,
            question=questions[0],
            generation_config=generation_config,
            verbose=False
        )
        answers = [pred]
        outputs.append({
            'question': questions[0],
            'question_id': question_ids[0],
            'answer': answers[0],
        })
        ans_file.write(json.dumps({
            'question': questions[0],
            'question_id': question_ids[0],
            'answer': answers[0],
        }) + "\n")
    ans_file.close()



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-base', type=str, default='')
    parser.add_argument('--model-path', type=str, default='OpenGVLab/internvl-chat-7b')
    # parser.add_argument('--prompttuning', type=bool)
    parser.add_argument('--use_prompt_tuning', action='store_true', default=False)
    parser.add_argument('--batch-size', type=int, default=1)
    parser.add_argument('--num-workers', type=int, default=1)
    parser.add_argument('--image-folder', type=str, default='/data2/dmz/llava_test/LLaVA-main')
    parser.add_argument('--num-beams', type=int, default=1)
    parser.add_argument('--temperature', type=float, default=0.0)
    # parser.add_argument('--out-dir', type=str, default='results')
    parser.add_argument('--few-shot', type=int, default=0)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--dynamic', action='store_true')
    parser.add_argument('--max-num', type=int, default=6)
    parser.add_argument('--load-in-8bit', action='store_true')
    parser.add_argument('--load-in-4bit', action='store_true')
    parser.add_argument('--auto', action='store_true')
    parser.add_argument('--answers-file', type=str, default='OpenGVLab/internvl-chat-7b')
    parser.add_argument('--question-file', type=str, default='data/vqa/vizwiz_val.jsonl')
    parser.add_argument('--out-dir', type=str, default='data/vqa/vizwiz_val.jsonl')
    # freeze_backbone: bool = field(
    #     default=False,
    #     metadata={'help': 'Set to True to freeze the ViT. Default is False.'},
    # )
    args = parser.parse_args()

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir, exist_ok=True)

    # args.datasets = args.datasets.split(',')
    # print('datasets:', args.datasets)
    assert args.batch_size == 1, 'Only batch size 1 is supported'

    # 移除分布式初始化
    # torch.distributed.init_process_group(
    #     backend='nccl',
    #     world_size=int(os.getenv('WORLD_SIZE', '1')),
    #     rank=int(os.getenv('RANK', '0')),
    # )
    # torch.distributed.init_process_group(
    #     backend='nccl',
    #     world_size=int(1),
    #     rank=int(0),
    # )

    torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))

    # model, tokenizer = load_model_and_tokenizer(args)
    # import ipdb;ipdb.set_trace()
    model, tokenizer = load_pretrained_model_both(args.model_path, args.model_base, prompt_tuning_adding=args.use_prompt_tuning, load_8bit=args.load_in_8bit, load_4bit=args.load_in_4bit, device_map="auto" if args.auto else None)
    image_size = model.config.force_image_size or model.config.vision_config.image_size
    use_thumbnail = model.config.use_thumbnail

    total_params = sum(p.numel() for p in model.parameters()) / 1e9
    if total_params > 20 or args.dynamic:
        args.num_beams = 1
        print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
    else:
        print(f'[test] total_params: {total_params}B')
    print(f'[test] image_size: {image_size}')
    print(f'[test] template: {model.config.template}')
    print(f'[test] dynamic_image_size: {args.dynamic}')
    print(f'[test] use_thumbnail: {use_thumbnail}')

    evaluate_chat_model()