import os
import json
import pandas as pd
from torch.utils.data import Dataset
import pickle
import random
import cv2
import albumentations as A
import numpy as np 
import torch 
import copy
from PIL import Image

from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info



class DatasetRegistryInstruct:
    # This dictionary acts as the central registry where we map dataset names to classes
    _registry = {}

    @classmethod
    def register(cls, name):
        """
        Registers a dataset class with a given name.
        The `name` will be the key used to retrieve the dataset class.
        """
        def inner_wrapper(wrapped_class):
            # Add the class to the registry with the provided name
            cls._registry[name] = wrapped_class
            return wrapped_class
        return inner_wrapper

    @classmethod
    def get_dataset(cls, name, *args, **kwargs):
        """
        Retrieves the dataset class based on the provided name.
        If the name is not found, raises a ValueError.
        """
        dataset_class = cls._registry.get(name)
        if dataset_class is None:
            raise ValueError(f"Dataset {name} is not registered.")
        return dataset_class(*args, **kwargs)  # Instantiate the class with args and kwargs



class ImageTextDatasetInstruct(Dataset):
    def __init__(self, config, dset_name, phase):
        self.config = config
        self.processor = AutoProcessor.from_pretrained(
                self.config.LLM.model_name, 
                cache_dir=self.config.LLM.cache_dir,  
                force_download=False, 
                device_map="auto"
            )
        self.seq_length = config.LLM.seq_length
        self.dset_name = dset_name
        self.phase = phase

        # Load data (to be defined in the subclass)
        self.data = self.load_data()

        # Get token IDs
        if "qwen" in self.config.LLM.model_name:
            self.im_start_id = self.processor.tokenizer.encode("<|im_start|>", add_special_tokens=False)[0]
            self.assistant_id = self.processor.tokenizer.encode("assistant", add_special_tokens=False)[0]
            self.vision_start_id = self.processor.tokenizer.encode("<|vision_start|>", add_special_tokens=False)[0]
            self.vision_end_id = self.processor.tokenizer.encode("<|vision_end|>", add_special_tokens=False)[0]
            self.vision_pad_id = self.processor.tokenizer.encode("<|vision_pad|>", add_special_tokens=False)[0]
            self.image_pad_id = self.processor.tokenizer.encode("<|image_pad|>", add_special_tokens=False)[0]

            self.label_mask_ids = {
                # self.im_start_id, 
                # self.assistant_id, 
                self.vision_start_id,
                self.vision_end_id,
                self.vision_pad_id,
                self.image_pad_id,
                self.processor.tokenizer.pad_token_id,
            }
            self.label_mask_ids = {int(x) for x in self.label_mask_ids} 

            print(f'{self.im_start_id=}')
            print(f'{self.assistant_id=}')
            print(f'{self.vision_start_id=}')
            print(f'{self.vision_end_id=}')
            print(f'{self.vision_pad_id=}')
            print(f'{self.image_pad_id=}')
            print(self.processor.tokenizer.pad_token_id)
            # print(self.processor.tokenizer.special_tokens_map)
            # assert False 
        
        if "gemma" in self.config.LLM.model_name:
            self.im_start_id = self.processor.tokenizer.encode("<start_of_image>", add_special_tokens=False)[0]
            self.im_end_id = self.processor.tokenizer.encode("<end_of_image>", add_special_tokens=False)[0]
            self.pad_id = self.processor.tokenizer.encode("<pad>", add_special_tokens=False)[0]
            self.bos_id = self.processor.tokenizer.encode("<bos>", add_special_tokens=False)[0]
            self.eos_id = self.processor.tokenizer.encode("<eos>", add_special_tokens=False)[0]
            self.im_soft_id = self.processor.tokenizer.encode("<image_soft_token>", add_special_tokens=False)[0]

            self.label_mask_ids = {
                self.im_start_id,
                self.im_end_id,
                self.pad_id,
                self.bos_id,
                self.eos_id,
                self.im_soft_id,
            }
            self.label_mask_ids = {int(x) for x in self.label_mask_ids} 


    def __getitem__(self, idx):

        if "qwen" in self.config.LLM.model_name.lower():
            return self._getitem_qwen(idx)
        elif "gemma" in self.config.LLM.model_name.lower():
            return self._getitem_gemma3(idx)
        else:
            raise ValueError


    def _getitem_gemma3(self, idx):
        item_data = self.get_item_data(idx)

        img_path = item_data['img_path']
        question = item_data['question']
        answer = item_data['answer']
        ans_type = item_data['ans_type']
        mc_options = item_data['mc_options']


        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img_path},
                    {"type": "text", "text": question}
                ]
            },
            {"role": "assistant", "content": [{"type": "text", "text": answer}]},
        ]

        inputs = self.processor.apply_chat_template(
            messages, add_generation_prompt=False, tokenize=True,
            return_dict=True, return_tensors="pt"
        )

        messages_prompt_only = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": img_path},
                    {"type": "text", "text": question}
                ]
            },
        ]

        prompt = self.processor.apply_chat_template(
            messages_prompt_only, add_generation_prompt=True, tokenize=True,
            return_dict=True, return_tensors="pt"
        )

        inputs["input_ids"] = inputs["input_ids"][0]
        inputs["labels"] = torch.tensor([-100 if int(token) in self.label_mask_ids else token for token in inputs["input_ids"]])
        inputs["prompt"] = prompt["input_ids"][0]
        inputs["attention_mask"] = inputs["attention_mask"][0]
        inputs["pixel_values"] = inputs["pixel_values"][0]


        inputs["dset_name"] = self.dset_name
        if ans_type is not None:
            inputs["ans_type"] = ans_type
        if mc_options is not None:
            inputs["mc_options"] = mc_options
        
        return inputs 
            


    def _getitem_qwen(self, idx):

        item_data = self.get_item_data(idx)

        img_path = item_data['img_path']
        question = item_data['question']
        answer = item_data['answer']
        ans_type = item_data['ans_type']
        mc_options = item_data['mc_options']


        messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": img_path, "resized_height": self.config.image_size[0],"resized_width": self.config.image_size[1],},
                        {"type": "text", "text": question},
                    ],
                },
                {"role": "assistant", "content": answer},
            ]


        # Preparation for inference
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False,
        )
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = self.processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding="max_length",
            max_length=self.seq_length,
            return_tensors="pt",
            truncation=True,
        )

        messages_prompt_only = [
                {"role": "system", "content": "You are a helpful assistant."},
                {
                    "role": "user",
                    "content": [
                        {"type": "image", "image": img_path, "resized_height": self.config.image_size[0],"resized_width": self.config.image_size[1],},
                        {"type": "text", "text": question},
                    ],
                },
            ]


        # Preparation for inference
        text_prompt_only = self.processor.apply_chat_template(
            messages_prompt_only, tokenize=False, add_generation_prompt=True,
        )
        image_inputs, video_inputs = process_vision_info(messages)
        prompt = self.processor(
            text=[text_prompt_only],
            images=image_inputs,
            videos=video_inputs,
            # padding="max_length",
            # max_length=self.seq_length,
            return_tensors="pt",
            truncation=True,

        )

        inputs["input_ids"] = inputs["input_ids"][0]
        inputs["labels"] = torch.tensor([-100 if int(token) in self.label_mask_ids else token for token in inputs["input_ids"]])
        inputs["prompt"] = prompt["input_ids"][0]

        inputs["attention_mask"] = inputs["attention_mask"][0]
        inputs["image_grid_thw"] = inputs["image_grid_thw"][0]

        inputs["dset_name"] = self.dset_name
        if ans_type is not None:
            inputs["ans_type"] = ans_type
        if mc_options is not None:
            inputs["mc_options"] = mc_options
        
        return inputs 
    
    def __len__(self):
        return len(self.data)

    def load_data(self):
        """Method to load dataset data. Implement in subclass."""
        raise NotImplementedError

    def get_item_data(self, idx):
        """Method to return image path and text data. Implement in subclass."""
        raise NotImplementedError


@DatasetRegistryInstruct.register('rad-vqa-Instruct')
class RADVQADatasetInstruct(ImageTextDatasetInstruct):
        
    def load_data(self):
        if self.phase == "train":
            with open(os.path.join(self.config.rad_vqa.main_path, 'train_data.json')) as f:
                return json.load(f)

        elif self.phase == "valid":
            with open(os.path.join(self.config.rad_vqa.main_path, 'valid_data.json')) as f:
                return json.load(f)
        else:
            raise ValueError('Invalid phase')

    def get_item_data(self, idx):
        img_path = os.path.join(self.config.rad_vqa.image_main_path, self.data[idx]["image_name"])
        question = str(self.data[idx]["question"]).lower().strip()
        answer = str(self.data[idx]["answer"]).lower().strip()
        ans_type = self.data[idx]['answer_type'].lower().strip()

        return {
            'img_path':img_path,
            'question': question,
            'answer': answer,
            'ans_type': ans_type,
            'mc_options': None 
        }

