import os
import random
import torch
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from .text_render import get_encoder_render_image_mask
import datasets
from torch.utils.data import random_split

import dataclasses
import logging
import math
import os
import io
import sys
import time
import json
import os.path as osp
from typing import Optional, Sequence, Union, Dict
from torch.utils.data import SubsetRandomSampler
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Subset

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from torch.nn.utils.rnn import pad_sequence
import tqdm
import copy
datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory='.': True 

def find_subsequence_indices(full_ids, pattern_ids):
    for i in range(len(full_ids) - len(pattern_ids) + 1):
        if full_ids[i : i + len(pattern_ids)] == pattern_ids:
            return i
    return -1

PROMPT_DICT = {
    "prompt_input": (
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}

def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode)
    return f

def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

def convert_conversations_to_instruction_format(conversations):
    results = []
    sources = []
    targets = []
    for idx, sample in enumerate(conversations):
        if not isinstance(sample, list) or len(sample) not in [2, 3]:
            continue

        if len(sample) == 3:
            roles = [turn.get("from") for turn in sample]
            if roles != ["system", "human", "gpt"]:
                continue
            instruction = sample[0]["value"]
            input_text = sample[1]["value"]
            output_text = sample[2]["value"]

        elif len(sample) == 2:
            roles = [turn.get("from") for turn in sample]
            if roles != ["human", "gpt"]:
                print("🔍🔍🔍🔍!= [human, gpt]:", sample)
                # print("⚡ ⚡ ⚡ ⚡ error !!! ⚡⚡ ⚡ ⚡ ")
                continue
            instruction = sample[0]["value"]
            input_text = ""
            output_text = sample[1]["value"]

        results.append({
            "instruction": instruction,
            "input": input_text,
            "output": output_text
        })

    return results


class OpenHermesDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, data_path: str, vl_chat_processor, max_length=4096, 
                        font_size = 7,
                        font_path = None,
                        n_parts = 1,
                        clip_token_num=576,
                        cache_dir=None,
                        pure_text = False,
                        raw_dataset=None,
                    ):
        super(OpenHermesDataset, self).__init__()
        self.vl_chat_processor = vl_chat_processor
        self.tokenizer = vl_chat_processor.tokenizer
        self.max_length = max_length
        self.n_parts = n_parts
        self.font_size = font_size
        self.font_path = font_path
        self.clip_token_num = clip_token_num
        self.pure_text = pure_text
        
        prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 
        filtered_data = []
        conversations = [sample["conversations"] for sample in raw_dataset]
        print("⚡ ⚡ ⚡ ⚡ conversations[0]", conversations[0])
        self.list_data_dict = convert_conversations_to_instruction_format(conversations)
        print("🔧 🔧 filter by len before: ", len(self.list_data_dict))
        for example in self.list_data_dict:
            input_len = len(example.get("input", ""))
            output_len = len(example.get("output", ""))
            instruction_len = len(example.get("instruction", ""))

            if input_len + instruction_len <=2000 and output_len <=1200:
                filtered_data.append(example)

        self.list_data_dict = filtered_data
        print("🔧 🔧 🔧filter by len after: ", len(self.list_data_dict))

        self.sources = [
            prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
            for example in self.list_data_dict
        ]
        self.targets = [f"{example['output']}" for example in self.list_data_dict]
        # self.targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in self.list_data_dict]

    def _tokenize_fn(self, strings: Sequence[str], tokenizer) -> Dict:
        """Tokenize a list of strings."""
        tokenized_list = [
            tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                max_length=tokenizer.model_max_length,
                truncation=True,
            )
            for text in strings
        ]
        input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
        input_ids_lens = labels_lens = [
            tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
        ]
        return dict(
            input_ids=input_ids,
            labels=labels,
            input_ids_lens=input_ids_lens,
            labels_lens=labels_lens,
        )

    def preprocess(
        self,
        sources: Sequence[str],
        targets: Sequence[str],
        tokenizer,
    ) -> Dict:
        """Preprocess the data by tokenizing."""
        examples = [s + t for s, t in zip(sources, targets)]
        examples_tokenized, sources_tokenized = [self._tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
        input_ids = examples_tokenized["input_ids"]
        labels = copy.deepcopy(input_ids)
        for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
            label[:source_len] = -100
        return dict(input_ids=input_ids, labels=labels)

    def __len__(self):
        print(" 🔁 🔁dataset size: ", len(self.list_data_dict))
        return len(self.list_data_dict)
    
    def process_single_sample(self, messages):
       
        full_text = self.vl_chat_processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False
        ) 

    
        image_inputs, video_inputs = process_vision_info(messages)

       
        model_inputs = self.vl_chat_processor(
            text=[full_text],
            images=image_inputs,
            videos=video_inputs,
            return_tensors="pt",
            padding=True,
        )
        input_ids = model_inputs["input_ids"][0]  # [seq_len]
       
        assistant_prefix = "<|im_start|>assistant\n"
        assistant_ids = self.vl_chat_processor.tokenizer(
            assistant_prefix, add_special_tokens=False
        ).input_ids

        start_index = find_subsequence_indices(input_ids.tolist(), assistant_ids)
        if start_index == -1:
            raise ValueError("Failed to locate assistant prefix in input_ids.")

        labels = input_ids.clone()
        labels[: start_index + len(assistant_ids)] = -100  

        sample = {}
        for k, v in model_inputs.items():
            if isinstance(v, torch.Tensor):
               
                if v.dim() >= 2 and k != "pixel_values" and k != "pixel_values_videos":
                    sample[k] = v[0]
                else:
                    sample[k] = v
            else:
                sample[k] = v
        sample["labels"] = labels

        return sample

    def __getitem__(self, i):
        input_text = self.sources[i]
        target_text = self.targets[i]
        if self.pure_text == False:
            input_text_image, input_text_image_mask, ori_texts = get_encoder_render_image_mask(encoder_token_id=None, tokenizer=self.tokenizer, encoder_text = input_text, n_parts = self.n_parts, font_size = self.font_size, font_path = self.font_path, clip_token_num=self.clip_token_num, image_size=(224*16,14), add_instruction=None)
            instruction = '''Let's anwser the question in the image. '''
            messages_textimage = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": input_text_image[0],
                        },
                        {"type": "text", "text": f"{instruction}"},
                    ],
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": f"{target_text}"},
                    ],
                }
            ]
            inputs_text_image = self.process_single_sample(messages_textimage)
        else:
            messages_puretext = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": f"{input_text}"},
                    ],
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": f"{target_text}"},
                    ],
                }
            ]
            inputs_pure_text = self.process_single_sample(messages_puretext)
        if self.pure_text == False:
            return {
                "inputs_text_image": inputs_text_image,
            }
        else:
            return {
                "inputs_pure_text": inputs_pure_text,
            }
        

