import os
from typing import Dict

import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_from_disk, load_dataset, DatasetDict
from transformers import BatchFeature
from ..modules.load_pretrained import load_tokenizer

import PIL
import pickle
from absl import logging


class MLLMDataset(Dataset):
    def __init__(self, dataset, config, tokenizer, drf_aux_tokenizer, drf_image_processor):
        self.dataset = dataset  # The Hugging Face dataset is passed directly
        self._config = config
        self.tokenizer = tokenizer
        self.drf_aux_tokenizer = drf_aux_tokenizer
        self.drf_image_processor = drf_image_processor
        if 'caption' in config['drafting']:
            self.captioning_image_processor, _, _ = load_tokenizer(self._config, None, 'captioning_model')
        
        self.image_token = "<image>"
        self.example = self._get_example_data()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        return self._process_input(sample)

    def collate_fn(self, processed_samples: list[Dict[str, torch.Tensor]]):
        # Assuming the batch is a list of processed inputs
        if len(processed_samples) == 1:
            return BatchFeature(processed_samples[0])
        
        # Refer tutorials/llava_batch.py: multi-image & batch>1 => just concat along image's batch axis 
        text_list = [sample['prompt'] for sample in processed_samples]
        batched_raw_input = dict(
            text = text_list,
            padding = True,
            return_tensors = "pt",
        )

        if self._config['is_drf_from_mllm'] or self._config['eval_datasets'] is None:
            image_list = sum([sample['images'] for sample in processed_samples], []) # merge list of lists
            batched_raw_input.update({'images': image_list})

        batched_samples = self.tokenizer(
            batched_raw_input
        )
        
        return BatchFeature(batched_samples)

    def _process_input(self, sample):
        processed_input = {}
        
        # Load image and process prompt
        images = self._get_image_data(sample)
        is_no_img = self._config['is_drf_text_only'] and self._config['is_tgt_text_only']
        if self._config['is_drf_from_mllm'] or self._config['eval_datasets'] is None:
            pixel_values = self.drf_image_processor(images, return_tensors="pt")["pixel_values"] if not is_no_img else None
            processed_input.update({'pixel_values': pixel_values})
            if hasattr(self, 'captioning_image_processor'):
                inputs_caption = {}
                if ('caption' in self._config['drafting']) and "lorence-2" in self._config['captioning_model']: # microsoft/Florence-2
                    inputs_caption_processed = self.captioning_image_processor(text=self._config['caption_type'], images=images, return_tensors="pt")
                else:
                    inputs_caption_processed = self.captioning_image_processor(images=images, return_tensors="pt")
                
                for _key in inputs_caption_processed.keys():
                    if _key in ['pixel_values', 'input_ids']:
                        inputs_caption[_key + '_caption'] = inputs_caption_processed[_key]
                
                processed_input.update(inputs_caption)

            if self._config['batch_size'] > 1:
                processed_input.update({'images': images})

        # Apply prompt tuning (only takes the first conversation)
        prompt =  self._tune_prompt(sample)
        processed_input.update(self.tokenizer(prompt, return_tensors="pt"))

        if not self._config['is_drf_from_mllm']:
            # batch 1 only
            aux_inputs = {f"aux_{k}": v for k, v in self.drf_aux_tokenizer(prompt, return_tensors="pt").items()}
            processed_input.update(aux_inputs)

        if self._config['batch_size'] > 1:
            processed_input.update({'prompt': prompt})

        return processed_input
    
    def _get_example_data(self):
        if not hasattr(self, 'batch_example'):
            if self._config['dataset'] == "ScienceQA":
                logging.info("[Dataset] Loading batch_example from disk for a single shot")
                path_example = os.path.join(self._config['input_datasets_dir'], "example/batch_example.pkl")
                with open(path_example, 'rb') as f:
                    example = pickle.load(f)
            else:
                example = None
        return example

    def _get_image_data(self, batch):
        images = []
        # Add example image if available
        if self.example is not None:
            images.append(self.example['image'].convert('RGB'))
        
        # Cases: given path vs. PIL image
        # multi-image
        if self._config['dataset'] in ['Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'HQ-Edit', 'MagicBrush', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST', 'WebQA', 'QBench', 'NLVR2_Mantis', 'OCR-VQA']:
            cols_img = [col for col in batch.keys() if col.startswith('image_')]
            images = [batch[col].convert('RGB') for col in cols_img if batch[col] is not None]
        
        # single-image - OCR
        elif self._config['dataset'] in ['LiveBench']:
            images = [image.convert('RGB') for image in batch['images']]
            
        # single-image - ELSE
        else:
            if isinstance(batch['image'], str): # batch['image']: os.PathLike == "XXX.jpg"
                path_image = os.path.join(self._config['input_datasets_dir'], 'images', batch['image'])
                image = PIL.Image.open(path_image)
            else: # if isinstance(batch['image'], PIL.PngImagePlugin.PngImageFile):
                image = batch['image']
            images.append(image.convert('RGB')) # 'RGB' to prevent dimension issues
        
        return images  
    
    def _get_text_data(self, batch):
        if self._config['dataset'] == "LLaVA-Instruct-150K":
            return batch["conversations"][0]['value']
        elif self._config['dataset'] == "COCO2014":
            return f"Provide a detailed description of the given image."
        elif self._config['dataset'] in ['chartqa', 'docvqa_val', 'infovqa_val', 'ok_vqa_val2014', 'textvqa_val', 'vizwiz_vqa_val', 'vqav2_val', 'QBench', 'NLVR2_Mantis', 'OCR-VQA', 'MMVet', 'POPE', 'HallusionBench']:
            return f"For the following question, provide a detailed explanation of your reasoning leading to the answer.\n{batch['question']}"
        elif self._config['dataset'] == "ScienceQA":
            pass
        elif self._config['dataset'] in ["VibeEval", "DC100_EN", "LLaVA-Bench-Wilder"]:
            _map_key = {
                "VibeEval": 'prompt', 
                "DC100_EN": 'question',
                "LLaVA-Bench-Wilder": 'Question',
            }
            return batch[_map_key[self._config['dataset']]]
        elif self._config['dataset'] in ['llava-bench-in-the-wild', 'Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'HQ-Edit', 'MagicBrush', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST', 'WebQA', 'LiveBench']:
            return batch['question']
        else:
            raise ValueError(f"Dataset not supported: {self._config['dataset']}")
    
    def _tune_prompt(self, batch):
        # Apply prompt tuning
        if hasattr(self.tokenizer, 'image_processor'):
            # Get raw prompt
            if self._config['dataset'] in ['LLaVA-Instruct-150K']:
                prompt_raw = self._get_text_data(batch)
                conversation = [{
                    "role": "user",
                    "content": [{"type": "text", "text": prompt_raw}],
                }]
            elif self._config['dataset'] in ["llava-bench-in-the-wild", "COCO2014", "VibeEval", "DC100_EN", "LLaVA-Bench-Wilder", "LiveBench", "chartqa", "docvqa_val", "infovqa_val", "ok_vqa_val2014", "textvqa_val", "vizwiz_vqa_val", "vqav2_val", "MMVet", "POPE", "HallusionBench"]:
                # add image token separately
                prompt_raw = self._get_text_data(batch)
                conversation = [{
                    "role": "user",
                    "content": [{"type": "image"}, 
                                {"type": "text", "text": prompt_raw}],
                }]
            elif self._config['dataset'] in ['Spot-the-Diff', 'Birds-to-Words', 'CLEVR-Change', 'HQ-Edit', 'MagicBrush', 'IEdit', 'AESOP','FlintstonesSV', 'PororoSV', 'VIST', 'WebQA', 'QBench', 'NLVR2_Mantis', 'OCR-VQA']:
                # image tokens are already included in the prompt
                prompt_raw = self._get_text_data(batch)
                conversation = [{
                    "role": "user",
                    "content": [{"type": "text", "text": prompt_raw}],
                }]
            elif self._config['dataset'] == 'ScienceQA': 
                conversation = self._get_conversation_scienceqa(batch)
            # multi image from lmms-eval
            else:
                raise ValueError(f"Dataset not supported: {self._config['dataset']}")
            if self.tokenizer.chat_template is None:
                if ('llava-hf' not in self._config['drf']) and self.tokenizer.chat_template is None:
                    self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] != 'system' %}{{ message['role'].upper() + ': '}}{% endif %}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% if message['role'] != 'assistant' %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['text'] + ' '}}{% endfor %}{% else %}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{% generation %}{{ content['text'] + ' '}}{% endgeneration %}{% endfor %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'ASSISTANT:' }}{% endif %}"
            resulting_prompt = self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True)

        else:
            resulting_prompt = self._get_text_data(batch)

        if "llava-hf/llava-interleave-qwen" in self._config['drf']:
            resulting_prompt = resulting_prompt.replace('user\n\n', 'user\n')
            resulting_prompt = resulting_prompt.replace('assistant\n\n', 'assistant\n')

        return resulting_prompt
    
    def _get_conversation_scienceqa(self, batch):
        # CoT
        is_template = True
        do_split_template_role = True
        has_user_only = False
        split_example = do_split_template_role and (not has_user_only)
        
        if not split_example:
            # Generate the first shot (with answer and explanation)
            example_shot = self._build_common_prompt(self.example, include_answer=True)

            # Generate the second shot (without the answer, for prediction)
            subsequence_prompt = self._build_common_prompt(batch, include_answer=False)

            prompt_augmented = (
                f"{example_shot}\n\n"
                f"{subsequence_prompt}"
            )
        else:
            # Generate the first shot (with answer and explanation)
            example_shot_1, example_shot_2 = self._build_common_prompt(self.example, include_answer=True, split_example=split_example)

            # Generate the second shot (without the answer, for prediction)
            subsequence_prompt_1, subsequence_prompt_2 = self._build_common_prompt(batch, include_answer=False, split_example=split_example)

        if is_template:
            if not do_split_template_role:
                conversation = [{
                    "role": "user",
                    "content": [{"type": "text", "text": prompt_augmented}],
                }]
            else:
                if has_user_only:
                    conversation = [
                        {
                            "role": "user",
                            "content": [{"type": "text", "text": example_shot}],
                        },
                        {
                            "role": "user",
                            "content": [{"type": "text", "text": subsequence_prompt}],
                        },
                    ]
                else:
                    conversation = [
                        {
                            "role": "user",
                            "content": [{"type": "text", "text": example_shot_1}],
                        },
                        {
                            "role": "assistant",
                            "content": [{"type": "text", "text": example_shot_2}],
                        },
                        {
                            "role": "user",
                            "content": [{"type": "text", "text": subsequence_prompt_1}],
                        },
                    ]
        return conversation
    
    def _build_common_prompt(self, batch, include_answer=True, split_example=False):
        keep_original = False
        # Extract relevant data from the batch
        question = batch['question']
        choices = batch['choices']
        hint = batch['hint']

        # Build the options part
        options = ' '.join([f"({chr(65 + i)}) {choice}" for i, choice in enumerate(choices)])

        if keep_original:
            if not split_example:
                # Start building the common prompt
                common_prompt = (
                    f"{self.image_token}\n"   
                    f"Question: {question}\n"
                    f"Options: {options}\n"
                    f"Context: {hint}\n"
                )

                # Add the answer and explanation if requested
                if include_answer:
                    answer = batch['choices'][batch['answer']]
                    lecture = batch['lecture']
                    solution = batch['solution']
                    common_prompt += f"Answer: The answer is {answer}. \nBECAUSE: {lecture} explanation: {solution}"
                else:
                    common_prompt += f"Answer: The answer is"

                return common_prompt
            else:
                # Start building the common prompt
                common_prompt_1 = (
                    f"{self.image_token}\n"   
                    f"Question: {question}\n"
                    f"Options: {options}\n"
                )
                common_prompt_2 = (
                    f"Context: {hint}\n"
                )
                
                if include_answer:
                    answer = batch['choices'][batch['answer']]
                    lecture = batch['lecture']
                    solution = batch['solution']
                    common_prompt_2 += f"Answer: The answer is {answer}. \nBECAUSE: {lecture} explanation: {solution}"
                else:
                    common_prompt_2 += f"Answer: The answer is"

                return common_prompt_1, common_prompt_2
        else:
            # Start building the common prompt
            instruction = "Based on the image, provide Reasoning and detailed Explanation behind the provided Answer to the Question."
            answer = batch['choices'][batch['answer']]
            common_prompt_1 = (
                f"{self.image_token}\n"   
                f"Question: {question}\n"
                f"Options: {options}\n"
                f"Answer: The answer is {answer}\n"
                f"{instruction}\n"
            )
            common_prompt_2 = ""
            
            if include_answer:
                lecture = batch['lecture']
                solution = batch['solution']
                # common_prompt_2 += f"Context: {hint}\n"
                common_prompt_2 += f"Reasoning: {lecture}\n"
                common_prompt_2 += f"Explanation: {solution}"
            else:
                pass
                # common_prompt_2 += f"Context:"

            return common_prompt_1, common_prompt_2


