import copy
import json
import math
import os
import random
import re
import ast
from typing import Dict
from PIL import Image, ImageDraw
import torch
import transformers
import yaml
from qwen_vl_utils import smart_resize, process_vision_info
from torch.utils.data import Dataset

from uav_vln.constants import (
    IGNORE_INDEX,
    DEFAULT_IMAGE_TOKEN,
    assistant_template,
    chat_template,
    system_message,
)
from uav_vln.trainer import rank0_print

def resize_coords_in_text(text, resize_ratio):
    pattern = re.compile(r'\[(\s*[\d.]+)\s*,\s*([\d.]+)(?:\s*,\s*([\d.]+)\s*,\s*([\d.]+))?\s*\]')
    
    def repl(match):
        nums = match.groups()
        nums = [int(n) for n in nums if n is not None]
        resized = [int(n // resize_ratio) for n in nums]
        return '[' + ', '.join(str(r) for r in resized) + ']'
    
    new_text = pattern.sub(repl, text)
    return new_text

class LazySupervisedDataset(Dataset):
    def __init__(
        self,
        tokenizer: transformers.PreTrainedTokenizer,
        processor: transformers.ProcessorMixin,
        data_path: str,
        data_args,
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.processor = processor
        self.list_data_dict = []

        self.resize_ratio = data_args.resize_ratio
        self.list_data_dict = json.load(open(data_path, 'r'))
        self.teacher_map_path = data_args.teacher_map_path
        self.neg_map_path = data_args.neg_map_path
        
        for data in self.list_data_dict:
            conversations = data["conversations"]
            for i in range(1, 3):
                cur_text = conversations[i]["value"]
                new_text = resize_coords_in_text(cur_text, data_args.resize_ratio)
                conversations[i]["value"] = new_text
                

        rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
        rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.data_args = data_args

    def __len__(self):
        return len(self.list_data_dict)

    @property
    def lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            img_tokens = (
                1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
            )
            length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
        return length_list

    @property
    def modality_lengths(self):
        length_list = []
        for sample in self.list_data_dict:
            cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
            assert cur_len > 0, f"Conversation length is 0 for {sample}"

            img_tokens = (
                1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
            )

            if "image" in sample or "video" in sample or self.data_args.early_mix_text:
                length_list.append(cur_len + img_tokens)
            else:
                length_list.append(-cur_len)
        return length_list

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sample = self._get_item(i)
        if sample is None:
            new_index = random.randint(0, len(self.list_data_dict) - 1)
            return self.__getitem__(new_index)
        else:
            return sample
        try:
            sample = self._get_item(i)
            if sample is None:
                new_index = random.randint(0, len(self.list_data_dict) - 1)
                return self.__getitem__(new_index)
        except Exception as e:
            print(f"Failed to fetch sample {i}. Exception:", e)
            new_index = random.randint(0, len(self.list_data_dict) - 1)
            return self.__getitem__(new_index)
        return sample

    def _get_item(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        image_path = sources['image'][0]
        image_name = os.path.basename(image_path).split(".")[0]
        teacher_map_path = f"{self.teacher_map_path}/{image_name}.pt" if self.teacher_map_path else None
        neg_map_path = f"{self.neg_map_path}/{image_name}.pt" if self.neg_map_path else None
        orig_img = Image.open(image_path).convert("RGB")

        # Downscale the original image to reduce GPU memory usage
        width, height = orig_img.size
        resized_img = orig_img.resize((width // self.resize_ratio, height // self.resize_ratio))
        
        image_list = [resized_img]

        sources = sources['conversations']
        data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list)

        if teacher_map_path and os.path.exists(teacher_map_path):
            teacher_map = torch.load(teacher_map_path, map_location="cpu", weights_only=True)
            teacher_map = teacher_map.squeeze(0)
        else:
            teacher_map = None

        if neg_map_path and os.path.exists(neg_map_path):
            neg_map = torch.load(neg_map_path, map_location="cpu", weights_only=True)
        else:
            neg_map = None

        data_dict = {
            "input_ids": data_dict["input_ids"][0],
            "labels": data_dict["labels"][0],
            "pixel_values": data_dict.get("pixel_values"),
            "image_grid_thw": data_dict.get("image_grid_thw"),
            "image_token_index": data_dict.get("image_token_index"),
            "text_token_index": data_dict.get("text_token_index"),
            "teacher_map": teacher_map,
            "neg_map": neg_map,
            "image_path": image_path,
        }

        return data_dict


    def preprocess_qwen2vl(
        self,
        source, # conversations
        tokenizer: transformers.PreTrainedTokenizer,
        processor: transformers.ProcessorMixin,
        image: list,
        system_message: str = system_message,
        agent_mode: bool = True,
        chat_template: str = chat_template,
        assistant_template: str = assistant_template,
    ) -> Dict:
        roles = {"human": "user", "gpt": "assistant", "system": "system"}
        assistant_template = assistant_template if agent_mode else chat_template
        processor.tokenizer = tokenizer

        # Apply prompt templates
        pixel_values, image_grid_thw = None, None

        input_id, target = [], []
        # coordinates = []
        # visual_token_indices_of_coordinates = []
        # multi_patch_labels = []
        
        image_list = []
        image_index = 0

        ## prepare the system message
        if roles[source[0]["from"]] == "system":
            system_message = source[0]["value"]
            source = source[1:self.data_args.max_conv_turns]
        # else: use the constant system message
        system_input_id = tokenizer.apply_chat_template(
            conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}],
            chat_template=chat_template,
        )
        input_id += system_input_id
        target += [IGNORE_INDEX] * len(system_input_id)

        ## prepare user-assistant conversation
        for conv in source:
            # regularize the conversation format
            try:
                role = conv["role"]
                content = conv["content"]
            except Exception:
                role = conv["from"]
                content = conv["value"]
            role = roles.get(role, role)

            # Count the number of <image> tokens in the content
            image_count = content.count(DEFAULT_IMAGE_TOKEN)
            if image_count > 0:
                assert role == "user", "Images are only supported for user messages"
                # include image information for the current conversation turn
                image_placeholders = []
                for _ in range(image_count):
                    image_placeholders.append({
                        "type": "image",
                        "image": image[image_index],
                        "min_pixels": self.processor.image_processor.min_pixels,
                        "max_pixels": self.processor.image_processor.max_pixels,
                    })
                    image_index += 1

                content = content.replace(DEFAULT_IMAGE_TOKEN, "")
                conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]}

                image_inputs, _ = process_vision_info([conv]) # list of PIL.Image.Image
                image_list.extend(image_inputs)
                
                templated_conv = tokenizer.apply_chat_template(
                    conversation=[conv], chat_template=chat_template, tokenize=False
                )
                inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt")

                if pixel_values is None and image_grid_thw is None:
                    pixel_values = inputs["pixel_values"]
                    image_grid_thw = inputs["image_grid_thw"]
                else:
                    pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0)
                    image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0)
            else:
                if role in ["user", "system"]:
                    conv = {"role": role, "content": [{"type": "text", "text": content}]}
                else:  # assistant
                    conv = {
                        "role": role,
                        "content": [{"type": "text", "text": content}],
                    }

                templated_conv = tokenizer.apply_chat_template(
                    conversation=[conv],
                    chat_template=assistant_template,
                    tokenize=False,
                )
                inputs = processor(text=[templated_conv], return_tensors="pt")

            encode_id = inputs.input_ids[0].tolist()

            input_id += encode_id
            if role in ["user", "system"]:
                target += [IGNORE_INDEX] * len(encode_id)
            else:
                target += encode_id

        assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"

        input_ids = torch.tensor([input_id], dtype=torch.long)
        targets = torch.tensor([target], dtype=torch.long)


        bs = input_ids.size(0)
        image_token_index = []
        text_token_index = []
        image_pad_id = 151655
        # search for image token index
        for i in range(bs):
            positions = torch.where(input_ids[i] == image_pad_id)[0]
            if len(positions) > 0:
                start_idx = positions.min().item()
                end_idx = positions.max().item()
                image_token_index.append([start_idx, end_idx])
            else:
                image_token_index.append([-1, -1])  # not found

        data_dict = {
            "input_ids": input_ids,  # tensor(bs x seq_len)
            "labels": targets,  # tensor(bs x seq_len)
        }
        
        # search for text token index

        anchor1 = "[Mission Objective]"
        anchor2 = "[Output Format Specification]"

        # Convert to token id list (without CLS/special tokens)
        anchor1_ids = tokenizer(anchor1, add_special_tokens=False)["input_ids"]
        anchor2_ids = tokenizer(anchor2, add_special_tokens=False)["input_ids"]
        
        def find_sublist_positions(tensor_row, sublist):
            """Find the starting index of the first occurrence of sublist in tensor_row"""
            for i in range(len(tensor_row) - len(sublist) + 1):
                if tensor_row[i:i+len(sublist)] == sublist:
                    return i
            return -1

        for j in range(bs):
            row_ids = input_ids[j].tolist()

            pos1 = find_sublist_positions(row_ids, anchor1_ids)
            pos2 = find_sublist_positions(row_ids, anchor2_ids)

            if pos1 != -1 and pos2 != -1:
                start_idx = pos1 + len(anchor1_ids)  # start after anchor1
                end_idx = pos2 - 1                   # end before anchor2
                between_segment = [start_idx, end_idx]
            else:
                between_segment = [-1, -1]
            text_token_index.append(between_segment)
        

        if pixel_values is not None:
            data_dict["pixel_values"] = pixel_values
            data_dict["image_grid_thw"] = image_grid_thw
        data_dict["image_token_index"] = image_token_index[0]
        data_dict["text_token_index"] = text_token_index[0]
        return data_dict
