"""
Two-Stage Review Mode Implementation

Implements the second review mode:
1. Stage 1: A subset of reviewers review ALL papers, each paper gets equal review tasks,
   each reviewer gets equal workload (λ). Integrate opinions with belief propagation.
2. Stage 2: Select papers with scores close to 0 (ambiguous papers) from stage 1 results.
   The remaining reviewers review these papers. Paper ratio = η of total papers.
   Each paper gets equal review tasks, each reviewer gets equal workload (λ).
3. Integrate ALL review results from both stages using belief propagation.
"""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import numpy as np
import random
from general_env import Paper, Reviewer, conference, Review
from framework.efficient_assign import efficient_bipartite_matching_with_tolerance
from bp_new import bp_modified
from review_modes.review_metrics import calculate_all_metrics

class TwoStageReviewConference(conference):
    """
    Two-Stage Review Conference System
    """

    def __init__(self, paper_num, reviewer_num, reviewer_quality, eta=0.5,
                 lambda_per_paper=5, stage1_reviewer_ratio=0.6, lambda1=None,
                 stage1_quality=None, stage2_quality=None, reviewer_prior_tuple=None,
                 use_ref_bp=False, k_ref=3, ref_alpha0=1.0, ref_beta0=1.0, ref_damping=0.5, ref_temp=1.0):
        """
        Initialize two-stage review conference

        Parameters:
        - paper_num: number of papers (n)
        - reviewer_num: number of reviewers (m)
        - reviewer_quality: common reviewer quality q (probability of correct review)
        - eta: ratio of papers proceeding to stage 2 (default 0.5)
        - lambda_per_paper: overall lambda per paper (tasks per paper overall)
        - stage1_reviewer_ratio: ratio of reviewers participating in stage 1 (default 0.6)
        - lambda1: tasks per paper in stage 1 (if None, auto-set to 2/3 * lambda_per_paper)
        """
        self.paper_num = paper_num
        self.reviewer_num = reviewer_num
        self.reviewer_quality = reviewer_quality
        self.eta = eta
        self.lambda_per_paper = lambda_per_paper  # Overall lambda per paper
        self.stage1_reviewer_ratio = stage1_reviewer_ratio

        # Auto-set lambda1 if not provided
        if lambda1 is None:
            self.lambda1 = round(2/3 * lambda_per_paper)
        else:
            self.lambda1 = lambda1
        self.stage1_quality = stage1_quality
        self.stage2_quality = stage2_quality

        # Calculate stage parameters
        self._calculate_stage_params()

        # Initialize parent class with adjusted reviewer_num
        self.reviewer_prior_tuple = reviewer_prior_tuple
        prior = (reviewer_quality, 1.0, reviewer_quality) if reviewer_prior_tuple is None else reviewer_prior_tuple
        super().__init__(
            paper_num=paper_num,
            reviewer_num=self.reviewer_num,
            reviewer_prior=prior,
            reviewer_workload=max(self.stage1_workload, self.stage2_workload),
            review_per_paper=0,
            report_per_review=0
        )

        # Stage-specific data structures
        self.stage1_reviewers = []
        self.stage2_reviewers = []
        self.stage1_reviews = []
        self.stage2_reviews = []
        self.stage2_papers = []
        self.stage1_decisions = None
        self.final_decisions = None

    def _calculate_stage_params(self):
        """
        Calculate parameters for both stages using robust reviewer allocation
        Constraint: lambda1 + eta * lambda2 = lambda_per_paper
        NEW STRATEGY: Optimally allocate reviewers to enable BP cross-validation
        """
        # Use robust reviewer allocation algorithm
        allocation = self._robust_reviewer_allocation(
            self.paper_num, self.lambda_per_paper, self.eta, stage1_ratio=0.5, lambda1_override=self.lambda1
        )

        # Extract allocation results
        self.lambda1 = allocation['lambda1']
        self.lambda2 = allocation['lambda2']
        self.stage1_reviewer_num = allocation['R1']
        self.stage2_reviewer_num = allocation['R2']
        self.stage1_total_tasks = allocation['stage1_tasks']
        self.stage2_total_tasks = allocation['stage2_tasks']
        self.stage2_paper_num = allocation['stage2_papers']

        # Calculate workloads
        self.stage1_workload = allocation['tasks_per_reviewer_1']
        self.stage2_workload = allocation['tasks_per_reviewer_2']
        self.stage1_extra_tasks = allocation['extra_tasks_1']
        self.stage2_extra_tasks = allocation['extra_tasks_2']

        # Paper numbers
        self.stage1_paper_num = self.paper_num

        # Validate reviewer allocation against provided reviewer_num
        self.total_reviewers_needed = self.stage1_reviewer_num + self.stage2_reviewer_num
        if self.total_reviewers_needed != self.reviewer_num:
            print(f"Info: Using {self.total_reviewers_needed} of {self.reviewer_num} available reviewers")
            if self.total_reviewers_needed > self.reviewer_num:
                print(f"Warning: Allocation needs {self.total_reviewers_needed} reviewers but only {self.reviewer_num} provided")
                print("Some reviewers may have higher workload than optimal")

        # # COMMENTED OUT: Original load balancing logic
        # # Allocate reviewers based on workload balance
        # self._allocate_reviewers_balanced()

        # Set reviews per paper for assignment
        self.stage1_reviews_per_paper = self.lambda1
        self.stage2_reviews_per_paper = int(round(self.lambda2))

        print(f"Two-stage parameters (ROBUST ALLOCATION WITH FIXED REVIEWERS):")
        print(f"  Input: λ={self.lambda_per_paper}, η={self.eta}, reviewers={self.reviewer_num}")
        print(f"  Stages: λ1={self.lambda1}, λ2={self.lambda2:.1f}")
        print(f"  Stage 1: {self.stage1_reviewer_num} reviewers, {self.stage1_paper_num} papers, {self.stage1_total_tasks} tasks")
        print(f"    -> Avg workload: {allocation['avg_workload_stage1']:.1f} papers/reviewer")
        print(f"    -> Distribution: {self.stage1_workload} base + {self.stage1_extra_tasks} reviewers get +1 extra")
        print(f"  Stage 2: {self.stage2_reviewer_num} reviewers, {self.stage2_paper_num} papers, {self.stage2_total_tasks:.0f} tasks")
        print(f"    -> Avg workload: {allocation['avg_workload_stage2']:.1f} papers/reviewer")
        print(f"    -> Distribution: {self.stage2_workload} base + {self.stage2_extra_tasks} reviewers get +1 extra")
        print(f"  Total: {self.total_reviewers_needed}/{self.reviewer_num} reviewers used")
        print(f"  -> Enables BP cross-validation: each reviewer sees multiple papers")

    def _allocate_reviewers_balanced(self):
        """
        COMMENTED OUT: Original load balancing logic

        Original method tried to balance workload between stages by:
        - Finding allocation that minimizes maximum workload
        - Calculating actual workloads for each stage

        NEW STRATEGY: Reviewer allocation is now calculated directly in _calculate_stage_params
        based on task numbers (each reviewer reviews exactly 1 paper)
        """
        # # ORIGINAL LOAD BALANCING LOGIC (COMMENTED OUT):
        # if self.stage2_total_tasks == 0:
        #     self.stage1_reviewer_num = self.reviewer_num
        #     self.stage2_reviewer_num = 0
        # elif self.stage1_total_tasks == 0:
        #     self.stage1_reviewer_num = 0
        #     self.stage2_reviewer_num = self.reviewer_num
        # else:
        #     # Find allocation that minimizes maximum workload
        #     best_allocation = None
        #     min_max_workload = float('inf')
        #
        #     # Try different allocations, ensuring at least 1 reviewer per active stage
        #     for stage1_reviewers in range(1, self.reviewer_num):
        #         stage2_reviewers = self.reviewer_num - stage1_reviewers
        #         if stage2_reviewers <= 0:
        #             continue
        #
        #         workload1 = self.stage1_total_tasks / stage1_reviewers
        #         workload2 = self.stage2_total_tasks / stage2_reviewers
        #         max_workload = max(workload1, workload2)
        #
        #         if max_workload < min_max_workload:
        #             min_max_workload = max_workload
        #             best_allocation = stage1_reviewers
        #
        #     # If no good allocation found, use simple ratio
        #     if best_allocation is None:
        #         task_ratio = self.stage1_total_tasks / (self.stage1_total_tasks + self.stage2_total_tasks)
        #         best_allocation = max(1, int(self.reviewer_num * task_ratio))
        #
        #     self.stage1_reviewer_num = best_allocation
        #     self.stage2_reviewer_num = self.reviewer_num - self.stage1_reviewer_num
        #
        # # Calculate actual workloads
        # self.stage1_workload = self.stage1_total_tasks / self.stage1_reviewer_num if self.stage1_reviewer_num > 0 else 0
        # self.stage2_workload = self.stage2_total_tasks / self.stage2_reviewer_num if self.stage2_reviewer_num > 0 else 0

        pass

    def _robust_reviewer_allocation(self, paper_num, lambda_per_paper, eta, stage1_ratio=0.5, lambda1_override=None, lambda2_override=None):
        """
        Robust reviewer allocation strategy for two-stage review using FIXED reviewer count

        Args:
            paper_num: number of papers
            lambda_per_paper: total reviews per paper across both stages
            eta: ratio of papers proceeding to stage 2
            stage1_ratio: proportion of total budget allocated to stage 1

        Returns:
            Dictionary with allocation parameters

        Key change: Uses self.reviewer_num (passed as parameter) instead of assuming paper_num
        """
        # Calculate reviews per stage using explicit lambda1/lambda2 if provided
        if lambda1_override is not None:
            lambda1 = int(round(lambda1_override))
            if eta > 0:
                lambda2_float = (lambda_per_paper - lambda1) / eta
                lambda2 = max(1, int(round(lambda2_float)))
            else:
                lambda2 = 0
        else:
            lambda1 = max(1, int(lambda_per_paper * stage1_ratio))
            lambda2 = lambda_per_paper - lambda1

        # If lambda2_override provided, it takes precedence
        if lambda2_override is not None:
            lambda2 = int(round(lambda2_override))

        # Ensure lambda2 is reasonable
        if lambda2 < 1:
            lambda1 = lambda_per_paper - 1
            lambda2 = 1

        # Calculate stage 2 papers (ensure at least 1)
        stage2_papers = max(1, int(np.ceil(eta * paper_num)))

        # Calculate total tasks for each stage
        stage1_tasks = paper_num * lambda1
        stage2_tasks = stage2_papers * lambda2
        total_tasks = stage1_tasks + stage2_tasks

        # KEY CHANGE: Use the provided reviewer_num instead of assuming paper_num
        total_reviewers = self.reviewer_num  # Use fixed reviewer count from parameter

        # Allocate reviewers proportionally to tasks, with constraints
        R1 = max(1, round(total_reviewers * stage1_tasks / total_tasks))
        R2 = total_reviewers - R1

        # Validation: Ensure we can actually handle the workload
        max_possible_stage1 = stage1_tasks
        max_possible_stage2 = stage2_tasks

        # If insufficient reviewers for reasonable workload, adjust allocation
        MIN_REVIEWERS_PER_STAGE = 1
        MAX_WORKLOAD_PER_REVIEWER = 20  # Reasonable upper limit

        # Ensure R1 can handle stage1 tasks reasonably
        if R1 > 0:
            required_workload_1 = stage1_tasks / R1
            if required_workload_1 > MAX_WORKLOAD_PER_REVIEWER:
                print(f"Warning: Stage 1 workload ({required_workload_1:.1f}) exceeds reasonable limit")

        # Ensure R2 can handle stage2 tasks reasonably
        if R2 > 0:
            required_workload_2 = stage2_tasks / R2
            if required_workload_2 > MAX_WORKLOAD_PER_REVIEWER:
                print(f"Warning: Stage 2 workload ({required_workload_2:.1f}) exceeds reasonable limit")

        # Ensure minimum reviewers per stage if tasks exist
        if stage1_tasks > 0:
            R1 = max(R1, MIN_REVIEWERS_PER_STAGE)
        if stage2_tasks > 0:
            R2 = max(R2, MIN_REVIEWERS_PER_STAGE)

        # Adjust if total exceeds available reviewers
        if R1 + R2 > total_reviewers:
            # Proportionally reduce while maintaining minimums
            if stage1_tasks > 0 and stage2_tasks > 0:
                # Both stages have work, split proportionally
                task_ratio = stage1_tasks / total_tasks
                R1 = max(MIN_REVIEWERS_PER_STAGE, int((total_reviewers - MIN_REVIEWERS_PER_STAGE) * task_ratio))
                R2 = total_reviewers - R1
            elif stage1_tasks > 0:
                # Only stage 1 has work
                R1 = total_reviewers
                R2 = 0
            else:
                # Only stage 2 has work (unusual)
                R1 = 0
                R2 = total_reviewers

        # Calculate task distribution within each stage
        if R1 > 0:
            tasks_per_reviewer_1 = stage1_tasks // R1
            extra_tasks_1 = stage1_tasks % R1
        else:
            tasks_per_reviewer_1 = 0
            extra_tasks_1 = 0

        if R2 > 0:
            tasks_per_reviewer_2 = stage2_tasks // R2
            extra_tasks_2 = stage2_tasks % R2
        else:
            tasks_per_reviewer_2 = 0
            extra_tasks_2 = 0

        return {
            'R1': R1, 'R2': R2,
            'stage1_tasks': stage1_tasks, 'stage2_tasks': stage2_tasks,
            'tasks_per_reviewer_1': tasks_per_reviewer_1, 'extra_tasks_1': extra_tasks_1,
            'tasks_per_reviewer_2': tasks_per_reviewer_2, 'extra_tasks_2': extra_tasks_2,
            'lambda1': lambda1, 'lambda2': lambda2,
            'stage2_papers': stage2_papers,
            'total_reviewers_used': R1 + R2,
            'avg_workload_stage1': stage1_tasks / R1 if R1 > 0 else 0,
            'avg_workload_stage2': stage2_tasks / R2 if R2 > 0 else 0
        }

    def generate_papers_and_reviewers(self):
        """
        Generate papers and reviewers with non-overlapping stage allocation
        NEW STRATEGY: Stage 1 and Stage 2 reviewers are completely separate
        """
        # Save the adjusted reviewer_num and restore original for generate1
        adjusted_reviewer_num = self.reviewer_num
        original_reviewer_num = self.reviewer_num

        # Generate papers and reviewers using existing logic (50% acceptance rate)
        # Note: generate1 uses self.reviewer_num, so we ensure it gets the adjusted value
        self.generate1()

        print(f"Generated {self.paper_num} papers (acceptance rate = 0.5)")
        print(f"Generated {len(self.reviewers)} reviewers (adjusted for non-overlapping stages)")

        # Verify we have enough reviewers
        if len(self.reviewers) < self.stage1_reviewer_num + self.stage2_reviewer_num:
            raise ValueError(f"Not enough reviewers generated. Need {self.stage1_reviewer_num + self.stage2_reviewer_num}, but only have {len(self.reviewers)}")

        # NEW STRATEGY: Split reviewers into two non-overlapping groups
        # Stage 1: reviewers [0 : stage1_reviewer_num]
        # Stage 2: reviewers [stage1_reviewer_num : stage1_reviewer_num + stage2_reviewer_num]

        self.stage1_reviewers = self.reviewers[:self.stage1_reviewer_num]
        self.stage2_reviewers = self.reviewers[self.stage1_reviewer_num:self.stage1_reviewer_num + self.stage2_reviewer_num]

        print(f"Stage 1 reviewers: {len(self.stage1_reviewers)} (non-overlapping)")
        print(f"Stage 2 reviewers: {len(self.stage2_reviewers)} (non-overlapping)")
        print(f"All reviewers have quality={self.reviewer_quality}")

        # Optional: override stage-specific qualities if provided
        if self.stage1_quality is not None:
            for r in self.stage1_reviewers:
                r.reliability = self.stage1_quality
            print(f"Stage 1 reviewer quality overridden to {self.stage1_quality}")
        if self.stage2_quality is not None:
            for r in self.stage2_reviewers:
                r.reliability = self.stage2_quality
            print(f"Stage 2 reviewer quality overridden to {self.stage2_quality}")

        # Verify non-overlap
        assert len(set([r.id for r in self.stage1_reviewers]) & set([r.id for r in self.stage2_reviewers])) == 0, \
            "Error: Stage 1 and Stage 2 reviewers should not overlap"

    def assign_stage1_reviews(self):
        """
        Assign review tasks for stage 1: subset of reviewers review ALL papers
        """
        print(f"\n=== Stage 1: Assigning Review Tasks ===")

        # Calculate tolerance (similar to direct review mode)
        tolerance = max(1, int(0.1 * self.stage1_workload))

        try:
            # Use efficient assignment algorithm
            assignment_graph = efficient_bipartite_matching_with_tolerance(
                m=self.paper_num,                    # all papers
                n=self.stage1_reviewer_num,          # stage 1 reviewers
                l=int(self.stage1_reviews_per_paper), # reviews per paper (ensure integer)
                r=int(self.stage1_workload),          # tasks per reviewer
                q=tolerance,                         # tolerance
                randomness_level=0.0
            )

            # Clear existing reviews
            self.stage1_reviews = []
            for reviewer in self.stage1_reviewers:
                reviewer.reviews = []
            for paper in self.papers:
                paper.reviews = []

            # Create Review objects from assignment graph
            review_id = 0
            for edge in assignment_graph.edges():
                paper_node, reviewer_node = edge
                paper_id = int(paper_node[1:])      # Remove 'L' prefix
                reviewer_idx = int(reviewer_node[1:]) # Remove 'R' prefix (index in stage1_reviewers)

                # Create Review object
                review = Review(review_id, self.stage1_reviewers[reviewer_idx], self.papers[paper_id])

                # Add to stage 1 reviews
                self.stage1_reviewers[reviewer_idx].reviews.append(review)
                self.papers[paper_id].reviews.append(review)
                self.stage1_reviews.append(review)

                review_id += 1

            print(f"Stage 1 assignment completed: {len(self.stage1_reviews)} review tasks")

        except Exception as e:
            print(f"Stage 1 efficient assignment failed, using fallback: {e}")
            self._assign_stage1_fallback()

        # Validate stage 1 assignment
        self._validate_stage1_assignment()

    def _assign_stage1_fallback(self):
        """
        Robust fallback assignment for stage 1 ensuring cross-validation
        """
        self.stage1_reviews = []

        # Clear existing assignments
        for reviewer in self.stage1_reviewers:
            reviewer.reviews = []
        for paper in self.papers:
            paper.reviews = []

        # Use robust assignment strategy
        total_assignments_needed = self.stage1_total_tasks

        # Create assignment list with balanced reviewer distribution
        assignments = []
        for reviewer_idx in range(self.stage1_reviewer_num):
            if reviewer_idx < self.stage1_extra_tasks:
                num_reviews = self.stage1_workload + 1  # Extra task
            else:
                num_reviews = self.stage1_workload
            assignments.extend([reviewer_idx] * num_reviews)

        # Shuffle to avoid bias
        import random
        random.shuffle(assignments)

        # Track assignments to avoid duplicate assignments to same paper
        assignment_matrix = np.zeros((self.paper_num, self.stage1_reviewer_num), dtype=bool)
        reviewer_loads = [0] * self.stage1_reviewer_num

        review_id = 0
        assignment_idx = 0

        # Assign reviews using shuffled assignment list
        for paper_id in range(self.paper_num):
            paper = self.papers[paper_id]
            assigned_count = 0

            # Try to assign required number of reviewers
            while assigned_count < self.stage1_reviews_per_paper and assignment_idx < len(assignments):
                reviewer_idx = assignments[assignment_idx]

                # Check if not already assigned
                if not assignment_matrix[paper_id, reviewer_idx]:
                    review = Review(review_id, self.stage1_reviewers[reviewer_idx], paper)

                    # Establish relationships
                    self.stage1_reviewers[reviewer_idx].reviews.append(review)
                    paper.reviews.append(review)
                    self.stage1_reviews.append(review)

                    # Update tracking
                    assignment_matrix[paper_id, reviewer_idx] = True
                    reviewer_loads[reviewer_idx] += 1
                    assigned_count += 1
                    review_id += 1

                assignment_idx += 1

            # If couldn't assign enough, fill with remaining reviewers
            if assigned_count < self.stage1_reviews_per_paper:
                for reviewer_idx in range(self.stage1_reviewer_num):
                    if assigned_count >= self.stage1_reviews_per_paper:
                        break

                    if not assignment_matrix[paper_id, reviewer_idx]:
                        review = Review(review_id, self.stage1_reviewers[reviewer_idx], paper)
                        self.stage1_reviewers[reviewer_idx].reviews.append(review)
                        paper.reviews.append(review)
                        self.stage1_reviews.append(review)
                        assignment_matrix[paper_id, reviewer_idx] = True
                        reviewer_loads[reviewer_idx] += 1
                        assigned_count += 1
                        review_id += 1

        print(f"Stage 1 fallback assignment:")
        print(f"  Average papers per reviewer: {np.mean(reviewer_loads):.1f}")
        print(f"  Min/Max papers per reviewer: {min(reviewer_loads)}/{max(reviewer_loads)}")
        papers_per_reviewer = [len(set(review.paper.id for review in reviewer.reviews))
                              for reviewer in self.stage1_reviewers]
        print(f"  Cross-validation enabled: {np.mean(papers_per_reviewer):.1f} papers/reviewer")

    def _validate_stage1_assignment(self):
        """
        Validate stage 1 assignment
        """
        paper_review_counts = [len(paper.reviews) for paper in self.papers]
        reviewer_loads = [len(reviewer.reviews) for reviewer in self.stage1_reviewers]

        print(f"Stage 1 validation:")
        print(f"  Paper review counts: min={min(paper_review_counts)}, "
              f"max={max(paper_review_counts)}, mean={np.mean(paper_review_counts):.2f}")
        print(f"  Reviewer workloads: min={min(reviewer_loads)}, "
              f"max={max(reviewer_loads)}, mean={np.mean(reviewer_loads):.2f}")

    def conduct_stage1_reviews(self):
        """
        Execute stage 1 review process
        """
        print(f"\n=== Stage 1: Conducting Reviews ===")

        # Each Review object performs review using existing logic
        for review in self.stage1_reviews:
            review._operate()

        print(f"Stage 1 reviews completed: {len(self.stage1_reviews)} opinions generated")

    def apply_stage1_bp(self):
        """
        Apply BP algorithm to stage 1 results to get preliminary decisions
        """
        print(f"\n=== Stage 1: Applying Belief Propagation ===")

        # Create stage 1 review matrix
        stage1_matrix = np.zeros((self.paper_num, self.reviewer_num), dtype=int)

        for review in self.stage1_reviews:
            paper_id = review.paper.id
            reviewer_id = review.reviewer.id
            stage1_matrix[paper_id, reviewer_id] = review.rating

        # Prepare BP prior; single q or mixture prior
        if getattr(self, "reviewer_prior_tuple", None) is not None:
            q_exp, frac_exp, q_base = self.reviewer_prior_tuple
            bp_prior = [[q_base, q_exp], [1.0 - frac_exp, frac_exp]]
        else:
            reviewer_correctness = self.reviewer_quality
            bp_prior = [[reviewer_correctness], [1.0]]

        max_papers_per_reviewer = int(np.max(np.sum(stage1_matrix != 0, axis=0)))

        try:
            self.stage1_decisions = bp_modified(
                stage1_matrix,
                "norm11",
                50,  # max iterations
                max_papers_per_reviewer,
                bp_prior
            )
            print("Stage 1 BP algorithm executed successfully")

        except Exception as e:
            print(f"Stage 1 BP failed, using majority voting: {e}")
            self.stage1_decisions = self._majority_voting_stage1(stage1_matrix)

        # Print stage 1 decision statistics
        abs_scores = np.abs(self.stage1_decisions)
        print(f"Stage 1 decision scores: mean_abs={np.mean(abs_scores):.3f}, "
              f"min_abs={np.min(abs_scores):.3f}, max_abs={np.max(abs_scores):.3f}")

    def _majority_voting_stage1(self, review_matrix):
        """
        Majority voting fallback for stage 1
        """
        decisions = np.zeros(self.paper_num)

        for paper_id in range(self.paper_num):
            votes = review_matrix[paper_id, :]
            positive_votes = np.sum(votes == 1)
            negative_votes = np.sum(votes == -1)

            if positive_votes > negative_votes:
                decisions[paper_id] = 1
            elif negative_votes > positive_votes:
                decisions[paper_id] = -1
            else:
                decisions[paper_id] = 0

        return decisions

    def select_stage2_papers(self):
        """
        Select papers for stage 2 based on stage 1 scores (closest to 0 = most ambiguous)
        """
        print(f"\n=== Selecting Papers for Stage 2 ===")

        if self.stage1_decisions is None:
            print("Error: Stage 1 decisions not available")
            return

        # Calculate absolute scores (distance from 0 = ambiguity level)
        abs_scores = np.abs(self.stage1_decisions)

        # Sort papers by absolute score (ascending = most ambiguous first)
        paper_indices = np.argsort(abs_scores)

        # Select top eta proportion of most ambiguous papers
        num_stage2_papers = min(self.stage2_paper_num, len(paper_indices))
        stage2_paper_indices = paper_indices[:num_stage2_papers]

        self.stage2_papers = [self.papers[i] for i in stage2_paper_indices]
        self.stage2_paper_indices = stage2_paper_indices  # Store for later use

        print(f"Selected {len(self.stage2_papers)} most ambiguous papers for stage 2")
        sample_scores = [f"{self.stage1_decisions[i]:.3f}" for i in stage2_paper_indices[:5]]
        print(f"Stage 2 paper scores (stage 1): {sample_scores}...")

    def assign_stage2_reviews(self):
        """
        Assign review tasks for stage 2: remaining reviewers review selected papers
        """
        print(f"\n=== Stage 2: Assigning Review Tasks ===")

        if len(self.stage2_papers) == 0 or len(self.stage2_reviewers) == 0:
            print("No stage 2 papers or reviewers available")
            self.stage2_reviews = []
            return
            return

        # Calculate tolerance
        tolerance = max(1, int(0.1 * self.stage2_workload))

        try:
            # Use efficient assignment algorithm
            assignment_graph = efficient_bipartite_matching_with_tolerance(
                m=len(self.stage2_papers),           # selected papers
                n=self.stage2_reviewer_num,          # stage 2 reviewers
                l=self.stage2_reviews_per_paper,     # reviews per paper
                r=int(self.stage2_workload),          # tasks per reviewer
                q=tolerance,                         # tolerance
                randomness_level=0.0
            )

            # Clear existing stage 2 reviews for stage 2 reviewers
            self.stage2_reviews = []
            for reviewer in self.stage2_reviewers:
                reviewer.reviews = []

            # Create Review objects from assignment graph
            review_id = len(self.stage1_reviews)  # Continue from stage 1 review IDs
            for edge in assignment_graph.edges():
                paper_node, reviewer_node = edge
                paper_idx = int(paper_node[1:])       # Index in stage2_papers list
                reviewer_idx = int(reviewer_node[1:]) # Index in stage2_reviewers list

                # Create Review object
                review = Review(review_id, self.stage2_reviewers[reviewer_idx], self.stage2_papers[paper_idx])

                # Add to stage 2 reviews
                self.stage2_reviewers[reviewer_idx].reviews.append(review)
                self.stage2_papers[paper_idx].reviews.append(review)  # Add to existing stage 1 reviews
                self.stage2_reviews.append(review)

                review_id += 1

            print(f"Stage 2 assignment completed: {len(self.stage2_reviews)} review tasks")

        except Exception as e:
            print(f"Stage 2 efficient assignment failed, using fallback: {e}")
            self._assign_stage2_fallback()

        # Validate stage 2 assignment
        self._validate_stage2_assignment()

    def _assign_stage2_fallback(self):
        """
        Robust fallback assignment for stage 2 ensuring cross-validation
        """
        self.stage2_reviews = []

        # Clear existing stage 2 assignments
        for reviewer in self.stage2_reviewers:
            reviewer.reviews = []

        # Use robust assignment strategy similar to stage 1
        total_assignments_needed = self.stage2_total_tasks

        # Create assignment list with balanced reviewer distribution
        assignments = []
        for reviewer_idx in range(self.stage2_reviewer_num):
            if reviewer_idx < self.stage2_extra_tasks:
                num_reviews = self.stage2_workload + 1  # Extra task
            else:
                num_reviews = self.stage2_workload
            assignments.extend([reviewer_idx] * num_reviews)

        # Shuffle to avoid bias
        import random
        random.shuffle(assignments)

        # Track assignments to avoid duplicate assignments to same paper
        assignment_matrix = np.zeros((len(self.stage2_papers), self.stage2_reviewer_num), dtype=bool)
        reviewer_loads = [0] * self.stage2_reviewer_num

        review_id = len(self.stage1_reviews)  # Continue from stage 1
        assignment_idx = 0

        # Assign reviews using shuffled assignment list
        for paper_idx, paper in enumerate(self.stage2_papers):
            assigned_count = 0

            # Try to assign required number of reviewers
            while assigned_count < self.stage2_reviews_per_paper and assignment_idx < len(assignments):
                reviewer_idx = assignments[assignment_idx]

                # Check if not already assigned
                if not assignment_matrix[paper_idx, reviewer_idx]:
                    review = Review(review_id, self.stage2_reviewers[reviewer_idx], paper)

                    # Establish relationships
                    self.stage2_reviewers[reviewer_idx].reviews.append(review)
                    paper.reviews.append(review)  # Add to existing stage 1 reviews
                    self.stage2_reviews.append(review)

                    # Update tracking
                    assignment_matrix[paper_idx, reviewer_idx] = True
                    reviewer_loads[reviewer_idx] += 1
                    assigned_count += 1
                    review_id += 1

                assignment_idx += 1

            # If couldn't assign enough, fill with remaining reviewers
            if assigned_count < self.stage2_reviews_per_paper:
                for reviewer_idx in range(self.stage2_reviewer_num):
                    if assigned_count >= self.stage2_reviews_per_paper:
                        break

                    if not assignment_matrix[paper_idx, reviewer_idx]:
                        review = Review(review_id, self.stage2_reviewers[reviewer_idx], paper)
                        self.stage2_reviewers[reviewer_idx].reviews.append(review)
                        paper.reviews.append(review)
                        self.stage2_reviews.append(review)
                        assignment_matrix[paper_idx, reviewer_idx] = True
                        reviewer_loads[reviewer_idx] += 1
                        assigned_count += 1
                        review_id += 1

        print(f"Stage 2 fallback assignment:")
        print(f"  Average papers per reviewer: {np.mean(reviewer_loads):.1f}")
        print(f"  Min/Max papers per reviewer: {min(reviewer_loads) if reviewer_loads else 0}/{max(reviewer_loads) if reviewer_loads else 0}")
        if reviewer_loads:
            papers_per_reviewer = [len(set(review.paper.id for review in reviewer.reviews))
                                  for reviewer in self.stage2_reviewers]
            print(f"  Cross-validation enabled: {np.mean(papers_per_reviewer):.1f} papers/reviewer")

    def _validate_stage2_assignment(self):
        """
        Validate stage 2 assignment
        """
        if len(self.stage2_papers) == 0:
            print("Stage 2 validation: No papers assigned")
            return

        # Count only stage 2 reviews for each paper
        paper_review_counts = [len([r for r in paper.reviews if r in self.stage2_reviews])
                              for paper in self.stage2_papers]
        reviewer_loads = [len(reviewer.reviews) for reviewer in self.stage2_reviewers]

        print(f"Stage 2 validation:")
        mean_paper_reviews = np.mean(paper_review_counts) if paper_review_counts else 0
        mean_reviewer_load = np.mean(reviewer_loads) if reviewer_loads else 0
        print(f"  Paper review counts: min={min(paper_review_counts) if paper_review_counts else 0}, "
              f"max={max(paper_review_counts) if paper_review_counts else 0}, "
              f"mean={mean_paper_reviews:.2f}")
        print(f"  Reviewer workloads: min={min(reviewer_loads) if reviewer_loads else 0}, "
              f"max={max(reviewer_loads) if reviewer_loads else 0}, "
              f"mean={mean_reviewer_load:.2f}")

    def conduct_stage2_reviews(self):
        """
        Execute stage 2 review process
        """
        print(f"\n=== Stage 2: Conducting Reviews ===")

        if len(self.stage2_reviews) == 0:
            print("No stage 2 reviews to conduct")
            return

        # Each Review object performs review using existing logic
        for review in self.stage2_reviews:
            review._operate()

        print(f"Stage 2 reviews completed: {len(self.stage2_reviews)} opinions generated")

    def apply_final_bp(self):
        """
        Apply BP algorithm to ALL review results from both stages
        """
        print(f"\n=== Final: Applying Belief Propagation to Combined Results ===")

        # Create combined review matrix
        combined_matrix = np.zeros((self.paper_num, self.reviewer_num), dtype=int)

        # Add stage 1 reviews
        for review in self.stage1_reviews:
            paper_id = review.paper.id
            reviewer_id = review.reviewer.id
            combined_matrix[paper_id, reviewer_id] = review.rating

        # Add stage 2 reviews
        for review in self.stage2_reviews:
            paper_id = review.paper.id
            reviewer_id = review.reviewer.id
            combined_matrix[paper_id, reviewer_id] = review.rating

        # Prepare BP prior; single q or mixture prior
        if getattr(self, "reviewer_prior_tuple", None) is not None:
            q_exp, frac_exp, q_base = self.reviewer_prior_tuple
            bp_prior = [[q_base, q_exp], [1.0 - frac_exp, frac_exp]]
        else:
            reviewer_correctness = self.reviewer_quality
            bp_prior = [[reviewer_correctness], [1.0]]

        max_papers_per_reviewer = int(np.max(np.sum(combined_matrix != 0, axis=0)))

        try:
            self.final_decisions = bp_modified(
                combined_matrix,
                "norm11",
                50,  # max iterations
                max_papers_per_reviewer,
                bp_prior
            )
            print("Final BP algorithm executed successfully")

        except Exception as e:
            print(f"Final BP failed, using majority voting: {e}")
            self.final_decisions = self._majority_voting_final(combined_matrix)

        print(f"Total reviews used in final BP: {np.sum(combined_matrix != 0)}")

    def _majority_voting_final(self, review_matrix):
        """
        Majority voting fallback for final decision - returns raw scores based on vote proportions
        """
        decisions = np.zeros(self.paper_num)

        for paper_id in range(self.paper_num):
            votes = review_matrix[paper_id, :]
            non_zero_votes = votes[votes != 0]

            if len(non_zero_votes) == 0:
                decisions[paper_id] = 0.0
            else:
                # Return the average vote as raw score
                decisions[paper_id] = np.mean(non_zero_votes)

        return decisions

    def evaluate_performance(self):
        """
        Evaluate system performance
        """
        if self.final_decisions is None:
            print("Error: Final decisions not generated yet")
            return None

        # Get true qualities and predicted qualities
        true_qualities = np.array([paper.quality for paper in self.papers])
        predicted_qualities = np.sign(self.final_decisions)

        # Calculate accuracy
        accuracy = np.mean(predicted_qualities == true_qualities)

        # Calculate confusion matrix
        tp = np.sum((predicted_qualities == 1) & (true_qualities == 1))
        tn = np.sum((predicted_qualities == -1) & (true_qualities == -1))
        fp = np.sum((predicted_qualities == 1) & (true_qualities == -1))
        fn = np.sum((predicted_qualities == -1) & (true_qualities == 1))

        # Calculate other metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        results = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1_score,
            'confusion_matrix': {'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn},
            'total_reviews': len(self.stage1_reviews) + len(self.stage2_reviews),
            'stage1_reviews': len(self.stage1_reviews),
            'stage2_reviews': len(self.stage2_reviews),
            'stage2_papers': len(self.stage2_papers)
        }

        print(f"\n=== Performance Evaluation ===")
        print(f"Accuracy: {accuracy:.3f}")
        print(f"Precision: {precision:.3f}")
        print(f"Recall: {recall:.3f}")
        print(f"F1-score: {f1_score:.3f}")
        print(f"Total Reviews: {results['total_reviews']} "
              f"(Stage1: {results['stage1_reviews']}, Stage2: {results['stage2_reviews']})")
        print(f"Stage 2 Papers: {results['stage2_papers']}/{self.paper_num}")
        print(f"Confusion Matrix: TP={tp}, TN={tn}, FP={fp}, FN={fn}")

        # Calculate all metrics using the new metrics module
        true_qualities = [paper.quality for paper in self.papers]
        all_metrics = calculate_all_metrics(self.final_decisions, true_qualities)

        # Extract results from metrics module
        calibration_error_result = all_metrics['calibration_error']
        bp_raw_error_result = all_metrics['bp_raw_error']  # Backward compatibility
        bp_f1_loss_result = all_metrics['bp_f1_loss']
        kl_divergence_result = all_metrics['kl_divergence']
        js_divergence_result = all_metrics['js_divergence']

        # Add new metrics to results
        results.update({
            'calibration_error': calibration_error_result['calibration_error'],
            'bp_raw_error': calibration_error_result['calibration_error'],  # Backward compatibility uses same value
            'bp_f1_loss': bp_f1_loss_result['f1_loss'],
            'bp_f1_score_50pct': bp_f1_loss_result['f1_score'],
            'bp_acceptance_rate': bp_f1_loss_result['acceptance_rate'],
            'kl_divergence': kl_divergence_result['kl_divergence'],
            'js_divergence': js_divergence_result['js_divergence']
        })

        print(f"\n=== New Metrics ===")
        print(f"Calibration Error (avg): {calibration_error_result['calibration_error']:.3f}")
        print(f"BP Raw Error (avg): {calibration_error_result['calibration_error']:.3f}")  # Same value for compatibility
        print(f"BP F1 Loss (50% accept): {bp_f1_loss_result['f1_loss']:.3f}")
        print(f"BP F1 Score (50% accept): {bp_f1_loss_result['f1_score']:.3f}")
        print(f"BP Acceptance Rate: {bp_f1_loss_result['acceptance_rate']:.3f}")
        print(f"KL Divergence: {kl_divergence_result['kl_divergence']:.4f}")
        print(f"JS Divergence: {js_divergence_result['js_divergence']:.4f}")
        print(f"True Distribution - Accept: {kl_divergence_result['true_distribution']['accept']:.3f}, Reject: {kl_divergence_result['true_distribution']['reject']:.3f}")
        print(f"Review Distribution - Accept: {kl_divergence_result['review_distribution']['accept']:.3f}, Reject: {kl_divergence_result['review_distribution']['reject']:.3f}")

        return results

    def _simulate_reference_counts(self, seed=12345):
        rng = np.random.RandomState(seed)
        ref = []
        for i in range(self.reviewer_num):
            q = self.reviewers[i].reliability
            s=f=0
            for _ in range(max(0,int(self.k_ref))):
                y = 1 if rng.rand()<0.5 else -1
                v = y if (rng.rand()<q) else -y
                if v==y: s+=1
                else: f+=1
            ref.append((s,f))
        return ref

    def run_two_stage_review(self):
        """
        Run complete two-stage review process
        """
        print("=== Starting Two-Stage Review Mode ===")

        # 1. Generate papers and reviewers
        self.generate_papers_and_reviewers()

        # 2. Stage 1: Assign and conduct reviews
        self.assign_stage1_reviews()
        self.conduct_stage1_reviews()
        self.apply_stage1_bp()

        # 3. Select papers for stage 2
        self.select_stage2_papers()

        # 4. Stage 2: Assign and conduct reviews (if there are papers)
        if len(self.stage2_papers) > 0:
            self.assign_stage2_reviews()
            self.conduct_stage2_reviews()

        # 5. Final BP on combined results
        self.apply_final_bp()

        # 6. Evaluate performance
        results = self.evaluate_performance()

        print("=== Two-Stage Review Mode Completed ===")
        return results


