import json
import sys
import logging, random
import datetime
from PIL import Image
import os

from Benchmarks.TextVQA.Eval import TextVQAAccuracyEvaluator

class TextVQADataset:
    def __init__(self, imgs_dir: str, questions_json_path: str, ocr_json_path: str, num_samples: int = None):
        self.imgs_dir = imgs_dir
        self.questions = json.load(open(questions_json_path))['data']
        self.ocrs = json.load(open(ocr_json_path))['data']
        self._create_ocr_index()
        self.num_samples = num_samples

        if num_samples:
            self.questions = self.questions[:num_samples]
            self.ocrs = self.ocrs[:num_samples]

    def _create_ocr_index(self):
        self.ocr_index = {item['image_id']: item for item in self.ocrs}

    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        ques_item = self.questions[idx]
        ocr_item = self.ocr_index[ques_item['image_id']]
        return {
            "image_id": ques_item['image_id'],
            "image": Image.open(f"{self.imgs_dir}/{ques_item['image_id']}.jpg").convert("RGB"),
            "question_id": ques_item['question_id'],
            "question": f"{ques_item['question']}",
            # "question": f"{ques_item['question']}\nOCR tokens: {ocr_item['ocr_tokens']}",
            "answers": ques_item["answers"]
        }

class TextVQA:
    def __init__(self, questions_json_path: str = None):
        self.dataset = {}
        self.questions = {}
        self.qa = {} 
        self.qq = {} 
        self.ia = {} 

        if questions_json_path is not None:
            print('\n🚀 Loading TextVQA answers and questions into memory')
            time_t = datetime.datetime.now(datetime.timezone.utc)
            dataset = json.load(open(questions_json_path, 'r'))
            print(datetime.datetime.now(datetime.timezone.utc) - time_t)
            self.dataset = dataset
            self.createIndex()

    def createIndex(self):
        print('\n🚀 Creating index')
        data_items = self.dataset['data']
        
        ia = {item['image_id']: [] for item in data_items}
        qa = {}
        qq = {}

        for item in data_items:
            question_id = item['question_id']
            image_id = item['image_id']
            
            qa[question_id] = item['answers']

            question = {
                'question_id': question_id,
                'image_id': image_id,
                'question': item['question']
            }
            qq[question_id] = question
            
            ia[image_id].append({
                'question_id': question_id,
                'answers': item['answers']
            })

        print('✅ Index created!')
        self.qa = qa
        self.qq = qq
        self.ia = ia

    def info(self):
        for key, value in self.dataset.items():
            if key != 'data':  
                print(f'{key}: {value}')

    def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
        imgIds = imgIds if isinstance(imgIds, list) else [imgIds]
        return [item['question_id'] for item in self.dataset['data']]

    def loadQA(self, ids=[]):
        if isinstance(ids, list):
            return [{'question_id': qid, 'answers': self.qa[qid]} for qid in ids if qid in self.qa]
        elif isinstance(ids, int):
            return [{'question_id': ids, 'answers': self.qa[ids]}] if ids in self.qa else []

    def showQA(self, anns):
        if not anns:
            return
        for ann in anns:
            quesId = ann['question_id']
            print(f"Question: {self.qq[quesId]['question']}")
            for i, ans in enumerate(ann['answers']):
                print(f"Answer {i+1}: {ans}")

    def loadRes(self, resFile, quesFile=None):
        res = TextVQA()
        res.dataset = {
            'dataset_type': self.dataset.get('dataset_type', 'val'),
            'dataset_name': self.dataset.get('dataset_name', 'textvqa'),
            'dataset_version': self.dataset.get('dataset_version', '0.5.1')
        }

        print('\n🚀 Loading and preparing results')
        time_t = datetime.datetime.now(datetime.timezone.utc)
        anns = json.load(open(resFile))
        assert isinstance(anns, list), 'results is not an array of objects'

        annsQuesIds = [ann['question_id'] for ann in anns]
        assert set(annsQuesIds) == set(self.getQuesIds()), (
            'Results do not match TextVQA set. '
            'Missing predictions or unknown question IDs.'
            f"set(annsQuesIds) number: {len(set(annsQuesIds))}, set(self.getQuesIds()) number: {len(set(self.getQuesIds()))}"
        )

        res_data = []
        for ann in anns:
            quesId = ann['question_id']
            result_ann = {
                'question_id': quesId,
                'answer': ann['answer'],
                'image_id': self.qq[quesId]['image_id'],
                'question': self.qq[quesId]['question']
            }
            res_data.append(result_ann)

        print(f'✅ DONE (t={(datetime.datetime.now(datetime.timezone.utc) - time_t).total_seconds():.2f}s)')
        res.dataset['data'] = res_data
        
        res.qa = {item['question_id']: item['answer'] for item in res_data}
        res.qq = {qid: self.qq[qid] for qid in self.qq.keys()}
        
        return res