def load_datasets(config, tokenizer, drf_aux_tokenizer, drf_image_processor) -> Dict[str, Dataset]:
    # noqa
    if config['dataset'] == "LLaVA-Instruct-150K":
        path_dataset = os.path.join(config['input_datasets_dir'], 'meta.json')
        map_datasets = load_dataset("json", data_files=path_dataset)
    else: # COCO2014  DC100_EN  LLaVA-Bench-Wilder  LLaVA-Instruct-150K  ScienceQA  VibeEval
        path_dataset = os.path.join(config['input_datasets_dir'])
        map_datasets = load_from_disk(path_dataset)
    
    # If there's only one dataset, split it into train, validation, and test
    if config['dataset'] == "LLaVA-Instruct-150K":
        split_single = list(map_datasets.keys())[0]
        full_train_dataset = map_datasets[split_single]
        map_datasets = _split_dataset(full_train_dataset, config)

    if config['dataset'] == "ScienceQA":
        reduce_map = {'train': 100, 'validation': 100, 'test': 200}
        map_datasets = _apply_tiny_data_filter(map_datasets, reduce_map)

    if config['dataset'] in ['chartqa', 'docvqa_val', 'infovqa_val', 'ok_vqa_val2014', 'textvqa_val', 'vizwiz_vqa_val', 'vqav2_val', 'HQ-Edit', 'MagicBrush', 'MMVet', 'POPE', 'HallusionBench', 'QBench', 'NLVR2_Mantis', 'OCR-VQA']:
        reduce_map = {'train': 100, 'validation': 100, 'test': 100}
        map_datasets = _apply_tiny_data_filter(map_datasets, reduce_map)

    # Apply tiny_data or reduce_data filters
    if config['tiny_data']:
        tiny_map = {'train': 80, 'validation': 10, 'test': 3}
        map_datasets = _apply_tiny_data_filter(map_datasets, tiny_map)

    if config['reduce_data'] is not None:
        logging.info(f"[Dataset] Using reduced data to {config['reduce_data']} rows")
        for split in map_datasets.keys():
            map_datasets[split] = map_datasets[split].select(range(config['reduce_data']))

    # Log dataset info
    for split, dataset in map_datasets.items():
        logging.info(f"[Dataset] {split} dataset: {len(dataset)} samples")

    # Wrap datasets with MLLMDataset
    map_datasets = _wrap_with_mllm_dataset(map_datasets, config, tokenizer, drf_aux_tokenizer, drf_image_processor)

    return map_datasets


