import os
import pandas as pd
import os
import json
import pandas as pd
import pickle
import random
import cv2
import albumentations as A
import numpy as np 
import torch 
import copy


from PIL import Image
from copy import deepcopy 
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from torch.utils.data import Dataset
from files.datasets.dataset_registry import DatasetRegistry, ImageTextDataset  



PMCVQA_REASONING_TRAIN_FILE = "./reasoning_datasets/pmc-vqa_train.json"
PMCVQA_REASONING_VALID_FILE = "./reasoning_datasets/pmc-vqa_valid.json"
PMCVQA_SELFIMP_PROCESSED_FILE = "./self_imp_processed_files/{model_name}-pmc-vqa-COMCTS-SelfImprove-Inference_self_imp_processed_combined.json"
N_KNN = 20

model_name_converter = {
   "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" : "DS-R1-Qwen-1.5B",
   "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" : "DS-R1-Llama-8B"
}

TO_BE_DELETED_KEYS = [
    'dtw_distances',
    'frechet_distances',
    'Kmedoid_frdist_distances_ncluster_10',
    'Kmedoid_frdist_distances_ncluster_20',
    'Kmedoid_dtw_distances_ncluster_10',
    'Kmedoid_dtw_distances_ncluster_20',
    'SpectralClustering_frdist_distances_ncluster_10',
    'SpectralClustering_frdist_distances_ncluster_20',
    'SpectralClustering_dtw_distances_ncluster_10',
    'SpectralClustering_dtw_distances_ncluster_20',
]

def get_direct_answer_item_data(config, data, idx):
    img_path = os.path.join(config.pmc_vqa.image_main_path, data[idx]["Figure_path"])
    question = data[idx]["Question"].lower().strip()
    answer = data[idx]["Answer"].lower().strip()

    given_text = 'question: ' + question + ' the answer is:'
    full_text = 'question: ' + question + ' the answer is: ' + answer

    return {
        'img_path':img_path,
        'given_text':given_text,
        'full_text':full_text,
        'ans_type': 'open',
        'mc_options': None,
        "question": question,
    }


def get_reasoning_item_data(config, data, idx):

    question = data[idx]["question"].lower().strip()
    img_path = data[idx]["image_path"]
    answer = data[idx]["reasoning_answer"]

    ans_type = 'open'

    given_text = 'question: ' + question + ' the answer is:'
    full_text = 'question: ' + question + ' the answer is: ' + answer

    return {
        'img_path':img_path,
        'given_text':given_text,
        'full_text':full_text,
        'ans_type':ans_type,
        'mc_options': None,
        'question':question,
    }


def load_data_(config, phase, _key, ratio):
    if phase == 'train':
        with open(PMCVQA_REASONING_TRAIN_FILE, 'r') as f:
            data1 = json.load(f)
        
        with open(PMCVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[config.LLM.model_name]), "r") as file:
            data2 = json.load(file)
            data2 = [x for x in data2 if x['is_correct']==1]
            if _key in ['frechet_distances', 'dtw_distances']:
                data2 = sorted(data2, key=lambda x:sum(sorted(x[_key])[:N_KNN]))
            else:
                data2 = sorted(data2, key=lambda x:sorted(x[_key])[0])                            
            nsample = len(data2)
            data2 = data2[:int(nsample*ratio)]
            for x in data2:
                x['image_path'] = x['img_path']
                for key_ in TO_BE_DELETED_KEYS:
                    del x[key_]

        return data1 + data2 

    elif phase == 'valid':
        with open(PMCVQA_REASONING_VALID_FILE, 'r') as f:
            return json.load(f)
    else:
        raise ValueError('Invalid phase')


@DatasetRegistry.register('pmc-vqa')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            return pd.read_csv(self.config.pmc_vqa.train_csv_path).to_dict(orient='records')
        elif self.phase == 'valid':
            return pd.read_csv(self.config.pmc_vqa.valid_csv_path).to_dict(orient='records')

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)



