"""
Heatmap curriculum generator
"""


import os
import json
import copy
from typing import List, Dict, Union, Tuple
from metachart import MetaGenerator


class HeatmapGenerator(MetaGenerator):
    def __init__(self, args, chart_id):
        super().__init__(args)
        self.chart_type = args.chart_type
        self.chart_id = chart_id
        self.all_qa_data_list = []
        self.round_num = 2
        self.qa_idx = 0
        self.detailed_reasoning = args.detailed_reasoning
        self.force_ground = args.force_ground
    


    ############################################################
    #   One-step Operator: h(list[Bar] | list[v]) → list[v]
    ############################################################
    
    def _one_step_statistics(self, chart_metadata: Dict):
        """
        Statistics: sum, mean, median, count for rows and columns
        """
        x_axis_title = chart_metadata['x_label']
        y_axis_title = chart_metadata['y_label']
        heatmap_category_singular = chart_metadata['heatmap_category']['singular']
        heatmap_category_plural = chart_metadata['heatmap_category']['plural']
        
        target_heatmap_data = copy.deepcopy(chart_metadata["heatmap_data"])  # matrix
        target_x_labels = copy.deepcopy(chart_metadata["x_labels"])
        target_y_labels = copy.deepcopy(chart_metadata["y_labels"])
        
        num_rows = len(target_heatmap_data)
        num_cols = len(target_heatmap_data[0]) if target_heatmap_data else 0
        
        # All cell indices for masking
        all_cell_indices = [(r, c) for r in range(num_rows) for c in range(num_cols)]
        
        # Row-wise calculations
        row_sums = [sum(row) for row in target_heatmap_data]
        row_means = [sum(row) / len(row) for row in target_heatmap_data]
        row_medians = []
        row_median_indices = []
        
        for row_idx, row in enumerate(target_heatmap_data):
            median_val, median_cols = self._compute_data_median(row)
            row_medians.append(median_val)
            # Convert column indices to cell indices
            median_cells = [(row_idx, col) for col in median_cols]
            row_median_indices.append(median_cells)
        
        # Column-wise calculations
        column_data = [[row[col] for row in target_heatmap_data] for col in range(num_cols)]
        col_sums = [sum(col) for col in column_data]
        col_means = [sum(col) / len(col) for col in column_data]
        col_medians = []
        col_median_indices = []
        
        for col_idx, col in enumerate(column_data):
            median_val, median_rows = self._compute_data_median(col)
            col_medians.append(median_val)
            # Convert row indices to cell indices
            median_cells = [(row, col_idx) for row in median_rows]
            col_median_indices.append(median_cells)
        
        # Overall statistics
        all_values = [val for row in target_heatmap_data for val in row]
        overall_sum = sum(all_values)
        overall_mean = overall_sum / len(all_values)
        overall_median, overall_median_flat_idx = self._compute_data_median(all_values)
        
        # Convert flat indices to cell coordinates for overall median
        overall_median_cells = []
        for flat_idx in overall_median_flat_idx:
            row_idx = flat_idx // num_cols
            col_idx = flat_idx % num_cols
            overall_median_cells.append((row_idx, col_idx))
        
        # Generate detailed reasoning strings
        read_all_reason = '\n'.join([
            f"* Row {target_y_labels[r]}: {', '.join([f'{target_x_labels[c]}={target_heatmap_data[r][c]}' for c in range(num_cols)])}"
            for r in range(num_rows)
        ]).strip()

        # Detailed median reasoning
        sorted_all_values = sorted(all_values)
        sorted_values_str = ", ".join([str(val) for val in sorted_all_values])

        # Detailed median reasoning for rows
        row_median_details = []
        for r in range(num_rows):
            row_vals = target_heatmap_data[r]
            sorted_row_vals = sorted(row_vals)
            sorted_row_str = ", ".join([str(val) for val in sorted_row_vals])
            row_median_details.append(f"* {target_y_labels[r]}: Sort [{sorted_row_str}] → Median: {row_medians[r]}")
        row_median_reason = '\n'.join(row_median_details)

        # Detailed median reasoning for columns  
        col_median_details = []
        for c in range(num_cols):
            col_vals = [target_heatmap_data[r][c] for r in range(num_rows)]
            sorted_col_vals = sorted(col_vals)
            sorted_col_str = ", ".join([str(val) for val in sorted_col_vals])
            col_median_details.append(f"* {target_x_labels[c]}: Sort [{sorted_col_str}] → Median: {col_medians[c]}")
        col_median_reason = '\n'.join(col_median_details)
        
        # Row sum calculations
        row_sum_calculations = []
        for r in range(num_rows):
            calc = f"{' + '.join([str(val) for val in target_heatmap_data[r]])} = {row_sums[r]}"
            row_sum_calculations.append(f"* {target_y_labels[r]}: {calc}")
        row_sum_reason = '\n'.join(row_sum_calculations)
        
        # Column sum calculations  
        col_sum_calculations = []
        for c in range(num_cols):
            col_vals = [target_heatmap_data[r][c] for r in range(num_rows)]
            calc = f"{' + '.join([str(val) for val in col_vals])} = {col_sums[c]}"
            col_sum_calculations.append(f"* {target_x_labels[c]}: {calc}")
        col_sum_reason = '\n'.join(col_sum_calculations)
        
        # Overall sum calculation
        overall_sum_reason = f"{' + '.join([str(val) for val in all_values])} = {overall_sum}"
        
        # Row mean calculations
        row_mean_calculations = []
        for r in range(num_rows):
            calc = f"({' + '.join([str(val) for val in target_heatmap_data[r]])}) / {num_cols} = {row_sums[r]} / {num_cols} = {row_means[r]}"
            row_mean_calculations.append(f"* {target_y_labels[r]}: {calc}")
        row_mean_reason = '\n'.join(row_mean_calculations)
        
        # Column mean calculations
        col_mean_calculations = []
        for c in range(num_cols):
            col_vals = [target_heatmap_data[r][c] for r in range(num_rows)]
            calc = f"({' + '.join([str(val) for val in col_vals])}) / {num_rows} = {col_sums[c]} / {num_rows} = {col_means[c]}"
            col_mean_calculations.append(f"* {target_x_labels[c]}: {calc}")
        col_mean_reason = '\n'.join(col_mean_calculations)
        
        # Overall mean calculation
        overall_mean_reason = f"({' + '.join([str(val) for val in all_values])}) / {len(all_values)} = {overall_sum} / {len(all_values)} = {overall_mean}"

        # Chart QA Pool
        easy_qa_pool = {
            "one_step__statistics__row_sum": {
                "question": [
                    f"What is the sum of {heatmap_category_plural} for each {y_axis_title} in this heatmap? Please provide the results in the format '{y_axis_title}: value' for each row.",
                    f"Can you calculate the total {heatmap_category_plural} for each {y_axis_title}? Please list the sum for each row.",
                    f"For each {y_axis_title}, what is the sum of all {heatmap_category_plural} across all {x_axis_title}?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the sum for each {y_axis_title}" + (f":\n{row_sum_reason}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": f"{'\n'.join([f'{target_y_labels[r]}: {row_sums[r]}' for r in range(num_rows)])}",
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "answer": all_cell_indices,
                },
            },
            "one_step__statistics__col_sum": {
                "question": [
                    f"What is the sum of {heatmap_category_plural} for each {x_axis_title} in this heatmap? Please provide the results in the format '{x_axis_title}: value' for each column.",
                    f"Can you calculate the total {heatmap_category_plural} for each {x_axis_title}? Please provide the results in the format '{x_axis_title}: value' for each column.",
                    f"For each {x_axis_title}, what is the sum of all {heatmap_category_plural} across all {y_axis_title}? Please provide the results in the format '{x_axis_title}: value' for each column.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the sum for each {x_axis_title}" + (f":\n{col_sum_reason}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": f"{'\n'.join([f'{target_x_labels[c]}: {col_sums[c]}' for c in range(num_cols)])}",
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "answer": all_cell_indices,
                },
            },
            "one_step__statistics__overall_sum": {
                "question": [
                    f"What is the total sum of all {heatmap_category_plural} in this heatmap?",
                    f"Can you calculate the sum of all {heatmap_category_plural} across all cells in this heatmap?",
                    f"What is the overall total of all {heatmap_category_plural} shown in this heatmap?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to sum all values" + (f":\n{overall_sum_reason}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": overall_sum,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "answer": all_cell_indices,
                },
            },
            "one_step__statistics__row_mean": {
                "question": [
                    f"What is the average {heatmap_category_singular} for each {y_axis_title} in this heatmap? Please round to two decimal places and provide results in the format '{y_axis_title}: value'.",
                    f"Can you calculate the mean {heatmap_category_singular} for each {y_axis_title}? Please round to two decimal places.",
                    f"For each {y_axis_title}, what is the average {heatmap_category_singular} across all {x_axis_title}? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the mean for each {y_axis_title}" + (f":\n{row_mean_reason}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": f"{'\n'.join([f'{target_y_labels[r]}: {row_means[r]}' for r in range(num_rows)])}",
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "answer": all_cell_indices,
                },
            },
            "one_step__statistics__col_mean": {
                "question": [
                    f"What is the average {heatmap_category_singular} for each {x_axis_title} in this heatmap? Please round to two decimal places and provide results in the format '{x_axis_title}: value'.",
                    f"Can you calculate the mean {heatmap_category_singular} for each {x_axis_title}? Please round to two decimal places.",
                    f"For each {x_axis_title}, what is the average {heatmap_category_singular} across all {y_axis_title}? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the mean for each {x_axis_title}" + (f":\n{col_mean_reason}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": f"{'\n'.join([f'{target_x_labels[c]}: {col_means[c]}' for c in range(num_cols)])}",
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "answer": all_cell_indices,
                },
            },
            "one_step__statistics__overall_mean": {
                "question": [
                    f"What is the overall average {heatmap_category_singular} across all cells in this heatmap? Please round to two decimal places.",
                    f"Can you calculate the mean {heatmap_category_singular} for all cells in this heatmap? Please round to two decimal places.",
                    f"What is the average {heatmap_category_singular} across the entire heatmap? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean" + (f":\n{overall_mean_reason}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": overall_mean,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "answer": all_cell_indices,
                },
            },
            "one_step__statistics__row_median": {
                "question": [
                    f"What is the median {heatmap_category_singular} for each {y_axis_title} in this heatmap? Please round to two decimal places and provide results in the format '{y_axis_title}: value'.",
                    f"Can you find the median value of {heatmap_category_plural} for each {y_axis_title}? Please round to two decimal places.",
                    f"For each {y_axis_title}, what is the median {heatmap_category_singular} across all {x_axis_title}? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the median for each {y_axis_title} by sorting values in each row and finding the middle value(s)" + (f":\n{row_median_reason}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": f"{'\n'.join([f'{target_y_labels[r]}: {row_medians[r]}' for r in range(num_rows)])}",
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": [cell for row_cells in row_median_indices for cell in row_cells],
                    "answer": [cell for row_cells in row_median_indices for cell in row_cells],
                },
            },
            "one_step__statistics__col_median": {
                "question": [
                    f"What is the median {heatmap_category_singular} for each {x_axis_title} in this heatmap? Please round to two decimal places and provide results in the format '{x_axis_title}: value'.",
                    f"Can you find the median value of {heatmap_category_plural} for each {x_axis_title}? Please round to two decimal places.",
                    f"For each {x_axis_title}, what is the median {heatmap_category_singular} across all {y_axis_title}? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the median for each {x_axis_title} by sorting values in each column" + (f":\n{col_median_reason}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": f"{'\n'.join([f'{target_x_labels[c]}: {col_medians[c]}' for c in range(num_cols)])}",
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": [cell for col_cells in col_median_indices for cell in col_cells],
                    "answer": [cell for col_cells in col_median_indices for cell in col_cells],
                },
            },
            "one_step__statistics__overall_median": {
                "question": [
                    f"What is the median {heatmap_category_singular} across all cells in this heatmap? Please round to two decimal places.",
                    f"Can you find the median value of all {heatmap_category_plural} in this heatmap? Please round to two decimal places.",
                    f"What is the middle value when all {heatmap_category_plural} in this heatmap are sorted? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the median value across all cells by sorting all values and identifying the middle value(s)" + (f":\n* Sort: {sorted_values_str}\n* Middle values: {overall_median}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": overall_median,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": overall_median_cells,
                    "answer": overall_median_cells,
                },
            },
            "one_step__statistics__count_cells": {
                "question": [
                    f"How many cells are there in this heatmap?",
                    f"What is the total number of {heatmap_category_plural} displayed in this heatmap?",
                    f"How many data points are shown in this heatmap?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to count all cells in this heatmap" + (f": {num_rows} rows \u00D7 {num_cols} columns = {len(all_cell_indices)} cells." if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": "none",
                "answer": len(all_cell_indices),
                "mask": {
                    "step_1": all_cell_indices,
                    "answer": all_cell_indices,
                },
            },
        }

        return easy_qa_pool



    ############################################################
    #                     Two-step Operator
    ############################################################

    def _two_step_statistics(self, chart_metadata: Dict, target_cells: List[tuple], constraint: str):
        """
        Statistics: sum, mean, median, count for selected cells
        """
        x_axis_title = chart_metadata['x_label']
        y_axis_title = chart_metadata['y_label']
        heatmap_category_singular = chart_metadata['heatmap_category']['singular']
        heatmap_category_plural = chart_metadata['heatmap_category']['plural']
        
        target_heatmap_data = copy.deepcopy(chart_metadata["heatmap_data"])
        target_x_labels = copy.deepcopy(chart_metadata["x_labels"])
        target_y_labels = copy.deepcopy(chart_metadata["y_labels"])
        
        # Sort target cells for consistency
        target_cells = sorted(target_cells)
        
        # Extract values from target cells
        target_values = [target_heatmap_data[r][c] for r, c in target_cells]
        
        # Calculate statistics
        target_sum = sum(target_values)
        target_mean = target_sum / len(target_values) if target_values else 0
        target_median, target_median_flat_idx = self._compute_data_median(target_values)
        target_count = len(target_values)
        
        # Convert median flat indices to cell coordinates
        target_median_cells = []
        if target_median_flat_idx:
            for flat_idx in target_median_flat_idx:
                if 0 <= flat_idx < len(target_cells):
                    target_median_cells.append(target_cells[flat_idx])
        
        # Generate detailed reasoning strings
        target_cells_str = ", ".join([f"({target_y_labels[r]}, {target_x_labels[c]})" for r, c in target_cells])
        target_values_str = ", ".join([str(val) for val in target_values])
        target_values_with_pos = ", ".join([f"({target_y_labels[r]}, {target_x_labels[c]})={target_heatmap_data[r][c]}" for r, c in target_cells])
        
        # Sum calculation
        sum_calculation = f"{' + '.join([str(val) for val in target_values])} = {target_sum}"
        
        # Mean calculation
        mean_calculation = f"({' + '.join([str(val) for val in target_values])}) / {len(target_values)} = {target_sum} / {len(target_values)} = {target_mean}"
        
        # Detailed median reasoning
        sorted_target_values = sorted(target_values)
        sorted_target_str = ", ".join([str(val) for val in sorted_target_values])
        median_calculation = f"* Sort: [{sorted_target_str}] → Median: {target_median}"

        # Chart QA Pool
        medium_qa_pool = {
            "two_step__statistics__sum": {
                "question": [
                    f"What is the total sum of {heatmap_category_plural} for cells {constraint}?",
                    f"Can you calculate the sum of {heatmap_category_plural} for cells {constraint}?",
                    f"What is the total of all {heatmap_category_plural} in cells {constraint}?",
                    f"Please compute the sum of {heatmap_category_plural} for cells {constraint}.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to identify the cells {constraint}" + (f": {target_cells_str}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to read their {heatmap_category_plural}" + (f": {target_values_with_pos}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to calculate their sum" + (f": {sum_calculation}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": constraint,
                "answer": target_sum,
                "mask": {
                    "step_1": target_cells,
                    "step_2": target_cells,
                    "step_3": target_cells if self.force_ground else [],
                    "answer": target_cells,
                },
            },
            "two_step__statistics__mean": {
                "question": [
                    f"What is the average {heatmap_category_singular} for cells {constraint}? Please round to two decimal places.",
                    f"Can you calculate the mean {heatmap_category_singular} for cells {constraint}? Please round to two decimal places.",
                    f"What is the mean value of {heatmap_category_plural} in cells {constraint}? Please round to two decimal places.",
                    f"Please compute the average {heatmap_category_singular} for cells {constraint}. Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to identify the cells {constraint}" + (f": {target_cells_str}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to read their {heatmap_category_plural}" + (f": {target_values_with_pos}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to calculate their mean" + (f": {mean_calculation}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": constraint,
                "answer": target_mean,
                "mask": {
                    "step_1": target_cells,
                    "step_2": target_cells,
                    "step_3": target_cells if self.force_ground else [],
                    "answer": target_cells,
                },
            },
            "two_step__statistics__median": {
                "question": [
                    f"What is the median {heatmap_category_singular} for cells {constraint}? Please round to two decimal places.",
                    f"Can you find the median value of {heatmap_category_plural} for cells {constraint}? Please round to two decimal places.",
                    f"What is the middle value of {heatmap_category_plural} in cells {constraint}? Please round to two decimal places.",
                    f"Please compute the median {heatmap_category_singular} for cells {constraint}. Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to identify the cells {constraint}" + (f": {target_cells_str}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to read their {heatmap_category_plural}" + (f": {target_values_with_pos}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to find the median value by sorting these values and identifying the middle value(s)" + (f":\n{median_calculation}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": constraint,
                "answer": target_median,
                "mask": {
                    "step_1": target_cells,
                    "step_2": target_cells,
                    "step_3": target_median_cells,
                    "answer": target_median_cells,
                },
            },
            "two_step__statistics__count": {
                "question": [
                    f"How many cells are there {constraint}?",
                    f"What is the number of cells {constraint}?",
                    f"Can you count the cells {constraint}?",
                    f"How many {heatmap_category_plural} are in cells {constraint}?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to identify the cells {constraint}" + (f": {target_cells_str}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to count these cells" + (f": {target_count}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": constraint,
                "answer": target_count,
                "mask": {
                    "step_1": target_cells,
                    "step_2": target_cells if self.force_ground else [],
                    "answer": target_cells,
                },
            },
        }

        return medium_qa_pool



    ############################################################
    #                     Multi-step Operator
    ############################################################

    def _multi_step_threshold(self, chart_metadata: Dict):
        """
        Multi-step threshold: above/below mean, then extrema and statistics within those groups
        """
        x_axis_title = chart_metadata['x_label']
        y_axis_title = chart_metadata['y_label']
        heatmap_category_singular = chart_metadata['heatmap_category']['singular']
        heatmap_category_plural = chart_metadata['heatmap_category']['plural']
        
        target_heatmap_data = copy.deepcopy(chart_metadata["heatmap_data"])
        target_x_labels = copy.deepcopy(chart_metadata["x_labels"])
        target_y_labels = copy.deepcopy(chart_metadata["y_labels"])
        
        num_rows = len(target_heatmap_data)
        num_cols = len(target_heatmap_data[0]) if target_heatmap_data else 0
        
        # All cell indices and values
        all_cell_indices = [(r, c) for r in range(num_rows) for c in range(num_cols)]
        all_values = [target_heatmap_data[r][c] for r, c in all_cell_indices]
        
        # Calculate overall mean
        overall_sum = sum(all_values)
        overall_mean = overall_sum / len(all_values) if all_values else 0
        
        # Find cells above and below mean
        above_mean_cells = []
        below_mean_cells = []
        above_mean_values = []
        below_mean_values = []
        
        for r in range(num_rows):
            for c in range(num_cols):
                value = target_heatmap_data[r][c]
                if value > overall_mean:
                    above_mean_cells.append((r, c))
                    above_mean_values.append(value)
                elif value < overall_mean:
                    below_mean_cells.append((r, c))
                    below_mean_values.append(value)
        
        # Statistics for above-mean group
        above_mean_sum = sum(above_mean_values) if above_mean_values else 0
        above_mean_avg = above_mean_sum / len(above_mean_values) if above_mean_values else 0
        above_mean_min = min(above_mean_values) if above_mean_values else 0
        above_mean_max = max(above_mean_values) if above_mean_values else 0
        
        # Find positions of extrema in above-mean group
        above_mean_min_cells = [cell for cell in above_mean_cells if target_heatmap_data[cell[0]][cell[1]] == above_mean_min]
        above_mean_max_cells = [cell for cell in above_mean_cells if target_heatmap_data[cell[0]][cell[1]] == above_mean_max]
        
        # Statistics for below-mean group
        below_mean_sum = sum(below_mean_values) if below_mean_values else 0
        below_mean_avg = below_mean_sum / len(below_mean_values) if below_mean_values else 0
        below_mean_min = min(below_mean_values) if below_mean_values else 0
        below_mean_max = max(below_mean_values) if below_mean_values else 0
        
        # Find positions of extrema in below-mean group
        below_mean_min_cells = [cell for cell in below_mean_cells if target_heatmap_data[cell[0]][cell[1]] == below_mean_min]
        below_mean_max_cells = [cell for cell in below_mean_cells if target_heatmap_data[cell[0]][cell[1]] == below_mean_max]
        
        # Differences between groups
        sum_difference = abs(above_mean_sum - below_mean_sum)
        avg_difference = abs(above_mean_avg - below_mean_avg)
        
        # Generate detailed reasoning strings
        read_all_reason = '\n'.join([
            f"* Row {target_y_labels[r]}: {', '.join([f'{target_x_labels[c]}={target_heatmap_data[r][c]}' for c in range(num_cols)])}"
            for r in range(num_rows)
        ]).strip()
        
        # Overall mean calculation
        mean_calculation = f"({' + '.join([str(val) for val in all_values])}) / {len(all_values)} = {overall_sum} / {len(all_values)} = {overall_mean}"
        
        # Above mean cells and calculations
        above_mean_positions = ", ".join([f"({target_y_labels[r]}, {target_x_labels[c]})" for r, c in above_mean_cells])
        above_mean_sum_calc = f"{' + '.join([str(val) for val in above_mean_values])} = {above_mean_sum}" if above_mean_values else "0"
        above_mean_avg_calc = f"({' + '.join([str(val) for val in above_mean_values])}) / {len(above_mean_values)} = {above_mean_sum} / {len(above_mean_values)} = {above_mean_avg}" if above_mean_values else "0"
        
        # Below mean cells and calculations
        below_mean_positions = ", ".join([f"({target_y_labels[r]}, {target_x_labels[c]})" for r, c in below_mean_cells])
        below_mean_sum_calc = f"{' + '.join([str(val) for val in below_mean_values])} = {below_mean_sum}" if below_mean_values else "0"
        below_mean_avg_calc = f"({' + '.join([str(val) for val in below_mean_values])}) / {len(below_mean_values)} = {below_mean_sum} / {len(below_mean_values)} = {below_mean_avg}" if below_mean_values else "0"
        
        # Position strings for extrema
        above_mean_min_pos_str = ", ".join([f"({target_y_labels[r]}, {target_x_labels[c]})" for r, c in above_mean_min_cells])
        above_mean_max_pos_str = ", ".join([f"({target_y_labels[r]}, {target_x_labels[c]})" for r, c in above_mean_max_cells])
        below_mean_min_pos_str = ", ".join([f"({target_y_labels[r]}, {target_x_labels[c]})" for r, c in below_mean_min_cells])
        below_mean_max_pos_str = ", ".join([f"({target_y_labels[r]}, {target_x_labels[c]})" for r, c in below_mean_max_cells])

        # Chart QA Pool
        hard_qa_pool = {
            "multi_step__threshold__above_mean__max__value": {
                "question": [
                    f"What is the highest {heatmap_category_singular} among cells that have {heatmap_category_plural} above the overall average?",
                    f"Among cells with {heatmap_category_plural} above the mean, what is the maximum {heatmap_category_singular}?",
                    f"What is the largest {heatmap_category_singular} among cells that are above the overall mean?",
                    f"Among cells whose {heatmap_category_plural} are higher than the average, what is the highest value?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify cells with {heatmap_category_plural} above {overall_mean}" + (f": {above_mean_positions}" if self.detailed_reasoning else "."),
                        "step_4": f"Fourth, I need to find the maximum {heatmap_category_singular} among these cells" + (f": {above_mean_max}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"maximum among cells above mean",
                "answer": above_mean_max,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": above_mean_cells,
                    "step_4": above_mean_max_cells,
                    "answer": above_mean_max_cells,
                },
            },
            "multi_step__threshold__above_mean__min__value": {
                "question": [
                    f"What is the lowest {heatmap_category_singular} among cells that have {heatmap_category_plural} above the overall average?",
                    f"Among cells with {heatmap_category_plural} above the mean, what is the minimum {heatmap_category_singular}?",
                    f"What is the smallest {heatmap_category_singular} among cells that are above the overall mean?",
                    f"Among cells whose {heatmap_category_plural} are higher than the average, what is the lowest value?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify cells with {heatmap_category_plural} above {overall_mean}" + (f": {above_mean_positions}" if self.detailed_reasoning else "."),
                        "step_4": f"Fourth, I need to find the minimum {heatmap_category_singular} among these cells" + (f": {above_mean_min}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"minimum among cells above mean",
                "answer": above_mean_min,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": above_mean_cells,
                    "step_4": above_mean_min_cells,
                    "answer": above_mean_min_cells,
                },
            },
            "multi_step__threshold__below_mean__max__value": {
                "question": [
                    f"What is the highest {heatmap_category_singular} among cells that have {heatmap_category_plural} below the overall average?",
                    f"Among cells with {heatmap_category_plural} below the mean, what is the maximum {heatmap_category_singular}?",
                    f"What is the largest {heatmap_category_singular} among cells that are below the overall mean?",
                    f"Among cells whose {heatmap_category_plural} are lower than the average, what is the highest value?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify cells with {heatmap_category_plural} below {overall_mean}" + (f": {below_mean_positions}" if self.detailed_reasoning else "."),
                        "step_4": f"Fourth, I need to find the maximum {heatmap_category_singular} among these cells" + (f": {below_mean_max}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"maximum among cells below mean",
                "answer": below_mean_max,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": below_mean_cells,
                    "step_4": below_mean_max_cells,
                    "answer": below_mean_max_cells,
                },
            },
            "multi_step__threshold__below_mean__min__value": {
                "question": [
                    f"What is the lowest {heatmap_category_singular} among cells that have {heatmap_category_plural} below the overall average?",
                    f"Among cells with {heatmap_category_plural} below the mean, what is the minimum {heatmap_category_singular}?",
                    f"What is the smallest {heatmap_category_singular} among cells that are below the overall mean?",
                    f"Among cells whose {heatmap_category_plural} are lower than the average, what is the lowest value?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify cells with {heatmap_category_plural} below {overall_mean}" + (f": {below_mean_positions}" if self.detailed_reasoning else "."),
                        "step_4": f"Fourth, I need to find the minimum {heatmap_category_singular} among these cells" + (f": {below_mean_min}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"minimum among cells below mean",
                "answer": below_mean_min,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": below_mean_cells,
                    "step_4": below_mean_min_cells,
                    "answer": below_mean_min_cells,
                },
            },
            "multi_step__threshold__above_mean__sum": {
                "question": [
                    f"What is the total sum of {heatmap_category_plural} for cells that are above the overall average?",
                    f"What is the sum of all {heatmap_category_plural} in cells that have values above the mean?",
                    f"Can you calculate the total of {heatmap_category_plural} for cells above the overall mean?",
                    f"What is the combined sum of {heatmap_category_plural} for cells with above-average values?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify cells with {heatmap_category_plural} above {overall_mean}" + (f": {above_mean_positions}" if self.detailed_reasoning else "."),
                        "step_4": f"Fourth, I need to calculate their total sum" + (f":\n{above_mean_sum_calc}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"sum of cells above mean",
                "answer": above_mean_sum,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": above_mean_cells,
                    "step_4": above_mean_cells if self.force_ground else [],
                    "answer": above_mean_cells,
                },
            },
            "multi_step__threshold__below_mean__sum": {
                "question": [
                    f"What is the total sum of {heatmap_category_plural} for cells that are below the overall average?",
                    f"What is the sum of all {heatmap_category_plural} in cells that have values below the mean?",
                    f"Can you calculate the total of {heatmap_category_plural} for cells below the overall mean?",
                    f"What is the combined sum of {heatmap_category_plural} for cells with below-average values?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify cells with {heatmap_category_plural} below {overall_mean}" + (f": {below_mean_positions}" if self.detailed_reasoning else "."),
                        "step_4": f"Fourth, I need to calculate their total sum" + (f":\n{below_mean_sum_calc}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"sum of cells below mean",
                "answer": below_mean_sum,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": below_mean_cells,
                    "step_4": below_mean_cells if self.force_ground else [],
                    "answer": below_mean_cells,
                },
            },
            "multi_step__threshold__above_mean__avg": {
                "question": [
                    f"What is the average {heatmap_category_singular} for cells that are above the overall average? Please round to two decimal places.",
                    f"What is the mean of {heatmap_category_plural} in cells that have values above the overall mean? Please round to two decimal places.",
                    f"Can you calculate the average {heatmap_category_singular} for cells above the overall mean? Please round to two decimal places.",
                    f"What is the mean value of {heatmap_category_plural} for cells with above-average values? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify cells with {heatmap_category_plural} above {overall_mean}" + (f": {above_mean_positions}" if self.detailed_reasoning else "."),
                        "step_4": f"Fourth, I need to calculate their average" + (f":\n{above_mean_avg_calc}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"average of cells above mean",
                "answer": above_mean_avg,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": above_mean_cells,
                    "step_4": above_mean_cells if self.force_ground else [],
                    "answer": above_mean_cells,
                },
            },
            "multi_step__threshold__below_mean__avg": {
                "question": [
                    f"What is the average {heatmap_category_singular} for cells that are below the overall average? Please round to two decimal places.",
                    f"What is the mean of {heatmap_category_plural} in cells that have values below the overall mean? Please round to two decimal places.",
                    f"Can you calculate the average {heatmap_category_singular} for cells below the overall mean? Please round to two decimal places.",
                    f"What is the mean value of {heatmap_category_plural} for cells with below-average values? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify cells with {heatmap_category_plural} below {overall_mean}" + (f": {below_mean_positions}" if self.detailed_reasoning else "."),
                        "step_4": f"Fourth, I need to calculate their average" + (f":\n{below_mean_avg_calc}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"average of cells below mean",
                "answer": below_mean_avg,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": below_mean_cells,
                    "step_4": below_mean_cells if self.force_ground else [],
                    "answer": below_mean_cells,
                },
            },
            "multi_step__threshold__sum_difference": {
                "question": [
                    f"What is the absolute difference between the total sum of {heatmap_category_plural} for cells above the overall average and those below it?",
                    f"What is the absolute difference between the sum of above-average cells and below-average cells?",
                    f"Can you calculate the absolute difference between the total {heatmap_category_plural} of cells above and below the mean?",
                    f"What is the absolute value of the difference between the sums of above-mean and below-mean cells?",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify and calculate sums for both groups.",
                        "step_4": f"Fourth, I need to calculate the absolute difference" + (f":\n|{above_mean_sum} - {below_mean_sum}| = {sum_difference}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"difference between above and below mean sums",
                "answer": sum_difference,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": above_mean_cells + below_mean_cells,
                    "step_4": above_mean_cells + below_mean_cells if self.force_ground else [],
                    "answer": above_mean_cells + below_mean_cells,
                },
            },
            "multi_step__threshold__avg_difference": {
                "question": [
                    f"What is the absolute difference between the average {heatmap_category_singular} of cells above the overall mean and those below it? Please round to two decimal places.",
                    f"What is the absolute difference between the mean of above-average cells and below-average cells? Please round to two decimal places.",
                    f"Can you calculate the absolute difference between the averages of cells above and below the overall mean? Please round to two decimal places.",
                ],
                "reasoning": [
                    {
                        "step_1": f"First, I need to read all the {heatmap_category_plural} in this heatmap" + (f":\n{read_all_reason}" if self.detailed_reasoning else "."),
                        "step_2": f"Second, I need to calculate the overall mean {heatmap_category_singular}" + (f":\n{mean_calculation}" if self.detailed_reasoning else "."),
                        "step_3": f"Third, I need to identify and calculate averages for both groups.",
                        "step_4": f"Fourth, I need to calculate the absolute difference" + (f":\n|{above_mean_avg} - {below_mean_avg}| = {avg_difference}" if self.detailed_reasoning else "."),
                    },
                ],
                "constraint": f"difference between above and below mean averages",
                "answer": avg_difference,
                "mask": {
                    "step_1": all_cell_indices,
                    "step_2": all_cell_indices if self.force_ground else [],
                    "step_3": above_mean_cells + below_mean_cells,
                    "step_4": above_mean_cells + below_mean_cells if self.force_ground else [],
                    "answer": above_mean_cells + below_mean_cells,
                },
            },
        }

        return hard_qa_pool
    