import os
import pandas as pd
import os
import json
import pandas as pd
import pickle
import random
import cv2
import albumentations as A
import numpy as np 
import torch 
import copy


from PIL import Image
from copy import deepcopy 
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from torch.utils.data import Dataset
from files.datasets.dataset_registry import DatasetRegistry, ImageTextDataset  

SLAKE_REASONING_TRAIN_FILE = "./reasoning_datasets/slake-vqa_train.json"
SLAKE_REASONING_VALID_FILE = "./reasoning_datasets/slake-vqa_valid.json"

SLAKE_SELFIMP_PROCESSED_FILE = "./self_imp_processed_files/{model_name}-slake-vqa-COMCTS-SelfImprove-Inference_self_imp_processed_combined.json"
N_KNN = 20

model_name_converter = {
   "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" : "DS-R1-Qwen-1.5B",
   "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" : "DS-R1-Llama-8B"
}


TO_BE_DELETED_KEYS = [
    'dtw_distances',
    'frechet_distances',
    'Kmedoid_frdist_distances_ncluster_10',
    'Kmedoid_frdist_distances_ncluster_20',
    'Kmedoid_dtw_distances_ncluster_10',
    'Kmedoid_dtw_distances_ncluster_20',
    'SpectralClustering_frdist_distances_ncluster_10',
    'SpectralClustering_frdist_distances_ncluster_20',
    'SpectralClustering_dtw_distances_ncluster_10',
    'SpectralClustering_dtw_distances_ncluster_20',
]


def get_direct_answer_item_data(config, data, idx):
    img_path = os.path.join(config.slake_vqa.image_main_path, data[idx]["img_name"])
    question = data[idx]["question"].lower().strip()
    answer = data[idx]["answer"].lower().strip()
    ans_type = data[idx]['answer_type'].lower().strip()

    given_text = 'question: ' + question + ' the answer is:'
    full_text = 'question: ' + question + ' the answer is: ' + answer

    return {
        'img_path':img_path,
        'given_text':given_text,
        'full_text':full_text,
        'ans_type': ans_type,
        'mc_options': None,
        'question': question,
    }


def get_reasoning_item_data(config, data, idx):

    question = data[idx]["question"].lower().strip()
    img_path = data[idx]["image_path"]
    answer = data[idx]["reasoning_answer"]

    ans_type = 'open'

    given_text = 'question: ' + question + ' the answer is:'
    full_text = 'question: ' + question + ' the answer is: ' + answer

    return {
        'img_path':img_path,
        'given_text':given_text,
        'full_text':full_text,
        'ans_type':ans_type,
        'mc_options': None,
        'question':question,
    }


def load_data_(config, phase, _key, ratio):
    if phase == 'train':
        with open(SLAKE_REASONING_TRAIN_FILE, 'r') as f:
            data1 =  json.load(f)

        with open(SLAKE_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[config.LLM.model_name]), "r") as file:
            data2 = json.load(file)
            data2 = [x for x in data2 if x['is_correct']==1]
            if _key in ['frechet_distances', 'dtw_distances']:
                data2 = sorted(data2, key=lambda x:sum(sorted(x[_key])[:N_KNN]))
            else:
                data2 = sorted(data2, key=lambda x:sorted(x[_key])[0])                            
            nsample = len(data2)
            data2 = data2[:int(nsample*ratio)] 
            for x in data2:
                x['image_path'] = x['img_path']
                for key_ in TO_BE_DELETED_KEYS:
                    del x[key_]

        return data1 + data2 

    elif phase == 'valid':
        with open(SLAKE_REASONING_VALID_FILE, 'r') as f:
            return json.load(f)
    else:
        raise ValueError('Invalid phase')



