import logging
from typing import List, Optional, Tuple

import os
import torchvision.transforms as transforms
import torch.distributed as dist
import json

import numpy as np
import torch
import torchvision.transforms as T
from accelerate import Accelerator, DistributedType
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from lmms_eval.api.instance import Instance
from lmms_eval.api.model import lmms
from lmms_eval.api.registry import register_model

eval_logger = logging.getLogger("eval_logger")

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

DEFAULT_GEN_KWARGS = dict(
    num_beams=1,
    max_new_tokens=1024,
    do_sample=False,
)


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD)])
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float("inf")
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size)
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def load_image(image, input_size=448, max_num=6):
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values


def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
    if bound:
        start, end = bound[0], bound[1]
    else:
        start, end = -100000, 100000
    start_idx = max(first_idx, round(start * fps))
    end_idx = min(round(end * fps), max_frame)
    seg_size = float(end_idx - start_idx) / num_segments
    frame_indices = np.array([int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
    return frame_indices


def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
    vr = VideoReader(video_path, ctx=cpu(0))
    max_frame = len(vr) - 1
    fps = float(vr.get_avg_fps())

    pixel_values_list, num_patches_list = [], []
    transform = build_transform(input_size=input_size)
    frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
    for frame_index in frame_indices:
        img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
        img = dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(tile) for tile in img]
        pixel_values = torch.stack(pixel_values)
        num_patches_list.append(pixel_values.shape[0])
        pixel_values_list.append(pixel_values)
    pixel_values = torch.cat(pixel_values_list)
    return pixel_values, num_patches_list


from datetime import timedelta

from accelerate.state import AcceleratorState
from accelerate.utils import InitProcessGroupKwargs

# from transformers import AutoModel, GenerationConfig


# import sys; sys.path = ["LLaVA-NeXT/"] + sys.path
# try:
#     from llava.mm_utils import KeywordsStoppingCriteria
# except ImportError as e:
#     eval_logger.debug(f"LLaVA is not installed. Please install LLaVA to use this model.\nError: {e}")
# import types

# @torch.no_grad()
# def generate(
#         self,
#         pixel_values: Optional[torch.FloatTensor] = None,
#         input_ids: Optional[torch.FloatTensor] = None,
#         attention_mask: Optional[torch.LongTensor] = None,
#         visual_features: Optional[torch.FloatTensor] = None,
#         generation_config: Optional[GenerationConfig] = None,
#         output_hidden_states: Optional[bool] = None,
#         return_dict: Optional[bool] = None,
#         **generate_kwargs,
# ) -> torch.LongTensor:
    
#     # if generate_kwargs.get('eos_token_id', None):
#     #     generate_kwargs.pop('eos_token_id')
#     #     eos_token = self.tokenizer.eos_token
#     #     keywords = [eos_token]
#     #     stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
#     #     generate_kwargs['stopping_criteria'] = [stopping_criteria]

#     assert self.img_context_token_id is not None
#     if pixel_values is not None:
#         if visual_features is not None:
#             vit_embeds = visual_features
#         else:
#             vit_embeds = self.extract_feature(pixel_values)
#         input_embeds = self.language_model.get_input_embeddings()(input_ids)
#         B, N, C = input_embeds.shape
#         input_embeds = input_embeds.reshape(B * N, C)

#         input_ids = input_ids.reshape(B * N)
#         selected = (input_ids == self.img_context_token_id)
#         assert selected.sum() != 0
#         input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)

#         input_embeds = input_embeds.reshape(B, N, C)
#     else:
#         input_embeds = self.language_model.get_input_embeddings()(input_ids)

#     outputs = self.language_model.generate(
#         inputs_embeds=input_embeds,
#         attention_mask=attention_mask,
#         generation_config=generation_config,
#         output_hidden_states=output_hidden_states,
#         return_dict=return_dict,
#         use_cache=True,
#         **generate_kwargs,
#     )

#     return outputs


@register_model("internvl2")
class InternVL2(lmms):
    def __init__(
        self,
        pretrained: str = "OpenGVLab/InternVL2-2B",
        modality: str = "image",
        device: str = "cuda:0",
        device_map: str = "cuda:0",
        batch_size: str = "1",
        max_frames_num: int = 32,
        **kwargs,
    ):
        super().__init__()

        self.path = pretrained
        
        # make sure internvl2 works well with pipeline parallel
        # self._model.generate = types.MethodType(generate, self._model)
        if device_map == 'auto':
            self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, device_map=device_map).eval()
            from accelerate.hooks import add_hook_to_module, AlignDevicesHook
            add_hook_to_module(self._model.language_model.lm_head, AlignDevicesHook(execution_device=self._model.language_model.model.embed_tokens.weight.device))
        else:
            self._model = AutoModel.from_pretrained(self.path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True).eval().cuda()

        self._tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
        # self._model.tokenizer = self._tokenizer

        batch_size = int(batch_size)
        assert batch_size == 1, f"Batch size should be 1 for InternVL2, but got {batch_size}."
        self.batch_size_per_gpu = batch_size

        accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
        accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
        self.accelerator = accelerator
        if accelerator.num_processes > 1:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"
        elif accelerator.num_processes == 1 and device_map == "auto":
            self._device = torch.device(device)
            self.device_map = device_map
        else:
            self._device = torch.device(f"cuda:{accelerator.local_process_index}")
            self.device_map = f"cuda:{accelerator.local_process_index}"

        if accelerator.num_processes > 1:
            assert accelerator.distributed_type in [DistributedType.FSDP, DistributedType.MULTI_GPU, DistributedType.DEEPSPEED], "Unsupported distributed type provided. Only DDP and FSDP are supported."
            # If you want to use DistributedType.DEEPSPEED, you have to run accelerate config before using the model
            # Also, you have to select zero stage 0 (equivalent to DDP) in order to make the prepare model works
            # I tried to set different parameters in the kwargs to let default zero 2 stage works, but it didn't work.
            if accelerator.distributed_type == DistributedType.DEEPSPEED:
                kwargs = {
                    "train_micro_batch_size_per_gpu": self.batch_size_per_gpu,
                    "train_batch_size": self.batch_size_per_gpu * accelerator.num_processes,
                }
                AcceleratorState().deepspeed_plugin.deepspeed_config_process(must_match=True, **kwargs)
                eval_logger.info("Detected that you are using DistributedType.DEEPSPEED. Make sure you run `accelerate config` and set zero stage to 0")

            if accelerator.distributed_type == DistributedType.FSDP or accelerator.distributed_type == DistributedType.DEEPSPEED:
                self._model = accelerator.prepare(self.model)
            else:
                self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
            self.accelerator = accelerator
            if self.accelerator.is_local_main_process:
                eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
            self._rank = self.accelerator.local_process_index
            self._world_size = self.accelerator.num_processes
        elif accelerator.num_processes == 1 and device_map == "auto":
            eval_logger.info(f"Using {accelerator.num_processes} devices with tensor parallelism")
            self._rank = 0
            self._word_size = 1
        else:
            eval_logger.info(f"Using single device: {self._device}")
            self.model.to(self._device)
            self._rank = 0
            self._world_size = 1

        self.modality = modality
        self.max_frames_num = max_frames_num

    @property
    def config(self):
        # return the associated transformers.AutoConfig for the given pretrained model.
        return self._config

    @property
    def tokenizer(self):
        return self._tokenizer

    @property
    def model(self):
        # returns the model, unwrapping it if using Accelerate
        if hasattr(self, "accelerator"):
            return self.accelerator.unwrap_model(self._model)
        else:
            return self._model

    @property
    def batch_size(self):
        return self.batch_size_per_gpu

    @property
    def device(self):
        return self._device

    @property
    def rank(self):
        return self._rank

    @property
    def world_size(self):
        return self._world_size

    def flatten(self, input):
        new_list = []
        for i in input:
            for j in i:
                new_list.append(j)
        return new_list

    def generate_until(self, requests) -> List[str]:
        res = []
        pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
        
        for contexts, gen_kwargs, doc_to_visual, doc_id, task, split, alpha_q, \
        alpha_k, alpha_v, num_classes_total, num_classes_selected, pca_rank, cluster_method, rho, eps, layer_wise_scale, boost_layer in [reg.args for reg in requests]:
            if "until" in gen_kwargs:
                gen_kwargs.pop("until")
            for k, v in DEFAULT_GEN_KWARGS.items():
                if k not in gen_kwargs:
                    gen_kwargs[k] = v

            pop_keys = []
            for k, v in gen_kwargs.items():
                if k not in DEFAULT_GEN_KWARGS:
                    pop_keys.append(k)

            for k in pop_keys:
                gen_kwargs.pop(k)

            visuals = [doc_to_visual(self.task_dict[task][split][doc_id])]
            
            if True: # for vis
                import os
                
                # 确保分布式环境已初始化
                if dist.is_initialized():
                    rank = dist.get_rank()  # 获取当前进程的 rank
                    world_size = dist.get_world_size()  # 获取总进程数
                else:
                    rank = 0
                    world_size = 1
                        
                # **所有进程等待主进程完成文件初始化**
                if dist.is_initialized():
                    dist.barrier()
                
                    # **使用for循环按顺序更新每个rank的映射**
                    for r in list(range(16)):
                        if rank == r:
                            # update_mapping(rank, directory, map_file)
                            # print(f"Rank {rank} will save to: {directory}")
                            dist.barrier()

                        # 同步，确保所有进程都完成映射更新
                        dist.barrier()

                
            if False: # for saving GTs
                save_gt_to_file(directory, self.task_dict[task][split][doc_id]['ground_truth'])
            
            if visuals != [None]:
                visuals = self.flatten(visuals)
                if self.modality == "image":
                    if visuals:
                        visuals = [load_image(visual).to(torch.bfloat16).cuda() for visual in visuals]
                        pixel_values = torch.cat(visuals, dim=0)
                        num_patches_list = [visual.size(0) for visual in visuals]
                        image_tokens = ["<image>"] * len(visuals)
                        image_tokens = " ".join(image_tokens)
                        contexts = image_tokens + "\n" + contexts
                    else:
                        pixel_values = None
                        num_patch_list = None
                    response, history = self.model.chat(self.tokenizer, pixel_values, contexts, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True)
                elif self.modality == "video":
                    assert len(visuals) == 1, f"Only one video is supported, but got {len(visuals)} videos."
                    video_path = visuals[0]
                    pixel_values, num_patches_list = load_video(video_path, num_segments=self.max_frames_num, max_num=1)
                        
                    # import pdb; pdb.set_trace()
                    
                    pixel_values = pixel_values.to(torch.bfloat16).cuda()
                    video_prefix = "".join([f"Frame{i+1}: <image>\n" for i in range(len(num_patches_list))])
                    question = video_prefix + contexts
                    
                    if cluster_method is not None:
                        save_cluster_root_path = None
                        save_cluster_root_path = 'SET-ROOT-PATH-HERE' + self.path.split('/')[-1] + '_' + str(self.max_frames_num) + 'f_' + str(num_classes_total) + 'classes_' + str(rho) + 'rho_' + str(eps) + 'eps'
                        if save_cluster_root_path is not None:
                            os.makedirs(save_cluster_root_path, exist_ok=True)
                            save_cluster_path = os.path.join(save_cluster_root_path, f"{doc_id:04d}")
                            os.makedirs(save_cluster_path, exist_ok=True)
                    
                    if False:
                        save_tensor_as_image(pixel_values, save_cluster_path)
                    
                    if (cluster_method is None) or (cluster_method == 'None'):
                        save_attn_path = os.path.join(save_cluster_path, 'baseline') # change here!
                    else:
                        save_attn_path = os.path.join(save_cluster_path, 'ssc4') # change here!
                    
                    if False:
                        save_question_to_file(save_attn_path, question)
                    
                    if layer_wise_scale is None:
                        response, history = self.model.chat(self.tokenizer, pixel_values, question, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True, \
                                                            alpha_q=alpha_q, alpha_k=alpha_k, alpha_v=alpha_v, num_classes_total=num_classes_total, num_classes_selected=num_classes_selected, \
                                                            pca_rank=pca_rank, cluster_method=cluster_method, rho=rho, eps=eps, save_cluster_path=save_cluster_path)
                    else:
                        response, history = self.model.chat(self.tokenizer, pixel_values, question, gen_kwargs, num_patches_list=num_patches_list, history=None, return_history=True, \
                                                            alpha_q=alpha_q, alpha_k=alpha_k, alpha_v=alpha_v, num_classes_total=num_classes_total, num_classes_selected=num_classes_selected, \
                                                            pca_rank=pca_rank, cluster_method=cluster_method, rho=rho, eps=eps, layer_wise_scale=layer_wise_scale, \
                                                            boost_layer=boost_layer, save_cluster_path=save_cluster_path)
            else:
                response, history = self.model.chat(self.tokenizer, None, contexts, gen_kwargs, num_patches_list=None, history=None, return_history=True)
            
            if False:
                save_response_to_file(save_attn_path, response)
                
            res.append(response)
            pbar.update(1)
        pbar.close()
        return res

    def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
        assert False, "Not implemented yet."

        