class TextVQAEval:
    def __init__(self, textqa, textqaRes, n=2):
        self.n = n
        self.accuracy = {}
        self.evalQA = {}
        self.textqa = textqa
        self.textqaRes = textqaRes
        self.params = {'question_id': textqa.getQuesIds()}
        self.evaluator = TextVQAAccuracyEvaluator()

    def evaluate(self, quesIds=None):
        if quesIds is None:
            quesIds = self.params['question_id']
        
        gts = {qid: self.textqa.qa[qid] for qid in quesIds}
        res = {qid: self.textqaRes.qa[qid] for qid in quesIds}

        print("\n🚀 Computing TextVQA accuracy")
        
        pred_list = []
        for quesId in quesIds:
            gt_answers = gts[quesId] 
            pred_answer = res[quesId]
            
            pred_list.append({
                'question_id': quesId,
                'gt_answers': gt_answers,
                'pred_answer': pred_answer
            })

        overall_accuracy = self.evaluator.eval_pred_list(pred_list)

        for entry in pred_list:
            quesId = entry['question_id']
            unique_answer_scores = self.evaluator._compute_answer_scores(entry['gt_answers'])
            pred_answer = self.evaluator.answer_processor(entry['pred_answer'])
            score = unique_answer_scores.get(pred_answer, 0.0)

            self.setEvalQA(quesId, score * 100)  

        self.setAccuracy(overall_accuracy)
        print("\n✅ Computing TextVQA accuracy complete")

    def setAccuracy(self, overall_accuracy):
        self.accuracy['overall'] = round(100 * overall_accuracy, self.n)

    def setEvalQA(self, quesId, acc):
        self.evalQA[quesId] = round(acc, self.n)

    def updateProgress(self, progress):
        barLength = 20
        status = ""
        if not isinstance(progress, float):
            progress = 0
            status = "error: progress var must be float\r\n"
        if progress < 0:
            progress = 0
            status = "Halt...\r\n"
        if progress >= 1:
            progress = 1
            status = "Done...\r\n"
        block = int(round(barLength * progress))
        text = f"\r✅ Finished Percent: [{'#' * block + '-' * (barLength - block)}] {int(progress * 100)}% {status}"
        sys.stdout.write(text)
        sys.stdout.flush()


def load_textvqa_dataset(img_dir: str, questions_json_path: str, ocr_json_path: str, num_samples: int = None):
    try:
        print("🚀 Loading TextVQA dataset")
        dataset = TextVQADataset(
            imgs_dir=img_dir,
            questions_json_path=questions_json_path,
            ocr_json_path=ocr_json_path,
            num_samples=num_samples
        )
        print(f"✅ Loading {len(dataset)} samples complete")
        return dataset
    except Exception as e:
        logging.error(f"⛔ Error loading dataset: {e}")
        return None

def evaluate_textvqa_results(
    data_dir,
    result_dir,
    filename_suffix,
    num_samples,
    model_name,
    data_subtype='val',
    n_show_error=2,
):
    question_path = f'{data_dir}/questions/TextVQA_0.5.1_{data_subtype}.json'
    
    if not num_samples:
        ques_file = question_path
    else:
        ques_file = f'{data_dir}/eval/questions/TextVQA_0.5.1_{data_subtype}_{num_samples}.json'
        if not os.path.exists(ques_file):
            os.makedirs(os.path.dirname(ques_file), exist_ok=True)
            with open(question_path) as f:
                questions_ori = json.load(f)
                questions = {
                    'dataset_type': questions_ori['dataset_type'],
                    'dataset_name': questions_ori['dataset_name'],
                    'dataset_version': questions_ori['dataset_version'],
                    'data': questions_ori['data'][:num_samples]
                }
            with open(ques_file, 'w') as f:
                json.dump(questions, f)
        
    img_dir = f'{data_dir}/imgs_train_val/train_images/'
    res_file = f'{result_dir}/Inference/{model_name}_TextVQA{filename_suffix}'

    textqa = TextVQA(ques_file)
    textqa_res = textqa.loadRes(res_file, ques_file)

    textqa_eval = TextVQAEval(textqa, textqa_res, n=n_show_error)
    textqa_eval.evaluate()

    print("\n⭐ Overall Accuracy is: %.02f\n" % textqa_eval.accuracy['overall'])

    evals = [qid for qid, score in textqa_eval.evalQA.items() if score < 35]
    if evals:
        print('\nExample of low-accuracy answers:')
        random_eval = random.choice(evals)
        random_ann = textqa.loadQA(random_eval)
        textqa.showQA(random_ann)
        print('\nGenerated answer (accuracy %.02f)' % textqa_eval.evalQA[random_eval])
        print("Answer: %s\n" % textqa_res.qa[random_eval])

    output_data = {"accuracy": textqa_eval.accuracy, "evalQA": textqa_eval.evalQA}
    json.dump(output_data, open(f"{result_dir}/Eval/{model_name}_TextVQA_results.json", "w"), indent=4)
    return textqa_eval.accuracy