import os
import pandas as pd
import os
import json
import pandas as pd
import pickle
import random
import cv2
import albumentations as A
import numpy as np 
import torch 
import copy


from PIL import Image
from copy import deepcopy 
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from torch.utils.data import Dataset
from files.datasets.dataset_registry import DatasetRegistry, ImageTextDataset  


PATHVQA_REASONING_TRAIN_FILE = "./reasoning_datasets/path-vqa_train.json"
PATHVQA_REASONING_VALID_FILE = "./reasoning_datasets/path-vqa_valid.json"
PATHVQA_SELFIMP_PROCESSED_FILE = "./self_imp_processed_files/{model_name}-path-vqa-COMCTS-SelfImprove-Inference_self_imp_processed_combined.json"
N_KNN = 20


model_name_converter = {
   "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" : "DS-R1-Qwen-1.5B",
   "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" : "DS-R1-Llama-8B"
}



TO_BE_DELETED_KEYS = [
    'dtw_distances',
    'frechet_distances',
    'Kmedoid_frdist_distances_ncluster_10',
    'Kmedoid_frdist_distances_ncluster_20',
    'Kmedoid_dtw_distances_ncluster_10',
    'Kmedoid_dtw_distances_ncluster_20',
    'SpectralClustering_frdist_distances_ncluster_10',
    'SpectralClustering_frdist_distances_ncluster_20',
    'SpectralClustering_dtw_distances_ncluster_10',
    'SpectralClustering_dtw_distances_ncluster_20',
]


def get_direct_answer_item_data(config, data, idx):
    img_path = data.iloc[idx]["image"]
    question = data.iloc[idx]["question"].lower().strip()
    answer = data.iloc[idx]["answer"]
    answer = str(answer).lower().strip()

    ans_type = 'closed' if answer in {'yes', 'no'} else 'open'

    given_text = 'question: ' + question + ' the answer is:'
    full_text = 'question: ' + question + ' the answer is: ' + answer

    return {
        'img_path':img_path,
        'given_text':given_text,
        'full_text':full_text,
        'ans_type':ans_type,
        'mc_options': None ,
        'question':question
    }

def get_reasoning_item_data(config, data, idx):

    question = data[idx]["question"].lower().strip()
    img_path = data[idx]["image_path"]
    answer = data[idx]["reasoning_answer"]

    ans_type = 'open'

    given_text = 'question: ' + question + ' the answer is:'
    full_text = 'question: ' + question + ' the answer is: ' + answer

    return {
        'img_path':img_path,
        'given_text':given_text,
        'full_text':full_text,
        'ans_type':ans_type,
        'mc_options': None,
        'question':question,
    }


def load_data_(config, phase, _key, ratio):
    if phase == 'train':
        with open(PATHVQA_REASONING_TRAIN_FILE, "r") as f:
            train_dicts = json.load(f)
        

        with open(PATHVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[config.LLM.model_name]), "r") as file:
            data2 = json.load(file)
            data2 = [x for x in data2 if x['is_correct']==1]
            if _key in ['frechet_distances', 'dtw_distances']:
                data2 = sorted(data2, key=lambda x:sum(sorted(x[_key])[:N_KNN]))
            else:
                data2 = sorted(data2, key=lambda x:sorted(x[_key])[0])                            
            nsample = len(data2)
            data2 = data2[:int(nsample*ratio)]            
            for x in data2:
                x['image_path'] = x['img_path']
                for key_ in TO_BE_DELETED_KEYS:
                    del x[key_]
        return  train_dicts + data2

    elif phase == 'valid':
        with open(PATHVQA_REASONING_VALID_FILE, "r") as f:
            valid_dicts = json.load(f)
        return valid_dicts
    else:
        raise ValueError('Invalid phase')


