import os
import json
import pandas as pd
import pickle
import random
import cv2
import albumentations as A
import numpy as np 
import torch 
import copy


from PIL import Image
from copy import deepcopy 
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from torch.utils.data import Dataset


CAPTION_PROMPTS = [
    "description of the key findings in this medical image",
    "comprehensive caption summarizing the main findings in this image",
    "summarization of this medical image",
    "report",
    "caption",
    "findings",
    "key findings",
    "analysis of the provided image",
    "observations",
    "observations in the provided image",
]

def generated_caption_prompt():
    cap = random.sample(CAPTION_PROMPTS, 1)[0]
    return cap 



class DatasetRegistry:
    # This dictionary acts as the central registry where we map dataset names to classes
    _registry = {}

    @classmethod
    def register(cls, name):
        """
        Registers a dataset class with a given name.
        The `name` will be the key used to retrieve the dataset class.
        """
        def inner_wrapper(wrapped_class):
            # Add the class to the registry with the provided name
            cls._registry[name] = wrapped_class
            return wrapped_class
        return inner_wrapper

    @classmethod
    def get_dataset(cls, name, *args, **kwargs):
        """
        Retrieves the dataset class based on the provided name.
        If the name is not found, raises a ValueError.
        """
        dataset_class = cls._registry.get(name)
        if dataset_class is None:
            raise ValueError(f"Dataset {name} is not registered.")
        return dataset_class(*args, **kwargs)  # Instantiate the class with args and kwargs




class ImageTextDataset(Dataset):
    def __init__(self, config, transforms, tokenizer, img_padding, dset_name, phase):
        self.config = config
        self.transforms = transforms
        self.tokenizer = tokenizer
        self.img_padding = img_padding
        self.seq_length = config.LLM.seq_length
        self.dset_name = dset_name
        self.phase = phase

        # Load data (to be defined in the subclass)
        self.data = self.load_data()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Define in the subclass how to handle specific indexing

        item_data = self.get_item_data(idx)
        img_path = item_data['img_path']
        text_input = item_data['given_text']
        full_text = item_data['full_text']
        ans_type = item_data['ans_type']
        mc_options = item_data['mc_options']
        question = item_data['question']

        # Read image 
        output = {"image": self.read_img(img_path, self.transforms), "dset_name": self.dset_name} if img_path is not None else {"dset_name": self.dset_name}
        
        output.update({'question':question})
        # Prepare the tokenized text output
        text_out = self.prepare_text(
            text_input, full_text, self.tokenizer, self.seq_length, self.img_padding
        )

        output.update(text_out)
        if mc_options is not None:
            output.update({'mc_options': mc_options})
        output.update({'ans_type': ans_type, 'img_path':img_path} if img_path is not None else {'ans_type': ans_type})

        return output

    def read_img(self, img_path, image_transforms):
        img = Image.open(img_path)
        img = np.array(img).astype(np.float32)
        img = image_transforms(image=img)["image"]
        img = img - img.min()
        if img.max() != 0:
            img = img / img.max()
        if img.ndim != 3:
            img = img[..., np.newaxis]
            img = np.concatenate([img, img, img], axis=-1)
        img = np.transpose(img, (2, 0, 1))
        img = torch.tensor(img)
        img = img[:3]
        return img

    def prepare_text(self, given_text, text_all, tokenizer, seq_length, img_padding):
        text_all = tokenizer(text_all)
        input_ids = text_all["input_ids"]
        input_ids.append(tokenizer.eos_token_id)
        input_ids = np.array(input_ids)

        if len(input_ids) < seq_length:
            input_ids = np.pad(
                input_ids,
                (0, seq_length - len(input_ids)),
                "constant",
                constant_values=0,
            )
        else:
            input_ids = input_ids[:seq_length]

        label = copy.deepcopy(input_ids)
        label[label == 0] = -100

        given_text = tokenizer(given_text)
        if len(given_text["input_ids"]) < len(label):
            label[: len(given_text["input_ids"])] = -100
        label = label.tolist()
        label = np.array(img_padding + label)

        given_text = np.array(given_text.input_ids)
        given_text_len = given_text.shape[0]
        given_text_len = np.array(given_text_len)

        if len(given_text) < seq_length:
            given_text = np.pad(
                given_text,
                (0, seq_length - len(given_text)),
                "constant",
                constant_values=0,
            )
        else:
            given_text = given_text[:seq_length]

        output = {
            "input_ids": input_ids,
            "labels": label,
            "given_text": given_text,
            "given_text_len": given_text_len,
        }
        return output

    def load_data(self):
        """Method to load dataset data. Implement in subclass."""
        raise NotImplementedError

    def get_item_data(self, idx):
        """Method to return image path and text data. Implement in subclass."""
        raise NotImplementedError