@DatasetRegistry.register('pmc-vqa-abbv')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            data = pd.read_csv(self.config.pmc_vqa.train_csv_path).to_dict(orient='records')
            print(f'{len(data)=}')

            dset = 'pmc-vqa'
            with open(f"./reasoning_datasets/{dset}_train.json", "r") as file:
                comcts = json.load(file)
            with open(f"./reasoning_datasets/{dset}_valid.json", "r") as file:
                comcts = comcts + json.load(file)
            with open(PMCVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                self_imp = json.load(file)

            comcts = [{
                'image_path': x['image_path'],
                'question': x['question'].lower(),
            } for x in comcts]
            print(f'{len(comcts)=}')
            self_imp = [{
                'image_path': x['image_path'],
                'question': x['question'].lower(),
            } for x in self_imp]
            print(f'{len(self_imp)=}')

            abbv = set()
            for x in comcts + self_imp:
                abbv.add((x['image_path'], x['question']))
            print(f'{len(abbv)=}')

            data_new = []

            for x in data:
                if (os.path.join(self.config.pmc_vqa.image_main_path, x["Figure_path"]), str(x["Question"]).lower().strip()) in abbv:
                    data_new.append(x)     
            print(f'{len(data_new)=}')
            kk = data[int(len(data)//2)]

            return data_new

        elif self.phase == 'valid':
            return pd.read_csv(self.config.pmc_vqa.valid_csv_path).to_dict(orient='records')

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)



@DatasetRegistry.register('pmc-vqa-COMCTS')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(PMCVQA_REASONING_TRAIN_FILE, 'r') as f:
                return json.load(f)

        elif self.phase == 'valid':
            with open(PMCVQA_REASONING_VALID_FILE, 'r') as f:
                return json.load(f)
        else:
            raise ValueError('Invalid phase')

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-SelfImprove')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        data = pd.read_csv(self.config.pmc_vqa.train_csv_path).to_dict(orient='records')

        with open(PMCVQA_REASONING_TRAIN_FILE, 'r') as f:
            data_reasoning = json.load(f)
        with open(PMCVQA_REASONING_VALID_FILE, 'r') as f:
            data_reasoning += json.load(f)
        
        reasoning_pairs = {(x['image_path'], x['question'].split("Choices")[0].lower().strip()) for x in data_reasoning}

        new_data = []
        print(f'data before cleaning {len(data)}')
        # print(f'{data_reasoning[0]}')
        # print(f'-'*20)
        for x in data:
            img_path = os.path.join(self.config.pmc_vqa.image_main_path, x["Figure_path"])
            question = x["Question"].lower().strip()
            # print(f'{img_path=}')
            # print(f'{question=}')
            # raise ValueError

            if (img_path, question) in reasoning_pairs:
                continue 
            new_data.append(x)

        print(f'data after cleaning {len(new_data)}')
        return new_data

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(PMCVQA_REASONING_TRAIN_FILE, 'r') as f:
                data1 = json.load(f)
            
            with open(PMCVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                data2 = json.load(file)
                for x in data2:
                    x['image_path'] = x['img_path']
                    for key_ in TO_BE_DELETED_KEYS:
                        del x[key_]

            return data1 + data2 

        elif self.phase == 'valid':
            with open(PMCVQA_REASONING_VALID_FILE, 'r') as f:
                return json.load(f)
        else:
            raise ValueError('Invalid phase')

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(PMCVQA_REASONING_TRAIN_FILE, 'r') as f:
                data1 = json.load(f)
            
            with open(PMCVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                data2 = json.load(file)
                data2 = [x for x in data2 if x['is_correct']==1]

                for x in data2:
                    x['image_path'] = x['img_path']
                    for key_ in TO_BE_DELETED_KEYS:
                        del x[key_]

            return data1 + data2 

        elif self.phase == 'valid':
            with open(PMCVQA_REASONING_VALID_FILE, 'r') as f:
                return json.load(f)
        else:
            raise ValueError('Invalid phase')

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-dtw_distances_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'dtw_distances', 0.5)
    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-dtw_distances_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'dtw_distances', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-frechet_distances_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'frechet_distances', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-frechet_distances_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'frechet_distances', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_10_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_10_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_20_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_20_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_25')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.25)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_20_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_20_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_10_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)

@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_10_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_20_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)

@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_20_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_10_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_10_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_20_ratio_50')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('pmc-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_20_ratio_80')
class PMCVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


