import os
import pandas as pd
import os
import json
import pandas as pd
import pickle
import random
import cv2
import albumentations as A
import numpy as np 
import torch 
import copy


from PIL import Image
from copy import deepcopy 
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from torch.utils.data import Dataset
from files.datasets.dataset_registry import DatasetRegistry, ImageTextDataset  


OMNIMEDVQA_REASONING_TRAIN_FILE = "./reasoning_datasets/omnimed-vqa_train.json"
OMNIMEDVQA_REASONING_VALID_FILE = "./reasoning_datasets/omnimed-vqa_valid.json"
OMNIMEDVQA_SELFIMP_PROCESSED_FILE = "./self_imp_processed_files/{model_name}-omnimed-vqa-COMCTS-SelfImprove-Inference_self_imp_processed_combined.json"
N_KNN = 20

model_name_converter = {
   "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" : "DS-R1-Qwen-1.5B",
   "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" : "DS-R1-Llama-8B"
}


TO_BE_DELETED_KEYS = [
    'dtw_distances',
    'frechet_distances',
    'Kmedoid_frdist_distances_ncluster_10',
    'Kmedoid_frdist_distances_ncluster_20',
    'Kmedoid_dtw_distances_ncluster_10',
    'Kmedoid_dtw_distances_ncluster_20',
    'SpectralClustering_frdist_distances_ncluster_10',
    'SpectralClustering_frdist_distances_ncluster_20',
    'SpectralClustering_dtw_distances_ncluster_10',
    'SpectralClustering_dtw_distances_ncluster_20',
]


def get_direct_answer_item_data(config, data, idx):
    img_path = os.path.join(config.omnimed_vqa.image_main_path, data[idx]["image_path"])
    question = data[idx]["question"].lower().strip() 
    answer = data[idx]["gt_answer"].lower().strip()

    given_text = 'question: ' + question + ' the answer is:'
    full_text = 'question: ' + question + ' the answer is: ' + answer
    
    return {
        'img_path':img_path,
        'given_text':given_text,
        'full_text':full_text,
        'ans_type': 'open',
        'mc_options': None,
        'question':question,
    }
    
def get_reasoning_item_data(config, data, idx):

    question = data[idx]["question"].lower().strip()
    img_path = data[idx]["image_path"]
    answer = data[idx]["reasoning_answer"]

    ans_type = 'open'

    given_text = 'question: ' + question + ' the answer is:'
    full_text = 'question: ' + question + ' the answer is: ' + answer

    return {
        'img_path':img_path,
        'given_text':given_text,
        'full_text':full_text,
        'ans_type':ans_type,
        'mc_options': None,
        'question':question,
    }


def load_data_(config, phase, _key, ratio):
    if phase == 'train':
        with open(OMNIMEDVQA_REASONING_TRAIN_FILE, "r") as f:
            train_dicts = json.load(f)
        
        with open(OMNIMEDVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[config.LLM.model_name]), "r") as file:
            data2 = json.load(file)
            data2 = [x for x in data2 if x['is_correct']==1]
            if _key in ['frechet_distances', 'dtw_distances']:
                data2 = sorted(data2, key=lambda x:sum(sorted(x[_key])[:N_KNN]))
            else:
                data2 = sorted(data2, key=lambda x:sorted(x[_key])[0])                            
            nsample = len(data2)
            data2 = data2[:int(nsample*ratio)]            
            for x in data2:
                x['image_path'] = x['img_path']
                for key_ in TO_BE_DELETED_KEYS:
                    del x[key_]
        return  train_dicts + data2

    elif phase == 'valid':
        with open(OMNIMEDVQA_REASONING_VALID_FILE, "r") as f:
            valid_dicts = json.load(f)
        return valid_dicts
    else:
        raise ValueError('Invalid phase')