@DatasetRegistry.register('mimic-cxr')
class MIMICCXRDataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.mimic_cxr.train_json_path, 'r') as f:
                return json.load(f)
        elif self.phase == 'valid':
            with open(self.config.mimic_cxr.valid_json_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        img_path = self.data[idx]["image_path"]
        caption = self.data[idx]["report"]

        # given_text = 'the report is:'
        # full_text = 'the report is: ' + str(caption).lower().strip()

        prompt = generated_caption_prompt()
        given_text = prompt + ":"
        full_text = prompt + ": " + str(caption).lower().strip()

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }






@DatasetRegistry.register('iu-xray')
class IUXRAYDataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            return pd.read_csv(self.config.iu_xray.train_csv_path).to_dict(orient='records')
        elif self.phase == 'valid':
            return pd.read_csv(self.config.iu_xray.valid_csv_path).to_dict(orient='records')

    def get_item_data(self, idx):
        img_path = os.path.join(self.config.iu_xray.image_main_path, self.data[idx]["filename"])
        findings = self.data[idx]["findings"]
        impression = self.data[idx]["impression"]

        findings_and_impression = ''
        if findings:
            findings_and_impression += str(findings)
            findings_and_impression += ' '
        if impression:
            findings_and_impression += str(impression)

        # given_text = 'findings and impression:'
        # full_text = 'findings and impression: ' + findings_and_impression.lower().strip()

        prompt = generated_caption_prompt()
        given_text = prompt + ":"
        full_text = prompt + ": " + findings_and_impression.lower().strip()

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }






@DatasetRegistry.register('pmc-oa')
class PMCOADataset(ImageTextDataset):
    def load_data(self):
        data = []
        if self.phase == 'train':
            with open(os.path.join(self.config.pmc_oa.json_files_main_path, 'train.jsonl'), 'r') as f:
                for line in f:
                    data.append(json.loads(line))
            with open(os.path.join(self.config.pmc_oa.json_files_main_path, 'valid.jsonl'), 'r') as f:
                for line in f:
                    data.append(json.loads(line))
        elif self.phase == 'valid':
            with open(os.path.join(self.config.pmc_oa.json_files_main_path, 'test.jsonl'), 'r') as f:
                for line in f:
                    data.append(json.loads(line))
        return data

    def get_item_data(self, idx):
        img_path = os.path.join(self.config.pmc_oa.image_main_path, self.data[idx]["image"])
        caption = str(self.data[idx]["caption"]).lower().strip()

        # given_text = 'the caption is:'
        # full_text = 'the caption is: ' + caption

        prompt = generated_caption_prompt()
        given_text = prompt + ":"
        full_text = prompt + ": " + caption

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }



@DatasetRegistry.register('roco')
class ROCODataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.roco.train_json_path, 'r') as f:
                data1 = json.load(f)
            with open(self.config.roco.valid_json_path, 'r') as f:
                data2 = json.load(f)
            return data1 + data2
        elif self.phase == 'valid':
            with open(self.config.roco.test_json_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        img_path = os.path.join(self.config.roco.image_main_path, self.data[idx]["name"])
        caption = str(self.data[idx]["caption"]).lower().strip()

        # given_text = 'the caption is:'
        # full_text = 'the caption is: ' + caption

        prompt = generated_caption_prompt()
        given_text = prompt + ":"
        full_text = prompt + ": " + caption

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }


# @DatasetRegistry.register('ubench-vqa')
# class UBENCHVQADataset(ImageTextDataset):
#     def load_data(self):
#         if self.phase == 'train':
#             with open(self.config.ubench.train_json_path, 'r') as f:
#                 return json.load(f)
#         elif self.phase == 'valid':
#             with open(self.config.ubench.valid_json_path, 'r') as f:
#                 return json.load(f)

#     def get_item_data(self, idx):
#         img_path = self.data[idx]["image"]
#         question = self.data[idx]["question"].lower().strip()
#         answer = self.data[idx]["answer"].lower().strip()
#         if answer.endswith('.'):
#             answer = answer[:-1]
        

