# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from dataclasses import dataclass, field
import json
import os
import re
import sys
from time import sleep
from typing import Optional, List
import logging

import torch
import torch.distributed as dist

import accelerate
from transformers import set_seed
import transformers

try:
    from transformers import LlamaTokenizerFast as LlamaTokenizer

    print("Using fast tokenizer")
except:
    from transformers import LlamaTokenizer

    print("Using slow tokenizer")

from transformers import AutoTokenizer, AutoModelForCausalLM

from all_utils.data_utils.data_utils_ppo import DataCollatorForQueryResponseDataset, QueryResponseDataset, make_rl_data_module, pad_sequences

from llava.RLHF.models.ppo.ppo_trainer import truncate_after_eos_with_padding, remove_image_token
from peft import PeftModel
from torch.utils.data import DataLoader
from tqdm import tqdm
import all_utils.data_utils.common_utils as common_utils
from torch.distributed import all_gather, get_rank, is_initialized
import time

from llava import conversation as conversation_lib
from llava.model import *
from llava.constants import (
    IMAGE_TOKEN_INDEX,
)
from typing import List, Optional, Callable, Dict
torch.backends.cuda.matmul.allow_tf32 = True

logger = logging.getLogger(__name__)


class MMSafeBenchDataset(QueryResponseDataset):
    """Dataset subclass for MMSafeBench that emits tokenized left-padded queries and includes an additional 'type' field."""

    def __init__(
        self,
        list_dict_data: List[dict],
        tokenizer: transformers.PreTrainedTokenizer,
        query_len: int,
        data_args: Optional[Dict] = None,
    ):
        super(MMSafeBenchDataset, self).__init__(
            list_dict_data=list_dict_data,
            tokenizer=tokenizer,
            query_len=query_len,
            data_args=data_args,
        )

    def __getitem__(self, idx):
        # Retrieve the base class items
        return_dict = super(MMSafeBenchDataset, self).__getitem__(idx)

        # Add 'type' field if it exists in the data
        if "type" in self.list_dict_data[idx]:
            return_dict["types"] = torch.tensor(
                self.list_dict_data[idx]["type"], dtype=torch.long
            )
        else:
            # Handle cases where 'type' might not be available
            return_dict["type"] = torch.tensor(-1, dtype=torch.long)  # Using -1 as a placeholder value for missing type

        if "scenario" in self.list_dict_data[idx]:
            return_dict["scenarios"] = torch.tensor(
                self.list_dict_data[idx]["scenario"], dtype=torch.long
            )
        else:
            # Handle cases where 'type' might not be available
            return_dict["scenario"] = torch.tensor(-1, dtype=torch.long)  # Using -1 as a placeholder value for missing type
        return return_dict



@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="XXX")
    # base_model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b")
    lora_enable: bool = field(default=False)
    peft_model_id_path: Optional[str] = field(default="XXX")
    temperature: float = field(default=1.0)
    trust_remote_code: Optional[bool] = field(
        default=False,
        metadata={
            "help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."
        },
    )
    version: Optional[str] = field(default="v1")
    tune_mm_mlp_adapter: bool = field(default=False)
    vision_tower: Optional[str] = field(default=None)
    mm_vision_select_layer: Optional[int] = field(
        default=-2
    )  # default to the last layer
    pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
    mm_use_im_start_end: bool = field(default=False)
    mm_use_im_patch_token: bool = field(default=False)
    mm_vision_select_feature: Optional[str] = field(default="patch")


@dataclass
class DataArguments:
    dataset_path: str = field(default="tatsu-lab/alpaca_farm")
    train_splits: List[str] = field(default_factory=lambda: ["unlabeled"])
    stop_token: Optional[str] = field(
        default=None,
        metadata={"help": "Token to stop generation with."},
    )
    # From LLaVA
    lazy_preprocess: bool = False
    is_multimodal: bool = True
    image_folder: Optional[str] = field(default=None)
    image_aspect_ratio: str = "pad"
    image_grid_pinpoints: Optional[str] = field(default=None)
    keywords: Optional[List[str]] = field(default_factory=lambda: ["harm"])
    dataset_path: str = field(default="XXX/config_test.json")