def create_data_loaders(config, tokenizer, drf_image_processor, drf_aux_tokenizer=None) -> Dict[str, DataLoader]:
    data_loaders = {}
    map_datasets = load_datasets(config, tokenizer, drf_aux_tokenizer, drf_image_processor)

    for split, dataset in map_datasets.items():
        shuffle = split == "train"
        data_loaders[split] = DataLoader(
            dataset,
            batch_size=config['batch_size'],
            shuffle=shuffle,
            collate_fn=dataset.collate_fn
        )

    return data_loaders


def _split_dataset(dataset, config):
    # Split the dataset: 80% train, 10% validation, 10% test
    train_valid_test = dataset.train_test_split(test_size=0.2, shuffle=True, seed=config['seed'])
    test_valid_split = train_valid_test['test'].train_test_split(test_size=0.5, shuffle=True, seed=config['seed'])

    # Combine splits into a DatasetDict
    return DatasetDict({
        'train': train_valid_test['train'],
        'validation': test_valid_split['train'],
        'test': test_valid_split['test'],
    })


def _apply_tiny_data_filter(map_datasets, tiny_map):
    logging.info("[Dataset] Using tiny data")
    for split in map_datasets.keys():
        map_datasets[split] = map_datasets[split].select(range(tiny_map[split]))
    return map_datasets


def _wrap_with_mllm_dataset(map_datasets, config, tokenizer, drf_aux_tokenizer, drf_image_processor):
    for split in map_datasets.keys():
        map_datasets[split] = MLLMDataset(map_datasets[split], config, tokenizer, drf_aux_tokenizer, drf_image_processor)
    return map_datasets