def main():
    """
    Test two-stage review mode
    """
    # Test configurations
    test_configs = [
        {
            'paper_num': 30,
            'reviewer_num': 20,
            'reviewer_quality': 0.8,
            'eta': 0.3,  # 30% of papers go to stage 2
            'lambda_per_reviewer': 5,
            'stage1_reviewer_ratio': 0.6  # 60% reviewers in stage 1
        },
        {
            'paper_num': 40,
            'reviewer_num': 25,
            'reviewer_quality': 0.9,
            'eta': 0.4,  # 40% of papers go to stage 2
            'lambda_per_reviewer': 6,
            'stage1_reviewer_ratio': 0.5  # 50% reviewers in stage 1
        }
    ]

    print("Starting Two-Stage Review Mode Testing\n")

    for i, config in enumerate(test_configs):
        print(f"\n{'='*60}")
        print(f"Test Configuration {i+1}: {config}")
        print(f"{'='*60}")

        # Create and run two-stage review system
        conference = TwoStageReviewConference(**config)
        results = conference.run_two_stage_review()

        # Output summary
        if results:
            print(f"\n[Config {i+1} Summary] Accuracy: {results['accuracy']:.3f}, "
                  f"F1: {results['f1_score']:.3f}, Total Reviews: {results['total_reviews']}")


if __name__ == "__main__":
    main()