#         given_text = 'question: ' + question + ' the answer is:'
#         full_text = 'question: ' + question + ' the answer is: ' + answer
#         return {
#             'img_path':img_path,
#             'given_text':given_text,
#             'full_text':full_text,
#             'ans_type': 'open',
#             'mc_options': None 
#         }



@DatasetRegistry.register('medqa')
class MEDQADataset(ImageTextDataset):
    def load_data(self):

        if self.phase == 'train':
            file_path = self.config.medqa.train_json_file_path
        elif self.phase == 'valid':
            file_path = self.config.medqa.valid_json_file_path

        data = []
        with open(file_path, 'r') as file:
            # Iterate through each line in the file
            for line in file:
                # Parse each line as JSON and print it
                sample = json.loads(line)
                data.append(sample)
        return data 

    def get_item_data(self, idx):
        question = self.data[idx]["question"].lower()
        answer = self.data[idx]["answer"].lower().strip()
        mc_options = [v.lower().strip() for k,v in self.data[idx]['options'].items()]
        random.shuffle(mc_options)

        options_text = ' options:' + ' a) ' + mc_options[0] + ' b) ' + mc_options[1] + ' c) ' + mc_options[2] + ' d) ' + mc_options[3]
        options_text = options_text.lower()
        given_text = 'question: ' + question + options_text  + '. the answer is:'
        full_text = 'question: ' + question + options_text + '. the answer is: ' + answer

        img_path = None

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': mc_options 
        }


@DatasetRegistry.register('medqa-CoT')
class MEDQADataset(ImageTextDataset):
    def load_data(self):

        if self.phase == 'train':
            file_path = self.config.medqa.train_json_file_path
        elif self.phase == 'valid':
            file_path = self.config.medqa.valid_json_file_path

        data = []
        with open(file_path, 'r') as file:
            # Iterate through each line in the file
            for line in file:
                # Parse each line as JSON and print it
                sample = json.loads(line)
                data.append(sample)
        return data 

    def get_item_data(self, idx):
        question = self.data[idx]["question"].lower()
        answer = self.data[idx]["answer"].lower().strip()
        mc_options = [v.lower().strip() for k,v in self.data[idx]['options'].items()]
        random.shuffle(mc_options)

        options_text = ' options:' + ' a) ' + mc_options[0] + ' b) ' + mc_options[1] + ' c) ' + mc_options[2] + ' d) ' + mc_options[3]
        options_text = options_text.lower()
        # given_text = 'question: ' + question + options_text  + '. the answer is:'
        # full_text = 'question: ' + question + options_text + '. the answer is: ' + answer

        given_text = (
            "You are a medical expert providing detailed reasoning for multiple-choice medical questions. Answer each question step by step, explaining your reasoning before selecting the final answer. "
            "question: A 45-year-old male presents with sudden onset chest pain radiating to his back, associated with sweating and shortness of breath. Blood pressure is 190/110 mmHg in both arms. What is the most likely diagnosis? "
            " options:" + " a) " + "Myocardial infarction" + " b) " + "Pulmonary embolism" + " c) " + "Aortic dissection" + " d) " + "Pericarditis "
            "Reasoning: First, note the sudden onset of severe chest pain radiating to the back, which is characteristic of an aortic dissection. The blood pressure is elevated, and the symmetric readings in both arms support this diagnosis. A myocardial infarction typically presents with substernal chest pain without radiation to the back. Pulmonary embolism is less likely without evidence of dyspnea or hypoxemia. Pericarditis would present with positional chest pain, which is not described here."
            "Final Answer: Aortic dissection. "
            "Similar to this example answer the following question by reasoning step by step. "
        ) + 'question: ' + question + options_text  + '. the answer is:'

        full_text = (
            "You are a medical expert providing detailed reasoning for multiple-choice medical questions. Answer each question step by step, explaining your reasoning before selecting the final answer. "
            "question: A 45-year-old male presents with sudden onset chest pain radiating to his back, associated with sweating and shortness of breath. Blood pressure is 190/110 mmHg in both arms. What is the most likely diagnosis? "
            " options:" + " a) " + "Myocardial infarction" + " b) " + "Pulmonary embolism" + " c) " + "Aortic dissection" + " d) " + "Pericarditis "
            "Reasoning: First, note the sudden onset of severe chest pain radiating to the back, which is characteristic of an aortic dissection. The blood pressure is elevated, and the symmetric readings in both arms support this diagnosis. A myocardial infarction typically presents with substernal chest pain without radiation to the back. Pulmonary embolism is less likely without evidence of dyspnea or hypoxemia. Pericarditis would present with positional chest pain, which is not described here."
            "Final Answer: Aortic dissection. "
            "Similar to this example answer the following question by reasoning step by step. "
        ) + 'question: ' + question + options_text + '. the answer is: ' + answer


        img_path = None

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': mc_options 
        }


