import json
import logging, random
import datetime, copy
import os
import itertools
from typing import List
from PIL import Image
import pandas as pd
from tqdm import tqdm
from Benchmarks.ChartQA.Eval import _compute_score, _remove_strings

class ChartQADataset:
    def __init__(self, data_dir: str, json_name: str = "test_human.json", sub_dir: str = "test", num_samples: int = None):
        self.data_dir = data_dir

        self.tables_dir = os.path.join(data_dir, sub_dir, "tables")
        self.png_dir = os.path.join(data_dir, sub_dir, "png")
        self.json_path = os.path.join(data_dir, sub_dir, json_name)
        
        with open(self.json_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        
        if num_samples:
            self.data = self.data[:num_samples]

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data_item = self.data[idx]
        file_id = data_item['imgname'].split('.')[0]
        return {
            "image_id": file_id,
            "image": Image.open(f"{self.png_dir}/{data_item['imgname']}").convert("RGB"),
            "question_id": file_id,
            "question": data_item['query'],
            "answers": data_item['label']
        }

class ChartQA:
    def __init__(self, data_dir: str = None, json_name: str = "test_human.json"):
        self.dataset = {}
        self.qa = {} 
        self.qq = {}  
        self.ia = {}  

        if data_dir is not None:
            print('\n🚀 Loading ChartQA data into memory')
            time_t = datetime.datetime.now(datetime.timezone.utc)
            
            self.data_dir = data_dir
            self.tables_dir = os.path.join(data_dir, "test", "tables")
            self.png_dir = os.path.join(data_dir, "test", "png")
            self.json_path = os.path.join(data_dir, "test", json_name)
            
            try:
                with open(self.json_path, 'r', encoding='utf-8') as f:
                    self.dataset['data'] = json.load(f)  
            except Exception:
                self.dataset['data'] = []
            
            print(datetime.datetime.now(datetime.timezone.utc) - time_t)
            self.createIndex()

    def createIndex(self):
        print('\n🚀 Creating index')
        try:
            data_items = json.load(open(self.json_path))
            file_ids = [item['imgname'].split('.')[0] for item in data_items]
        except Exception:
            file_ids = []
        
        qa = {}
        qq = {}
        ia = {}

        for file_id in file_ids:
            try:
                table_path = os.path.join(self.tables_dir, f"{file_id}.csv")
                qa[file_id] = self._load_csv_numbers(table_path)
            except Exception:
                qa[file_id] = []

            question_info = {
                'file_id': file_id,
                'table_path': os.path.join(self.tables_dir, f"{file_id}.csv"),
                'image_path': os.path.join(self.png_dir, f"{file_id}.png")
            }
            
            imgname = f"{file_id}.png"
            for item in self.dataset.get('data', []):
                if item.get("imgname") == imgname:
                    question_info.update({
                        'question': item.get("query", ""),
                        'label': item.get("label", "")
                    })
                    break
            
            qq[file_id] = question_info
            
            ia[file_id] = [{
                'file_id': file_id,
                'table_numbers': qa[file_id]
            }]

        print('✅ Index created!')
        self.qa = qa
        self.qq = qq
        self.ia = ia

    def _load_csv_numbers(self, path: str) -> List[float]:
        """Load and extract numbers from CSV file"""
        df = pd.read_csv(path)
        values = df.values
        flattened = list(itertools.chain.from_iterable(values))
        return _remove_strings(flattened)

    def info(self):
        for key, value in self.dataset.items():
            if key != 'data':
                print(f'{key}: {value}')

    def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
        return list(self.qa.keys())

    def loadQA(self, ids=[]):
        if isinstance(ids, list):
            return [{'file_id': fid, 'table_numbers': self.qa[fid]} for fid in ids if fid in self.qa]
        elif isinstance(ids, (int, str)):
            return [{'file_id': ids, 'table_numbers': self.qa[ids]}] if ids in self.qa else []

    def showQA(self, anns):
        if not anns:
            return
        for ann in anns:
            fid = ann['file_id']
            print(f"File: {fid}")
            if fid in self.qq and 'question' in self.qq[fid]:
                print(f"Question: {self.qq[fid]['question']}")
            print(f"Table numbers: {ann['table_numbers'][:10]}...")  

    def loadRes(self, resFile, quesFile=None):
        res = ChartQA()
        res.dataset = copy.deepcopy(self.dataset)

        print('\n🚀 Loading and preparing ChartQA results')
        time_t = datetime.datetime.now(datetime.timezone.utc)
        
        anns = json.load(open(resFile))
        assert isinstance(anns, list), 'results is not an array of objects'

        annsQuesIds = [ann['question_id'] for ann in anns]
        assert set(annsQuesIds) == set(self.getQuesIds()), (
            'Results do not match ChartQA set. '
            'Missing predictions or unknown question IDs.'
            f"set(annsQuesIds) number: {len(set(annsQuesIds))}, set(self.getQuesIds()) number: {len(set(self.getQuesIds()))}"
        )

        res_data = []
        for ann in anns:
            quesId = ann['question_id']
            result_ann = {
                'question_id': quesId,
                'answer': ann['answer'],
                'file_id': quesId 
            }
            res_data.append(result_ann)

        print(f'✅ DONE (t={(datetime.datetime.now(datetime.timezone.utc) - time_t).total_seconds():.2f}s)')
        res.dataset['data'] = res_data

        res.qa = {item['question_id']: item['answer'] for item in res_data}
        res.qq = {qid: self.qq[qid] for qid in self.qq.keys() if qid in self.qq}
        
        return res

class ChartQAEval:
    def __init__(self, chartqa: ChartQA, chartqaRes, n=2):
        self.n = n
        self.accuracy = {}
        self.evalQA = {}
        self.chartqa = chartqa
        self.chartqaRes = chartqaRes
        self.params = {'file_id': chartqa.getQuesIds()}

    def evaluate(self, fileIds=None):
        if fileIds is None:
            fileIds = self.params['file_id']
        
        gts = {fid: self.chartqa.qa[fid] for fid in fileIds}
        res = {fid: self.chartqaRes.qa[fid] for fid in fileIds}

        print("\n🚀 Computing ChartQA accuracy")
        
        scores = []
        for file_id in tqdm(fileIds):
            gt_nums = gts[file_id] 
            pred_answer = res[file_id] 
            
            try:
                pred_nums = _remove_strings([pred_answer])
                if len(pred_nums) == 0:
                    score = 0.0
                else:
                    score = _compute_score(gt_nums, pred_nums)
            except Exception:
                score = 0.0
            
            scores.append(score)
            self.evalQA[file_id] = round(100 * score, self.n)

        self.setAccuracy(scores)
        print("\n✅ Computing ChartQA accuracy complete")

    def setAccuracy(self, scores):
        overall = round(100 * (sum(scores) / max(1, len(scores))), self.n)
        self.accuracy['overall'] = overall


def load_chartqa_dataset(data_dir: str = "Benchmarks/ChartQA/ChartQA_Dataset", 
                        json_name: str = "test_human.json", 
                        num_samples: int = None):
    try:
        print("🚀 Loading ChartQA dataset")
        dataset = ChartQADataset(
            data_dir=data_dir,
            json_name=json_name,
            num_samples=num_samples
        )
        print(f"✅ Loading {len(dataset)} samples complete")
        return dataset
    except Exception as e:
        logging.error(f"⛔ Error loading ChartQA dataset: {e}")
        return None

def evaluate_chartqa_results(
    data_dir,
    result_dir,
    filename_suffix,
    num_samples,
    model_name,
    json_name="test_human.json",
    n_show_error=2,
):
    question_path = f'{data_dir}/test/{json_name}'
    
    if not num_samples:
        ques_file = question_path
    else:
        base_name = os.path.splitext(json_name)[0]
        ques_file = f'{data_dir}/eval/test/{base_name}_{num_samples}.json'
        if not os.path.exists(ques_file):
            os.makedirs(os.path.dirname(ques_file), exist_ok=True)
            with open(question_path) as f:
                questions_ori = json.load(f)
                questions = questions_ori[:num_samples] 
            with open(ques_file, 'w') as f:
                json.dump(questions, f)
    
    os.makedirs(result_dir, exist_ok=True)

    res_file = f'{result_dir}/Inference/{model_name}_ChartQA{filename_suffix}'
    if num_samples:
        chartqa = ChartQA()
        chartqa.data_dir = data_dir
        chartqa.tables_dir = os.path.join(data_dir, "test", "tables")
        chartqa.png_dir = os.path.join(data_dir, "test", "png")
        chartqa.json_path = ques_file
        
        try:
            with open(ques_file, 'r', encoding='utf-8') as f:
                chartqa.dataset = {'data': json.load(f)}
        except Exception:
            chartqa.dataset = {'data': []}
        
        chartqa.createIndex()
    else:
        chartqa = ChartQA(data_dir, json_name=json_name)
    
    chartqa_res = chartqa.loadRes(res_file)

    chartqa_eval = ChartQAEval(chartqa, chartqa_res, n=n_show_error)
    chartqa_eval.evaluate()

    print("\n⭐ Overall Accuracy is: %.02f\n" % chartqa_eval.accuracy['overall'])

    evals = [fid for fid, score in chartqa_eval.evalQA.items() if score < 35]
    if evals:
        print('\nExample of low-accuracy file:')
        random_eval = random.choice(evals)
        random_ann = chartqa.loadQA(random_eval)
        chartqa.showQA(random_ann)
        print('\nGenerated table score: %.02f' % chartqa_eval.evalQA[random_eval])

    output_data = {"accuracy": chartqa_eval.accuracy, "evalQA": chartqa_eval.evalQA}
    json.dump(output_data, open(f"{result_dir}/Eval/{model_name}_ChartQA_results.json", "w"), indent=4)
    return chartqa_eval.accuracy