@dataclass
class TrainingArguments(transformers.Seq2SeqTrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    finetune_mm_projector: bool = field(default=False)
    # From AlpacaFarm
    truncate_tokens: Optional[List[str]] = field(
        default_factory=lambda: None,
        metadata={
            "help": "Tokens in strings to truncate at first occurrence. "
            "This was used in original OAI summarization paper to avoid models returning incomplete sentences. "
        },
    )
    suppress_eos_at_generation: bool = field(
        default=False,
        metadata={
            "help": "Whether to suppress the end-of-sequence token at generation time."
        },
    )
    num_patches: int = field(default=576)
    model_max_length: int = field(default=2048)
    query_len: int = field(default=128)
    output_dir: str = field(
        default="./output", metadata={"help": "The output dir for logs and checkpoints"}
    )

def merge_json_files(directory):
    all_data = []  # List to hold all the data combined from JSON files

    # Walk through all the directories and files in the given directory
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith('.json'):
                file_path = os.path.join(root, file)
                with open(file_path, 'r') as f:
                    data = json.load(f)
                    all_data.extend(data)

    return all_data



def make_mmsafebench_data_module(
    tokenizer: transformers.PreTrainedTokenizer,
    data_args,
    training_args,
):
    list_data_dict = merge_json_files(data_args.dataset_path)
    eval_dataset = MMSafeBenchDataset(
        list_dict_data=list_data_dict,
        tokenizer=tokenizer,
        query_len=training_args.query_len,
        data_args=data_args,
    )
    return dict(
        eval_dataset=eval_dataset,
        data_collator=DataCollatorForQueryResponseDataset(),
    )

def custom_gather(data):
    """
    自定义聚合函数，适用于字符串列表。
    使用此函数前提是已经初始化了分布式环境。
    """
    output_lists = [None] * torch.distributed.get_world_size()
    torch.distributed.all_gather_object(output_lists, data)
    # 扁平化列表
    return [item for sublist in output_lists for item in sublist]

def rank0_print(*args):
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if local_rank == 0:
        print(*args)

def adjust_queries(queries):
    non_zero_lengths = (queries != 0).long().sum(dim=1)
    max_length = non_zero_lengths.max()
    adjusted_queries = queries[:, -max_length:]
    return adjusted_queries

def strip_pad(seq: List[int], tokenizer):
        return [tok for tok in seq if tok != tokenizer.pad_token_id]

def extract_queries_responses(quereis, responses, tokenizer):
    quereis_list = quereis.tolist() #记得改拼写错误
    responses_list = responses.tolist()
    quereis_vec, responses_vec = [], []
    for query, response in zip(quereis_list, responses_list):
        query = strip_pad(query, tokenizer)
        response = strip_pad(response, tokenizer)
        quereis_vec.append(query)
        responses_vec.append(response)
    sequences_vec = [c + r for c, r in zip(quereis_vec, responses_vec)]
    return quereis_vec, responses_vec, sequences_vec

def get_model_answer(batch, model, accelerator, tokenizer, args):
    unwrapped_policy = accelerator.unwrap_model(
    model, keep_fp32_wrapper=True
    )

    (
            images,
            ids,
            queries,
            scenarios,
            types
        ) = common_utils.unpack_dict(
            common_utils.prepare_inputs(batch, device=accelerator.device),
            keys=(
                "images",
                "ids",
                "queries",
                "scenarios",
                "types"
            ),
        )
    queries = adjust_queries(queries=queries)
    query_attn_masks = queries.ne(tokenizer.pad_token_id)
    images = images.to(torch.bfloat16)
    responses = unwrapped_policy.generate(
        inputs=queries,
        images=images,
        attention_mask=query_attn_masks,
        do_sample=False,
        max_length=args.model_max_length,
        pad_token_id=tokenizer.pad_token_id,
        suppress_tokens=(
            [tokenizer.eos_token_id]
            if args.suppress_eos_at_generation
            else None
        ),
        top_p=1.0, #nucleus sampling
        top_k=0,
        temperature=args.temperature,
        num_return_sequences=1, #num_return_sequences = 1
        synced_gpus=True,
    )
    responses = responses[:,1:]


    responses = truncate_after_eos_with_padding(
            responses,
            tokenizer.eos_token_id,
            tokenizer.pad_token_id,
        )
    queries_vec, responses_vec, sequences_vec = extract_queries_responses(queries,responses,tokenizer)
    sequences = torch.tensor(pad_sequences(sequences_vec, pad_value=tokenizer.pad_token_id),dtype=torch.long, device=accelerator.device)
    sequences_attention_mask = sequences.ne(tokenizer.pad_token_id)
    return remove_image_token(queries), responses, ids, scenarios, types

def save_samples_by_scenario_and_type(gathered_results, output_dir):
    scenario_list = {
        1: "01-Illegal_Activitiy",
        2: "02-HateSpeech",
        3: "03-Malware_Generation",
        4: "04-Physical_Harm",
        5: "05-EconomicHarm",
        6: "06-Fraud",
        7: "07-Sex",
        8: "08-Political_Lobbying",
        9: "09-Privacy_Violence",
        10: "10-Legal_Opinion",
        11: "11-Financial_Advice",
        12: "12-Health_Consultation",
        13: "13-Gov_Decision",
    }

    type_list = {
        0: "text_only",
        1: "sd",
        2: "typo",
        3: "sd_typo",
    }
    results_by_scenario_type = {}

    # Organize samples into directories by scenario and type
    for sample in gathered_results:
        scenario_dir = scenario_list.get(sample['scenario'], "Unknown_Scenario")
        type_dir = type_list.get(sample['type'], "Unknown_Type")
        key = (scenario_dir, type_dir)

        if key not in results_by_scenario_type:
            results_by_scenario_type[key] = []

        results_by_scenario_type[key].append(sample)

    # Write samples to respective files
    for (scenario, type_), samples in results_by_scenario_type.items():
        scenario_path = os.path.join(output_dir, scenario)
        os.makedirs(scenario_path, exist_ok=True)
        type_file_path = os.path.join(scenario_path, f"{type_}.json")

        with open(type_file_path, "w") as f:
            json.dump(samples, f, indent=4)

    print(f"Samples successfully saved in structured directories under {output_dir}")



def save_samples_by_scenario_and_type(gathered_results, output_dir):
    scenario_list = {
        1: "01-Illegal_Activitiy",
        2: "02-HateSpeech",
        3: "03-Malware_Generation",
        4: "04-Physical_Harm",
        5: "05-EconomicHarm",
        6: "06-Fraud",
        7: "07-Sex",
        8: "08-Political_Lobbying",
        9: "09-Privacy_Violence",
        10: "10-Legal_Opinion",
        11: "11-Financial_Advice",
        12: "12-Health_Consultation",
        13: "13-Gov_Decision",
    }

    type_list = {
        0: "text_only",
        1: "sd",
        2: "typo",
        3: "sd_typo",
    }
    results_by_scenario_type = {}

    # Organize samples into directories by scenario and type
    for sample in gathered_results:
        scenario_dir = scenario_list.get(sample['scenario'], "Unknown_Scenario")
        type_dir = type_list.get(sample['type'], "Unknown_Type")
        key = (scenario_dir, type_dir)

        if key not in results_by_scenario_type:
            results_by_scenario_type[key] = []

        results_by_scenario_type[key].append(sample)

    # Write samples to respective files
    for (scenario, type_), samples in results_by_scenario_type.items():
        scenario_path = os.path.join(output_dir, scenario)
        os.makedirs(scenario_path, exist_ok=True)
        type_file_path = os.path.join(scenario_path, f"{type_}.json")

        with open(type_file_path, "w") as f:
            sorted_samples = sorted(samples, key=lambda x: x['id'])
            unique_samples = {}
            for item in sorted_samples:
                if item['id'] not in unique_samples:
                    unique_samples[item['id']] = item
            final_samples = list(unique_samples.values())
            json.dump(final_samples, f, indent=4)

    print(f"Samples successfully saved in structured directories under {output_dir}")


@torch.inference_mode()
def eval():
    hfparser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    (
        model_args,
        data_args,
        training_args,
        extra_args,
    ) = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
    args = argparse.Namespace(
        **vars(model_args), **vars(data_args), **vars(training_args)
    )
    training_args.data_config = data_args

    accelerator = accelerate.Accelerator()
    tokenizer_model_name = args.model_name_or_path
    TokenizerClass = AutoTokenizer
    # Tokenizer
    tokenizer = TokenizerClass.from_pretrained(
        tokenizer_model_name,
        cache_dir=args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="left",
        truncation_side="right",
        use_fast=False,
    )

    tokenizer.pad_token = tokenizer.unk_token
    if model_args.version in conversation_lib.conv_templates:
        conversation_lib.default_conversation = conversation_lib.conv_templates[
            model_args.version
        ] #Conversation(system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=('USER', 'ASSISTANT'), messages=(), offset=0, sep_style=<SeparatorStyle.TWO: 2>, sep=' ', sep2='</s>', version='v1', skip_next=False)
    else:
        conversation_lib.default_conversation = conversation_lib.conv_templates[
            "vicuna_v1"
        ]

    if model_args.vision_tower is not None:
        from llava.model import LlavaLlamaForCausalLM

        model = LlavaLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch.bfloat16)
        if args.lora_enable:
            model = PeftModel.from_pretrained(model, os.path.join(args.peft_model_id_path, "adapter_model/lora_policy"))
            mm_projector_path = os.path.join(args.peft_model_id_path, "mm_projector.bin")
            if os.path.exists(mm_projector_path):
                mm_projector_weights = torch.load(mm_projector_path, map_location='cpu')
                def get_w(weights, keyword):
                    return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
                model.get_model().mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
                print(f"Success loaed mm_projector at {mm_projector_path}")
            else:
                print(f"Warning: mm_projector not found at {mm_projector_path}")
        vision_tower = model.get_vision_tower()
        if not vision_tower.is_loaded:
            vision_tower.load_model()

        data_args.image_processor = vision_tower.image_processor
        training_args.query_len = args.query_len = training_args.model_max_length - vision_tower.num_patches
        training_args.num_patches = args.num_patches = vision_tower.num_patches
        data_args.is_multimodal = True
        data_args.mm_use_im_start_end = model_args.mm_use_im_start_end #false
        training_args.use_im_start_end = model_args.mm_use_im_start_end #false


    model.config.use_cache = False
    model.config.tokenizer_padding_side = 'left'
    if args.vision_tower is not None:
        model.config.image_aspect_ratio = args.image_aspect_ratio #image_aspect_ratio = pad
        model.config.image_grid_pinpoints = args.image_grid_pinpoints #None
        vision_tower.to(device="cuda", dtype=torch.bfloat16)
        mm_projector = model.get_model().mm_projector
        mm_projector.to(device="cuda", dtype=torch.bfloat16)
    model.to(dtype=torch.bfloat16, device=training_args.device)

    args.lora_enable = False


    # Dataset
    data_module: dict = make_mmsafebench_data_module(
        tokenizer=tokenizer, data_args=data_args, training_args=training_args
    )

    if accelerator.is_main_process:
        training_data = data_module["eval_dataset"]
        for i in range(3):
            ex_input_ids_0 = training_data[i]["queries"]
            rank0_print(ex_input_ids_0[ex_input_ids_0 != tokenizer.pad_token_id]) #把前面的pad都省略了
            ex_input_ids_0[ex_input_ids_0 == IMAGE_TOKEN_INDEX] = tokenizer.eos_token_id
            rank0_print(tokenizer.decode(ex_input_ids_0, skip_special_tokens=False))
            rank0_print("=" * 20)

    rank = int(os.environ.get("RANK", 0)) #0
    world_size = int(os.environ.get("WORLD_SIZE", 1)) #1
    node_id = rank // torch.cuda.device_count() #0

    print(f"Distributed info: rank={rank}, world_size={world_size}, node_id={node_id}")


    # sampler = torch.utils.data.DistributedSampler(data_module['train_dataset'], shuffle=False)
    eval_dataloader = DataLoader(
        dataset=data_module['eval_dataset'],
        batch_size=1,  # Ensure this is set in args
        collate_fn=data_module['data_collator'],
        shuffle=False,  # For evaluation we usually don't shuffle
        # sampler=sampler
    )
    model, eval_dataloader = accelerator.prepare(model, eval_dataloader)
    model.eval()
    all_samples = []
    for batch in tqdm(eval_dataloader, desc="Evaluating", total=len(eval_dataloader)):
        sample = []
        queries, responses, ids, scenarios, types  = get_model_answer(batch, model=model, accelerator=accelerator, tokenizer=tokenizer, args=args)
        for query, response, id, scenario, type in zip(queries, responses, ids, scenarios, types):
            sample = {
                "id": id.item(),
                "scenario": scenario.item(),
                "type": type.item(),
                "query": tokenizer.decode(query, skip_special_tokens=True),
                "response": tokenizer.decode(response, skip_special_tokens=True),
            }
            all_samples.append(sample)
    # all_samples = accelerator.gather(all_samples)
    # 聚合所有进程的数据
    if is_initialized():
        gathered_results = custom_gather(all_samples)
    else:
        gathered_results = all_samples

    # 只在主进程中执行写入操作
    if torch.distributed.get_rank() == 0:
        save_samples_by_scenario_and_type(gathered_results, args.output_dir)




if __name__ == "__main__":
    eval()
    time.sleep(20)