@DatasetRegistry.register('medmcqa')
class MEDMCQADataset(ImageTextDataset):
    def load_data(self):

        if self.phase == 'train':
            data = []
            file_path = self.config.medmcqa.train_json_file_path
            with open(file_path, 'r') as file:
                # Iterate through each line in the file
                for line in file:
                    # Parse each line as JSON and print it
                    sample = json.loads(line)
                    data.append(sample)

            # file_path = self.config.medmcqa.valid_json_file_path
            # with open(file_path, 'r') as file:
            #     # Iterate through each line in the file
            #     for line in file:
            #         # Parse each line as JSON and print it
            #         sample = json.loads(line)
            #         data.append(sample)

        elif self.phase == 'valid':
            data = []
            file_path = self.config.medmcqa.valid_json_file_path
            with open(file_path, 'r') as file:
                # Iterate through each line in the file
                for line in file:
                    # Parse each line as JSON and print it
                    sample = json.loads(line)
                    data.append(sample)

        return data 

    def get_item_data(self, idx):
        question = self.data[idx]["question"].lower()
        # explanation = self.data[idx]["exp"] # only for training 
        ans_idx = int(self.data[idx]["cop"]) - 1
        mc_options = [self.data[idx]["op"+s].lower().strip() for s in ['a', 'b', 'c', 'd']]
        answer = mc_options[ans_idx].lower()

        options_text = ' options:' + ' a) ' + mc_options[0] + ' b) ' + mc_options[1] + ' c) ' + mc_options[2] + ' d) ' + mc_options[3]
        options_text = options_text.lower()
        given_text = 'question: ' + question + options_text  + '. the answer is:'
        full_text = 'question: ' + question + options_text + '. the answer is: ' + answer
        img_path = None

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': mc_options 
        }


@DatasetRegistry.register('pubmedqa')
class PUBMEDQADataset(ImageTextDataset):
    def load_data(self):

        if self.phase == 'train':
            file_path = self.config.pubmedqa.train_json_file_path
        elif self.phase == 'valid':
            file_path = self.config.pubmedqa.valid_json_file_path

        with open(file_path, 'r') as file:
            data_dict = json.load(file)
        data = []
        for k, v in data_dict.items():
            v.update({'id':k})
            data.append(v)

        return data 

    def get_item_data(self, idx):
        question = self.data[idx]["QUESTION"].lower().strip()
        answer = self.data[idx]["final_decision"].lower().strip()

        given_text = 'question: ' + question + ' the answer is:'
        full_text = 'question: ' + question + ' the answer is: ' + answer
        img_path = None
        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'closed',
            'mc_options': None  
        }

    


@DatasetRegistry.register('med-trinity')
class MEDTRINITYDataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.med_trinity.train_json_file_path, 'r') as f:
                return json.load(f)
        elif self.phase == 'valid':
            with open(self.config.med_trinity.valid_json_file_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        img_path = self.data[idx]["file_name"]
        caption = self.data[idx]["caption"]

        given_text = 'the caption is:'
        full_text = 'the caption is: ' + str(caption).lower().strip()

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }


@DatasetRegistry.register('med-trinity-part0')
class MEDTRINITYDataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.med_trinity_part0.train_json_file_path, 'r') as f:
                return json.load(f)
        elif self.phase == 'valid':
            with open(self.config.med_trinity_part0.valid_json_file_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        img_path = self.data[idx]["file_name"]
        caption = self.data[idx]["caption"]

        given_text = 'the caption is:'
        full_text = 'the caption is: ' + str(caption).lower().strip()

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }


@DatasetRegistry.register('med-trinity-part1')
class MEDTRINITYDataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.med_trinity_part1.train_json_file_path, 'r') as f:
                return json.load(f)
        elif self.phase == 'valid':
            with open(self.config.med_trinity_part1.valid_json_file_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        img_path = self.data[idx]["file_name"]
        caption = self.data[idx]["caption"]

        given_text = 'the caption is:'
        full_text = 'the caption is: ' + str(caption).lower().strip()

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }


@DatasetRegistry.register('med-trinity-part2')
class MEDTRINITYDataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.med_trinity_part2.train_json_file_path, 'r') as f:
                return json.load(f)
        elif self.phase == 'valid':
            with open(self.config.med_trinity_part2.valid_json_file_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        img_path = self.data[idx]["file_name"]
        caption = self.data[idx]["caption"]

        given_text = 'the caption is:'
        full_text = 'the caption is: ' + str(caption).lower().strip()

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }

@DatasetRegistry.register('med-trinity-part3')
class MEDTRINITYDataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.med_trinity_part3.train_json_file_path, 'r') as f:
                return json.load(f)
        elif self.phase == 'valid':
            with open(self.config.med_trinity_part3.valid_json_file_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        img_path = self.data[idx]["file_name"]
        caption = self.data[idx]["caption"]

        given_text = 'the caption is:'
        full_text = 'the caption is: ' + str(caption).lower().strip()

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }

@DatasetRegistry.register('med-trinity-part4')
class MEDTRINITYDataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.med_trinity_part4.train_json_file_path, 'r') as f:
                return json.load(f)
        elif self.phase == 'valid':
            with open(self.config.med_trinity_part4.valid_json_file_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        img_path = self.data[idx]["file_name"]
        caption = self.data[idx]["caption"]

        given_text = 'the caption is:'
        full_text = 'the caption is: ' + str(caption).lower().strip()

        return {
            'img_path':img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': None 
        }




@DatasetRegistry.register('freedom-Intelligence')
class FreedomIntelligenceDataset(ImageTextDataset):
    def load_data(self):
        
        assert self.config.LLM.seq_length >= 350, 'YOU USE FREEDOM INTELLIGENCE DATASET, INCREASE THE SEQ LENGTH'
        if self.phase == 'train':
            file_path = self.config.freedom_intelligence.train_json_file_path
            with open(file_path, 'r') as file:
                data = json.load(file)

        elif self.phase == 'valid':
            file_path = self.config.freedom_intelligence.valid_json_file_path
            with open(file_path, 'r') as file:
                data = json.load(file)

        return data 

    def get_item_data(self, idx):

        question = self.data[idx]["Question"].lower()
        answer = self.data[idx]["Complex_CoT"].lower()

        given_text = 'question: ' + question  + '. chain of thought:'
        full_text = 'question: ' + question + '. chain of thought: ' + answer
        img_path = None

        return {
            'img_path': img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': img_path ,
        }



@DatasetRegistry.register('breast-usound-vqa')
class FreedomIntelligenceDataset(ImageTextDataset):
    def load_data(self):
        
        with open("./comts_out_files/breast_ultrasound_train_generated.jsonl", "r") as file:
            data = [json.loads(x) for x in file]
        data = [x['qwen2_vl_72b'] for x in data]

        def _prepare_text(text):
    
            text = text.split("###")
            image_description = [x for x in text if "Image Description:" in x][0]
            image_description = image_description.replace('\n', ' ')
            image_description = image_description.split("Image Description:")[-1].strip()
            image_description = image_description.split('.')

            rationales = [x for x in text if "Rationales:" in x][0]
            rationales = rationales.replace('\n', ' ')
            rationales = rationales.split("Rationales:")[-1].strip()
            rationales = rationales.split('.')

            final_answer = [x for x in text if "The final answer is:" in x][0]
            final_answer = final_answer.replace('\n', ' ').strip()
            final_answer = final_answer.split('.')

            all_chain = image_description + rationales + final_answer
            all_chain = [x for x in all_chain if x!='']
            all_chain = [x.strip()+"." for x in all_chain]
            all_chain = " ".join(all_chain)
            return all_chain

        for x in data:
            x['reasoning'] =  _prepare_text(x['response'])

        return data 

    def get_item_data(self, idx):

        question = self.data[idx]["question"].lower()
        answer = self.data[idx]["reasoning"].lower()
        img_path = self.data[idx]["image_path"]

        given_text = 'question: ' + question + ' the answer is:'
        full_text = 'question: ' + question + ' the answer is: ' + answer

        return {
            'img_path': img_path,
            'given_text':given_text,
            'full_text':full_text,
            'ans_type': 'open',
            'mc_options': img_path ,
        }