@DatasetRegistry.register('path-vqa')
class PathVQADataset(ImageTextDataset):
    def load_data(self):
        qa_main_path = self.config.path_vqa.qa_main_path
        if self.phase == 'train':
            with open(qa_main_path + "train/train_qa.pkl", "rb") as f:
                train_dicts = pickle.load(f)
            for i in range(len(train_dicts)):
                train_dicts[i]["image"] = os.path.join(self.config.path_vqa.image_main_path, "train", train_dicts[i]["image"] + ".jpg")

            with open(qa_main_path + "val/val_qa.pkl", "rb") as f:
                val_dicts = pickle.load(f)
            for i in range(len(val_dicts)):
                val_dicts[i]["image"] = os.path.join(self.config.path_vqa.image_main_path, "val", val_dicts[i]["image"] + ".jpg")

            return pd.DataFrame(train_dicts + val_dicts)

        elif self.phase == 'valid':
            with open(qa_main_path + "test/test_qa.pkl", "rb") as f:
                dicts = pickle.load(f)
            for i in range(len(dicts)):
                dicts[i]["image"] = os.path.join(self.config.path_vqa.image_main_path, "test", dicts[i]["image"] + ".jpg")
            return pd.DataFrame(dicts)

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-abbv')
class PathVQADataset(ImageTextDataset):
    def load_data(self):
        qa_main_path = self.config.path_vqa.qa_main_path
        if self.phase == 'train':
            with open(qa_main_path + "train/train_qa.pkl", "rb") as f:
                train_dicts = pickle.load(f)
            for i in range(len(train_dicts)):
                train_dicts[i]["image"] = os.path.join(self.config.path_vqa.image_main_path, "train", train_dicts[i]["image"] + ".jpg")

            with open(qa_main_path + "val/val_qa.pkl", "rb") as f:
                val_dicts = pickle.load(f)
            for i in range(len(val_dicts)):
                val_dicts[i]["image"] = os.path.join(self.config.path_vqa.image_main_path, "val", val_dicts[i]["image"] + ".jpg")
            data = train_dicts + val_dicts
            # return pd.DataFrame(train_dicts + val_dicts)

            print(f"{len(data)=}")
            
            dset = 'path-vqa'
            with open(f"./reasoning_datasets/{dset}_train.json", "r") as file:
                comcts = json.load(file)
            with open(f"./reasoning_datasets/{dset}_valid.json", "r") as file:
                comcts = comcts + json.load(file)
            with open(PATHVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                self_imp = json.load(file)

            comcts = [{
                'image_path': x['image_path'],
                'question': x['question'].lower(),
            } for x in comcts]

            print(f"{len(comcts)=}")

            self_imp = [{
                'image_path': x['image_path'],
                'question': x['question'].lower(),
            } for x in self_imp]

            print(f"{len(self_imp)=}")

            abbv = set()
            for x in comcts + self_imp:
                abbv.add((x['image_path'], x['question']))

            print(f"{len(abbv)=}")

            print(f"{len(data)=}")

            data_new = []

            for x in data:
                if (x["image"], x["question"].lower()) in abbv:
                    data_new.append(x)     

            print(f"{len(data_new)=}")
            # assert False 
            return pd.DataFrame(data_new)

        elif self.phase == 'valid':
            with open(qa_main_path + "test/test_qa.pkl", "rb") as f:
                dicts = pickle.load(f)
            for i in range(len(dicts)):
                dicts[i]["image"] = os.path.join(self.config.path_vqa.image_main_path, "test", dicts[i]["image"] + ".jpg")
            # data = dicts
            return pd.DataFrame(dicts)


    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS')
class PathVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(PATHVQA_REASONING_TRAIN_FILE, "r") as f:
                return json.load(f)
            
        elif self.phase == 'valid':
            with open(PATHVQA_REASONING_VALID_FILE, "r") as f:
                return  json.load(f)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)
    

@DatasetRegistry.register('path-vqa-SelfImprove')
class PathVQADataset(ImageTextDataset):
    def load_data(self):

        qa_main_path = self.config.path_vqa.qa_main_path

        with open(qa_main_path + "train/train_qa.pkl", "rb") as f:
            train_dicts = pickle.load(f)
        for i in range(len(train_dicts)):
            train_dicts[i]["image"] = os.path.join(self.config.path_vqa.image_main_path, "train", train_dicts[i]["image"] + ".jpg")

        with open(qa_main_path + "val/val_qa.pkl", "rb") as f:
            val_dicts = pickle.load(f)
        for i in range(len(val_dicts)):
            val_dicts[i]["image"] = os.path.join(self.config.path_vqa.image_main_path, "val", val_dicts[i]["image"] + ".jpg")

        data = train_dicts + val_dicts
        
        with open(PATHVQA_REASONING_TRAIN_FILE, "r") as f:
            data_reasoning = json.load(f)
    
        with open(PATHVQA_REASONING_VALID_FILE, "r") as f:
            data_reasoning += json.load(f)
                

        reasoning_pairs = {(x['image_path'], x['question'].lower().strip()) for x in data_reasoning}

        new_data = []
        print(f'data before cleaning {len(data)}')
        for x in data:
            if (x["image"], x["question"].lower().strip()) in reasoning_pairs:
                continue 
            new_data.append(x)

        print(f'data before cleaning {len(new_data)}')

        return pd.DataFrame(new_data)

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)



@DatasetRegistry.register('path-vqa-COMCTS-SelfImp')
class PathVQADataset(ImageTextDataset):
    def load_data(self):

        if self.phase == 'train':
            with open(PATHVQA_REASONING_TRAIN_FILE, "r") as f:
                train_dicts = json.load(f)
            
            with open(PATHVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                data2 = json.load(file)
                nsample = len(data2)
                for x in data2:
                    x['image_path'] = x['img_path']
                    for key_ in TO_BE_DELETED_KEYS:
                        del x[key_]
            return  train_dicts + data2

        elif self.phase == 'valid':
            with open(PATHVQA_REASONING_VALID_FILE, "r") as f:
                valid_dicts = json.load(f)
            return valid_dicts
        else:
            raise ValueError('Invalid phase')


    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect')
class PathVQADataset(ImageTextDataset):
    def load_data(self):

        if self.phase == 'train':
            with open(PATHVQA_REASONING_TRAIN_FILE, "r") as f:
                train_dicts = json.load(f)
            
            with open(PATHVQA_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                data2 = json.load(file)
                data2 = [x for x in data2 if x['is_correct']==1]
                nsample = len(data2)
                for x in data2:
                    x['image_path'] = x['img_path']
                    for key_ in TO_BE_DELETED_KEYS:
                        del x[key_]
            return  train_dicts + data2

        elif self.phase == 'valid':
            with open(PATHVQA_REASONING_VALID_FILE, "r") as f:
                valid_dicts = json.load(f)
            return valid_dicts
        else:
            raise ValueError('Invalid phase')


    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-dtw_distances_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'dtw_distances', 0.5)
    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-dtw_distances_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'dtw_distances', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-frechet_distances_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'frechet_distances', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-frechet_distances_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'frechet_distances', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_10_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_10_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_20_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_20_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_25')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.25)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_20_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_20_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_10_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)

@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_10_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_20_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)

@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_20_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_10_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_10_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_20_ratio_50')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('path-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_20_ratio_80')
class PATHVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