def save_tensor_as_image(pixel_values, directory):
    os.makedirs(directory, exist_ok=True)  # 确保目录存在

    for i, img_tensor in enumerate(pixel_values):  # 逐帧处理
        
        img_tensor = img_tensor.cpu() * torch.tensor(IMAGENET_STD).view(3, 1, 1) + torch.tensor(IMAGENET_MEAN).view(3, 1, 1)  # 反归一化
        img_tensor = torch.clamp(img_tensor, 0, 1)  # 限制范围在 [0,1]

        # 转换为 PIL 图像格式
        transform = transforms.ToPILImage()
        img = transform(img_tensor)

        # 保存图像
        img.save(os.path.join(directory, f"frame{i}.png"))
        
        
def save_question_to_file(directory, question):
    os.makedirs(directory, exist_ok=True)  # 确保目录存在
    file_path = os.path.join(directory, "question_nopreprompt.txt")  # 目标文件路径
    
    with open(file_path, "w", encoding="utf-8") as f:
        f.write(question)  # 写入问题文本
        
        
def save_gt_to_file(directory, gt):
    os.makedirs(directory, exist_ok=True)  # 确保目录存在
    file_path = os.path.join(directory, "gt.txt")  # 目标文件路径
    
    with open(file_path, "w", encoding="utf-8") as f:
        f.write(gt)  # 写入问题文本
        
        
def save_response_to_file(directory, response):
    os.makedirs(directory, exist_ok=True)  # 确保目录存在
    file_path = os.path.join(directory, "response_nopreprompt.txt")  # 目标文件路径
    
    with open(file_path, "w", encoding="utf-8") as f:
        f.write(response)  # 写入问题文本
        
        
# **每个进程更新自己的映射**
def update_mapping(rank, directory, map_file):
    # **确保文件操作是安全的**
    for _ in range(10):  # 允许多次尝试，避免并发冲突
        try:
            # **读取已有映射**
            with open(map_file, "r") as f:
                rank_directory = json.load(f)

            # **更新当前 rank 的映射**
            rank_directory[str(rank)] = directory

            # **写回文件**
            with open(map_file, "w") as f:
                json.dump(rank_directory, f)

            break  # 成功写入后退出循环
        except Exception as e:
            print(f"Rank {rank}: File write conflict, retrying... {e}")

