import os
from typing import Dict, List, Optional, Sequence
import json
import mmengine
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger, print_log
from terminaltables import AsciiTable
import numpy as np
from embodiedqa.registry import METRICS
from embodiedqa.structures import EulerDepthInstance3DBoxes
from collections import defaultdict

@METRICS.register_module()
class ScanQAMetric(BaseMetric):
    def __init__(self,
                 topk: List[float] = [1, 10],
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None,
                 format_only=False,
                 extra_pred_scores_suffix = None,
                 question_type_analysis=True,
                 result_dir='') -> None:
        super(ScanQAMetric, self).__init__(prefix=prefix,
                                            collect_device=collect_device)
        self.topk = topk
        self.prefix = prefix
        self.format_only = format_only
        self.extra_pred_scores_suffix =  extra_pred_scores_suffix
        self.result_dir = result_dir
        self.question_type_analysis = question_type_analysis

    def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
        """Process one batch of data samples and predictions."""
        for data_sample in data_samples:
            # The data_sample is a dictionary, so we must use key access ['key']
            eval_ann_info = data_sample['eval_ann_info']
            cpu_pred = dict(pred_scores=data_sample['pred_scores'].to('cpu'))
            if 'pid_weights' in data_sample and data_sample['pid_weights'] is not None:
                cpu_pred['pid_weights'] = data_sample['pid_weights']
            else:
                cpu_pred['pid_weights'] = None
            if self.extra_pred_scores_suffix is not None:
                for suffix in self.extra_pred_scores_suffix:
                    # Check for the key's existence in the dictionary
                    if f'pred_scores{suffix}' in data_sample:
                        cpu_pred[f'pred_scores{suffix}'] = data_sample[f'pred_scores{suffix}'].to('cpu')
            
            self.results.append((eval_ann_info, cpu_pred))

    def ground_eval(self, gt_annos, pred_annos, logger=None,type_suffix=''):
        # This method remains unchanged
        assert len(pred_annos) == len(gt_annos)
        pred = {}
        gt = {}
        metric_types = ['EM@'+str(k) for k in self.topk]
        if self.question_type_analysis:
            metric_types +=  ['what','where','how','is','which','others']
        metric_types = [t + type_suffix for t in metric_types]
        for metric_type in metric_types:
            pred.update({metric_type: 0})
            gt.update({metric_type: 0})
        for sample_id in range(len(pred_annos)):
            pred_anno = pred_annos[sample_id]
            gt_anno = gt_annos[sample_id]
            pred_scores = pred_anno['pred_scores']
            gt_answer_labels = gt_anno['gt_answer_labels']
            top_index = pred_scores.argsort(dim=-1, descending=True)[:max(self.topk)]
            if self.question_type_analysis:
                question_type = gt_anno['question_type']
                gt[question_type+type_suffix] += 1
                pred[question_type+type_suffix] += int((gt_answer_labels[top_index[:1]]).any())
            for k in self.topk:
                found = int((gt_answer_labels[top_index[:k]]).any())
                gt['EM@' + str(k) + type_suffix] += 1
                pred['EM@' + str(k) + type_suffix] += found
        header = ['Type']
        header.extend(metric_types)
        ret_dict = {}
        table_columns = [['results']]
        for metric_type in metric_types:
            value = pred[metric_type] / max(gt[metric_type], 1)
            ret_dict[metric_type] = value
            table_columns.append([f'{value:.4f}'])
        table_data = [header]
        table_rows = list(zip(*table_columns))
        table_data += table_rows
        table = AsciiTable(table_data)
        table.inner_footing_row_border = True
        print_log('\n' + table.table, logger=logger)
        return ret_dict

    def compute_metrics(self, results: list) -> Dict[str, float]:
        logger: MMLogger = MMLogger.get_current_instance()

        annotations, preds = zip(*results)
        ret_dict = {}
        
        results_list = []
        answer_candidates = self.dataset_meta.get('answer_candidates')
        for i, pred in enumerate(preds):
            gt_answer_id = np.where(annotations[i]['gt_answer_labels']==1)[0]
            pred_scores = pred['pred_scores']
            top10_index = pred_scores.argsort(dim=-1, descending=True)[:10]
            result = dict(question=annotations[i]['question'],
                          question_id=annotations[i]['question_id'],
                          answer_top10=[answer_candidates[k] for k in top10_index],
                          scene_id=annotations[i]['scan_id'].split('/')[-1],
                          )
            if not self.format_only:
                result['gt_answer']=[answer_candidates[k] for k in gt_answer_id]                
            results_list.append(result)

        if self.result_dir and not os.path.exists(self.result_dir):
            os.makedirs(self.result_dir)

        with open(os.path.join(self.result_dir, 'test_results.json'), 'w') as f:
            json.dump(results_list, f, indent=4)
            
        if self.format_only:
            return ret_dict

        ret_dict = self.ground_eval(annotations, preds, logger=logger)
        
        if self.extra_pred_scores_suffix is not None:
            for suffix in self.extra_pred_scores_suffix:
                if f'pred_scores{suffix}' in preds[0]:
                    ret_dict.update(self.ground_eval(annotations, preds, logger=logger ,type_suffix=suffix))
        
        self._generate_dam_report(results)
        return ret_dict 

    def _generate_dam_report(self, results: list):
        """Processes the logged DAM weights and saves them to a JSON file."""
        annotations, preds = zip(*results)

        atom_names = ["U_Point", "U_Image", "Redundancy", "Synergy"]
        report_data = defaultdict(lambda: defaultdict(list))
        major_question_types = ['what', 'where', 'how', 'is', 'which']

        for i, pred in enumerate(preds):
            if pred.get('pid_weights') is None:
                continue

            weights_dict = pred['pid_weights']
            ann = annotations[i]
            question_id = ann.get('question_id', 'unknown_id')
            question = ann.get('question', '').strip().lower()

            question_type = question.split(' ')[0]
            target_category = question_type if question_type in major_question_types else 'others'
            
            for atom_name, weight in weights_dict.items():
                report_data[target_category][atom_name].append((question_id, weight))
                
        if not report_data:
            print("No DAM weights were found in any samples. Skipping report generation for ScanQA.")
            return

        final_report = { 'atom_names': atom_names, 'weights_by_question_type': report_data }
        report_filename = 'dam_weights_report.json'
        output_path = os.path.join(self.result_dir, report_filename)
        with open(output_path, 'w') as f:
            json.dump(final_report, f, indent=4)
        print(f"Successfully generated DAM weights analysis report at: {output_path}")