@DatasetRegistry.register('omnimed-vqa')
class OMNIMEDVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.omnimed_vqa.train_json_path, 'r') as f:
                return json.load(f)
        elif self.phase == 'valid':
            with open(self.config.omnimed_vqa.valid_json_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-abbv')
class OMNIMEDVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.omnimed_vqa.train_json_path, 'r') as f:
                data = json.load(f)
            print(f'{len(data)=}')

            dset = 'omnimed-vqa'
            with open(f"./reasoning_datasets/{dset}_train.json", "r") as file:
                comcts = json.load(file)
            with open(f"./reasoning_datasets/{dset}_valid.json", "r") as file:
                comcts = comcts + json.load(file)
            with open(OMNIMEDVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                self_imp = json.load(file)

            comcts = [{
                'image_path': x['image_path'],
                'question': x['question'].lower(),
            } for x in comcts]
            print(f'{len(comcts)=}')

            self_imp = [{
                'image_path': x['image_path'],
                'question': x['question'].lower(),
            } for x in self_imp]
            print(f'{len(self_imp)=}')

            abbv = set()
            for x in comcts + self_imp:
                abbv.add((x['image_path'], x['question']))
            print(f'{len(abbv)=}')

            data_new = []

            for x in data:
                if (os.path.join(self.config.omnimed_vqa.image_main_path, x["image_path"]), str(x["question"]).lower()) in abbv:
                    data_new.append(x)     
            print(f'{len(data_new)=}')


            return data_new
            
        elif self.phase == 'valid':
            with open(self.config.omnimed_vqa.valid_json_path, 'r') as f:
                return json.load(f)

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(OMNIMEDVQA_REASONING_TRAIN_FILE, "r") as f:
                return json.load(f)
            
        elif self.phase == 'valid':
            with open(OMNIMEDVQA_REASONING_VALID_FILE, "r") as f:
                return json.load(f)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)
    


@DatasetRegistry.register('omnimed-vqa-SelfImprove')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):

        with open(self.config.omnimed_vqa.train_json_path, 'r') as f:
            data = json.load(f)
        
        with open(OMNIMEDVQA_REASONING_TRAIN_FILE, "r") as f:
            data_reasoning = json.load(f)
    
        with open(OMNIMEDVQA_REASONING_VALID_FILE, "r") as f:
            data_reasoning += json.load(f)
            
        reasoning_pairs = {(x['image_path'], x['question'].lower().strip()) for x in data_reasoning}

        new_data = []
        print(f'data before cleaning {len(data)}')
        for x in data:

            img_path = os.path.join(self.config.omnimed_vqa.image_main_path, x["image_path"])
            question = x["question"].lower().strip()             

            if (img_path, question) in reasoning_pairs:
                continue 
            new_data.append(x)

        print(f'data before cleaning {len(new_data)}')

        return new_data

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)



@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):

        if self.phase == 'train':
            with open(OMNIMEDVQA_REASONING_TRAIN_FILE, "r") as f:
                train_dicts = json.load(f)
            
            with open(OMNIMEDVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                data2 = json.load(file)
                nsample = len(data2)
                for x in data2:
                    x['image_path'] = x['img_path']
                    for key_ in TO_BE_DELETED_KEYS:
                        del x[key_]
            return  train_dicts + data2

        elif self.phase == 'valid':
            with open(OMNIMEDVQA_REASONING_VALID_FILE, "r") as f:
                valid_dicts = json.load(f)
            return valid_dicts
        else:
            raise ValueError('Invalid phase')


    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):

        if self.phase == 'train':
            with open(OMNIMEDVQA_REASONING_TRAIN_FILE, "r") as f:
                train_dicts = json.load(f)
            
            with open(OMNIMEDVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                data2 = json.load(file)
                data2 = [x for x in data2 if x['is_correct']==1]
                nsample = len(data2)
                for x in data2:
                    x['image_path'] = x['img_path']
                    for key_ in TO_BE_DELETED_KEYS:
                        del x[key_]
            return  train_dicts + data2

        elif self.phase == 'valid':
            with open(OMNIMEDVQA_REASONING_VALID_FILE, "r") as f:
                valid_dicts = json.load(f)
            return valid_dicts
        else:
            raise ValueError('Invalid phase')


    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-dtw_distances_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'dtw_distances', 0.5)
    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-dtw_distances_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'dtw_distances', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-frechet_distances_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'frechet_distances', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-frechet_distances_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'frechet_distances', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_10_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_10_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_20_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_20_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_25')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.25)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_20_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_20_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_10_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)

@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_10_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_20_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)

@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_20_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_10_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_10_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_20_ratio_50')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('omnimed-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_20_ratio_80')
class OmnimedVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