@DatasetRegistry.register('slake-vqa')
class SlakeVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.slake_vqa.train_json_path, 'r') as f:
                data_train = json.load(f)
            with open(self.config.slake_vqa.valid_json_path, 'r') as f:
                data_valid = json.load(f)
            return data_train + data_valid
        elif self.phase == 'valid':
            with open(self.config.slake_vqa.test_json_path, 'r') as f:
                return json.load(f)
        else:
            raise ValueError('Invalid phase')

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-abbv')
class SlakeVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(self.config.slake_vqa.train_json_path, 'r') as f:
                data_train = json.load(f)
            with open(self.config.slake_vqa.valid_json_path, 'r') as f:
                data_valid = json.load(f)
            data =  data_train + data_valid

            dset = 'slake-vqa'
            with open(f"./reasoning_datasets/{dset}_train.json", "r") as file:
                comcts = json.load(file)
            with open(f"./reasoning_datasets/{dset}_valid.json", "r") as file:
                comcts = comcts + json.load(file)
            with open(SLAKE_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                self_imp = json.load(file)

            comcts = [{
                'image_path': x['image_path'],
                'question': x['question'].lower(),
            } for x in comcts]

            print(f"{len(comcts)=}")

            self_imp = [{
                'image_path': x['image_path'],
                'question': x['question'].lower(),
            } for x in self_imp]

            print(f"{len(self_imp)=}")

            abbv = set()
            for x in comcts + self_imp:
                abbv.add((x['image_path'], x['question']))

            print(f"{len(abbv)=}")

            print(f"{len(data)=}")

            data_new = []

            for x in data:
                if (os.path.join(self.config.slake_vqa.image_main_path, x["img_name"]), str(x["question"]).lower()) in abbv:
                    data_new.append(x)     

            print(f"{len(data_new)=}")
            return data_new

        elif self.phase == 'valid':
            with open(self.config.slake_vqa.test_json_path, 'r') as f:
                data = json.load(f)
            return data
        else:
            raise ValueError('Invalid phase')
        
                

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS')
class SlakeVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(SLAKE_REASONING_TRAIN_FILE, 'r') as f:
                return json.load(f)

        elif self.phase == 'valid':
            with open(SLAKE_REASONING_VALID_FILE, 'r') as f:
                return json.load(f)
        else:
            raise ValueError('Invalid phase')

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-SelfImprove')
class SlakeVQADataset(ImageTextDataset):
    def load_data(self):
        with open(self.config.slake_vqa.train_json_path, 'r') as f:
            data_train = json.load(f)
        with open(self.config.slake_vqa.valid_json_path, 'r') as f:
            data_valid = json.load(f)
        data = data_train + data_valid

        with open(SLAKE_REASONING_TRAIN_FILE, 'r') as f:
            data_reasoning = json.load(f)
        with open(SLAKE_REASONING_VALID_FILE, 'r') as f:
            data_reasoning += json.load(f)
        
        reasoning_pairs = {(x['image_path'], x['question'].lower().strip()) for x in data_reasoning}

        new_data = []
        print(f'data before cleaning {len(data)}')
        for x in data:
            if (os.path.join(self.config.slake_vqa.image_main_path, x["img_name"]), x["question"].lower().strip()) in reasoning_pairs:
                continue 
            new_data.append(x)

        print(f'data before cleaning {len(new_data)}')
        return new_data

    def get_item_data(self, idx):
        return get_direct_answer_item_data(self.config, self.data, idx)



@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp')
class SlakeVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(SLAKE_REASONING_TRAIN_FILE, 'r') as f:
                data1 =  json.load(f)

            with open(SLAKE_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                data2 = json.load(file)
                for x in data2:
                    x['image_path'] = x['img_path']
                    for key_ in TO_BE_DELETED_KEYS:
                        del x[key_]

            return data1 + data2 

        elif self.phase == 'valid':
            with open(SLAKE_REASONING_VALID_FILE, 'r') as f:
                return json.load(f)
        else:
            raise ValueError('Invalid phase')

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect')
class SlakeVQADataset(ImageTextDataset):
    def load_data(self):
        if self.phase == 'train':
            with open(SLAKE_REASONING_TRAIN_FILE, 'r') as f:
                data1 =  json.load(f)

            with open(SLAKE_SELFIMP_PROCESSED_FILE.format(model_name=model_name_converter[self.config.LLM.model_name]), "r") as file:
                data2 = json.load(file)
                data2 = [x for x in data2 if x['is_correct']==1]
                for x in data2:
                    x['image_path'] = x['img_path']
                    for key_ in TO_BE_DELETED_KEYS:
                        del x[key_]

            return data1 + data2 

        elif self.phase == 'valid':
            with open(SLAKE_REASONING_VALID_FILE, 'r') as f:
                return json.load(f)
        else:
            raise ValueError('Invalid phase')

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-dtw_distances_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'dtw_distances', 0.5)
    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-dtw_distances_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'dtw_distances', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-frechet_distances_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'frechet_distances', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-frechet_distances_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'frechet_distances', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_10_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_10_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_20_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_frdist_distances_ncluster_20_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_frdist_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_25')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_10', 0.25)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)




@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_20_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_20_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'Kmedoid_dtw_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_10_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)

@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_10_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_20_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)

@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_frdist_distances_ncluster_20_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_frdist_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_10_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_10', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)



@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_10_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_10', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_20_ratio_50')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_20', 0.5)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


@DatasetRegistry.register('slake-vqa-COMCTS-SelfImp-iscorrect-SpectralClustering_dtw_distances_ncluster_20_ratio_80')
class SLAKEVQADataset(ImageTextDataset):
    def load_data(self):
        return load_data_(self.config, self.phase, 'SpectralClustering_dtw_distances_ncluster_20', 0.8)

    def get_item_data(self, idx):
        return get_reasoning_item_data(self.config, self.data, idx)