class OpenHermesDataModule(pl.LightningDataModule):
    def __init__(self, data_path, vl_chat_processor, 
                        batch_size=4, 
                        max_length=4096, 
                        font_size = 7,
                        font_path = None,
                        n_parts = 1,
                        clip_token_num=576,
                        cache_dir=None,
                        pure_text=False,
                ):
        super().__init__()
        self.data_path = data_path
        self.vl_chat_processor = vl_chat_processor
        self.tokenizer = vl_chat_processor.tokenizer
        self.batch_size = batch_size
        self.max_length = max_length
        self.n_parts = n_parts
        self.font_size = font_size
        self.font_path = font_path
        self.clip_token_num = clip_token_num
        self.cache_dir = cache_dir
        self.pure_text = pure_text
    

    def setup(self, stage=None):
        self.raw_dataset = load_dataset(self.data_path, split="train")
        self.train_dataset = OpenHermesDataset(
            data_path=self.data_path,
            vl_chat_processor=self.vl_chat_processor,
            max_length=self.max_length,
            font_size=self.font_size,
            font_path=self.font_path,
            n_parts=self.n_parts,
            clip_token_num=self.clip_token_num,
            cache_dir=self.cache_dir,
            pure_text=self.pure_text,
            raw_dataset=self.raw_dataset,
        )

    def train_dataloader(self):
        print("tarin len: ", len(self.train_dataset))
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,  
            num_workers=4,
            collate_fn=self.custom_collate,
        )

    @staticmethod
    def collate_fields(batch_list):
       
        collated = {}
        keys = batch_list[0].keys()

        for key in keys:
            values = [sample[key] for sample in batch_list]

            if isinstance(values[0], torch.Tensor):
              
                if key in ["input_ids", "attention_mask", "labels"]:
                    pad_value = 0 if key != "labels" else -100
                    collated[key] = pad_sequence(values, batch_first=True, padding_value=pad_value)
                elif key == "pixel_values":
                    collated[key] = torch.cat(values)

                elif key in ["pixel_values_videos", "image_grid_thw", "video_grid_thw"]:
                    collated[key] = torch.stack(values)

                else:
                    collated[key] = torch.stack(values)

        return collated

    @staticmethod
    def custom_collate(batch):
      
        if self.pure_text == False:
            batch_text_image = [sample["inputs_text_image"] for sample in batch]
            collated_text_image = OpenHermesDataModule.collate_fields(batch_text_image)
            return {
                "inputs_text_image": collated_text_image,
            }
        else:
            batch_pure_text = [sample["inputs_pure_text"] for sample in batch]
            collated_pure_text = OpenHermesDataModule.collate_fields(batch_pure_text)
            return {
                "inputs_pure_text": collated_pure_text,
            }
