"""
Direct Review Mode Implementation

Based on existing code logic, implementing the first review mode:
1. Generate n papers and m reviewers
2. Each paper gets equal number of review tasks, each reviewer gets equal workload (λ)
3. All reviewers complete reviews and generate review opinions
4. Use belief propagation to obtain final results
"""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import numpy as np
import random
from general_env import Paper, Reviewer, conference, Review  # Reuse existing classes
from framework.efficient_assign import efficient_bipartite_matching_with_tolerance
from bp_new import bp_modified
from review_modes.review_metrics import calculate_all_metrics

class DirectReviewConference(conference):
    """
    Direct Review Conference System, inheriting from existing conference class
    """

    def __init__(self, paper_num, reviewer_num, reviewer_quality, lambda_per_paper=10, reviewer_prior_tuple=None, use_ref_bp=False, k_ref=3, ref_alpha0=1.0, ref_beta0=1.0, ref_damping=0.5, ref_temp=1.0):
        """
        Initialize direct review conference

        Parameters:
        - paper_num: number of papers (n)
        - reviewer_num: number of reviewers (m)
        - reviewer_quality: common reviewer quality q (probability of correct review)
        - lambda_per_paper: number of reviews per paper (lambda)
        """
        # Calculate task assignment parameters
        self.paper_num = paper_num
        self.reviewer_num = reviewer_num
        self.reviewer_quality = reviewer_quality
        self.lambda_per_paper = lambda_per_paper
        self.use_ref_bp = use_ref_bp or (os.getenv('USE_REF_BP','0') in ('1','true','yes','on'))
        self.k_ref = k_ref
        self.ref_alpha0=ref_alpha0; self.ref_beta0=ref_beta0
        self.ref_damping=ref_damping; self.ref_temp=ref_temp

        # Calculate balanced assignment scheme
        self._calculate_assignment_params()

        # Initialize parent class with direct review requirements
        self.reviewer_prior_tuple = reviewer_prior_tuple
        prior = (reviewer_quality, 1.0, reviewer_quality) if reviewer_prior_tuple is None else reviewer_prior_tuple
        super().__init__(
            paper_num=paper_num,
            reviewer_num=reviewer_num,
            reviewer_prior=prior,
            reviewer_workload=self.avg_tasks_per_reviewer,
            review_per_paper=self.reviews_per_paper,
            report_per_review=0
        )

    def _calculate_assignment_params(self):
        """
        Calculate assignment parameters where lambda_per_paper = tasks per paper
        """
        # Set reviews per paper equal to lambda per paper
        self.reviews_per_paper = self.lambda_per_paper

        # Calculate total tasks and average workload per reviewer
        self.total_tasks = self.paper_num * self.lambda_per_paper
        self.avg_tasks_per_reviewer = self.total_tasks / self.reviewer_num

        print(f"Direct Review Assignment:")
        print(f"  λ = {self.lambda_per_paper} tasks per paper")
        print(f"  Total tasks: {self.total_tasks}")
        print(f"  Avg tasks per reviewer: {self.avg_tasks_per_reviewer:.2f}")

    def generate_papers_and_reviewers(self):
        """
        Generate papers and reviewers, reusing existing logic with acceptance rate = 0.5
        """
        # Directly call parent's generate1 method, which already implements 50% acceptance rate
        self.generate1()

        print(f"Generated {self.paper_num} papers (acceptance rate = 0.5)")
        print(f"Generated {self.reviewer_num} reviewers")
        print(f"  Reviewer prior: {('uniform', self.reviewer_quality) if self.reviewer_prior_tuple is None else ('mixture', self.reviewer_prior_tuple)}")

    def assign_reviews_direct(self):
        """
        Direct review task assignment using efficient assignment algorithm
        """
        print(f"\nStarting review task assignment...")

        # Calculate reasonable tolerance (allow ±1 deviation from target workload)
        tolerance = max(1, int(0.1 * self.avg_tasks_per_reviewer))  # 10% tolerance, minimum 1

        # Use efficient assignment algorithm
        try:
            assignment_graph = efficient_bipartite_matching_with_tolerance(
                m=self.paper_num,           # number of papers
                n=self.reviewer_num,        # number of reviewers
                l=self.reviews_per_paper,   # reviews per paper
                r=int(self.avg_tasks_per_reviewer), # tasks per reviewer
                q=tolerance,                # tolerance
                randomness_level=0.0        # randomness level (deterministic)
            )

            # Clear existing reviews and assignments
            self.reviews = []
            for reviewer in self.reviewers:
                reviewer.reviews = []
            for paper in self.papers:
                paper.reviews = []

            # Create Review objects from assignment graph
            review_id = 0
            for edge in assignment_graph.edges():
                paper_node, reviewer_node = edge
                paper_id = int(paper_node[1:])      # Remove 'L' prefix
                reviewer_id = int(reviewer_node[1:]) # Remove 'R' prefix

                # Create Review object
                review = Review(review_id, self.reviewers[reviewer_id], self.papers[paper_id])

                # Establish relationships
                self.reviewers[reviewer_id].reviews.append(review)
                self.papers[paper_id].reviews.append(review)
                self.reviews.append(review)

                review_id += 1

            print(f"Assignment completed: {len(self.reviews)} review tasks in total")

        except Exception as e:
            print(f"Efficient assignment failed, using fallback: {e}")
            self._assign_reviews_fallback()

        # Validate assignment results
        self._validate_assignment()

    def _assign_reviews_fallback(self):
        """
        Fallback simple assignment method
        """
        self.reviews = []
        reviewer_loads = [0] * self.reviewer_num

        review_id = 0
        for paper in self.papers:
            paper.reviews = []
            assigned_count = 0

            while assigned_count < self.reviews_per_paper:
                # Find reviewer with minimum workload
                min_load_reviewer_id = reviewer_loads.index(min(reviewer_loads))

                # Check if already assigned to this paper
                already_assigned = any(
                    review.reviewer.id == min_load_reviewer_id
                    for review in paper.reviews
                )

                if not already_assigned:
                    # Create Review object
                    review = Review(review_id, self.reviewers[min_load_reviewer_id], paper)

                    # Establish relationships
                    self.reviewers[min_load_reviewer_id].reviews.append(review)
                    paper.reviews.append(review)
                    self.reviews.append(review)

                    reviewer_loads[min_load_reviewer_id] += 1
                    assigned_count += 1
                    review_id += 1
                else:
                    # If minimum workload reviewer already assigned, find next available
                    for rid in range(self.reviewer_num):
                        if not any(review.reviewer.id == rid for review in paper.reviews):
                            review = Review(review_id, self.reviewers[rid], paper)
                            self.reviewers[rid].reviews.append(review)
                            paper.reviews.append(review)
                            self.reviews.append(review)
                            reviewer_loads[rid] += 1
                            assigned_count += 1
                            review_id += 1
                            break

    def _validate_assignment(self):
        """
        Validate assignment results
        """
        # Count reviews per paper
        paper_review_counts = [len(paper.reviews) for paper in self.papers]

        # Count workload per reviewer
        reviewer_loads = [len(reviewer.reviews) for reviewer in self.reviewers]

        print(f"Paper review counts: min={min(paper_review_counts)}, max={max(paper_review_counts)}, "
              f"mean={np.mean(paper_review_counts):.2f}")
        print(f"Reviewer workloads: min={min(reviewer_loads)}, max={max(reviewer_loads)}, "
              f"mean={np.mean(reviewer_loads):.2f}")

        # Check if assignment meets requirements
        under_reviewed = sum(1 for count in paper_review_counts if count < self.reviews_per_paper)
        if under_reviewed > 0:
            print(f"Warning: {under_reviewed} papers have insufficient reviews")

    def conduct_reviews(self):
        """
        Execute review process, reusing existing Review._operate logic
        """
        print(f"\nStarting review process...")

        # Each Review object performs review
        for review in self.reviews:
            review._operate()  # Call existing review logic

        print(f"Review completed: {len(self.reviews)} review opinions generated")

    def make_review_matrix(self):
        """
        Create review matrix for BP algorithm
        """
        review_matrix = np.zeros((self.paper_num, self.reviewer_num), dtype=int)

        for review in self.reviews:
            paper_id = review.paper.id
            reviewer_id = review.reviewer.id
            review_matrix[paper_id, reviewer_id] = review.rating

        return review_matrix

    def _simulate_reference_counts(self, seed=12345):
        """Simulate K_ref reference tasks per reviewer, return list of (s_i,f_i)."""
        rng = np.random.RandomState(seed)
        ref = []
        # Reference paper labels ~ Bernoulli(0.5) on {+1,-1}
        for i in range(self.reviewer_num):
            q = self.reviewers[i].reliability
            s=f=0
            for _ in range(max(0,int(self.k_ref))):
                y = 1 if rng.rand()<0.5 else -1
                # reviewer response
                v = y if (rng.rand()<q) else -y
                if v==y: s+=1
                else: f+=1
            ref.append((s,f))
        return ref

        """
        Create review matrix for BP algorithm
        """
        review_matrix = np.zeros((self.paper_num, self.reviewer_num), dtype=int)

        for review in self.reviews:
            paper_id = review.paper.id
            reviewer_id = review.reviewer.id
            review_matrix[paper_id, reviewer_id] = review.rating

        return review_matrix

    def apply_belief_propagation(self, max_iterations=50):
        """
        Apply Belief Propagation algorithm to obtain final decisions
        """
        print(f"\nApplying Belief Propagation algorithm...")

        # Create review matrix
        review_matrix = self.make_review_matrix()

        # Prepare BP prior; support single q or mixture (spammer-hammer)
        if self.reviewer_prior_tuple is not None:
            q_exp, frac_exp, q_base = self.reviewer_prior_tuple
            bp_prior = [[q_base, q_exp], [1.0 - frac_exp, frac_exp]]
        else:
            reviewer_correctness = self.reviewer_quality
            bp_prior = [[reviewer_correctness], [1.0]]

        # Calculate maximum papers per reviewer (for BP algorithm)
        # Ensure it's an integer and handle edge cases
        max_papers_per_reviewer = int(np.max(np.sum(review_matrix != 0, axis=0)))

        try:
            if self.use_ref_bp:
                from bp_ref import bp_with_references
                ref = self._simulate_reference_counts(seed=42)
                L,_ = bp_with_references(review_matrix, ref, max_iter=50, alpha0=self.ref_alpha0, beta0=self.ref_beta0, damping=self.ref_damping, temp=self.ref_temp)
                self.final_decisions = L
                print("Reference-anchored BP executed successfully")
            else:
                # Call existing BP algorithm
                self.final_decisions = bp_modified(
                    review_matrix,
                    "norm11",  # initialization method
                    max_iterations,
                    max_papers_per_reviewer,
                    bp_prior
                )
                print("BP algorithm executed successfully")

        except Exception as e:
            print(f"BP algorithm failed, using majority voting: {e}")
            self.final_decisions = self._majority_voting(review_matrix)

    def _majority_voting(self, review_matrix):
        """
        Majority voting fallback - returns raw scores based on vote proportions
        """
        decisions = np.zeros(self.paper_num)

        for paper_id in range(self.paper_num):
            votes = review_matrix[paper_id, :]
            non_zero_votes = votes[votes != 0]

            if len(non_zero_votes) == 0:
                decisions[paper_id] = 0.0
            else:
                # Return the average vote as raw score
                decisions[paper_id] = np.mean(non_zero_votes)

        return decisions

    def evaluate_performance(self):
        """
        Evaluate system performance
        """
        if not hasattr(self, 'final_decisions') or self.final_decisions is None:
            print("Error: Final decisions not generated yet")
            return None

        # Get true qualities and predicted qualities
        true_qualities = np.array([paper.quality for paper in self.papers])
        predicted_qualities = np.sign(self.final_decisions)

        # Calculate accuracy
        accuracy = np.mean(predicted_qualities == true_qualities)

        # Calculate confusion matrix
        tp = np.sum((predicted_qualities == 1) & (true_qualities == 1))
        tn = np.sum((predicted_qualities == -1) & (true_qualities == -1))
        fp = np.sum((predicted_qualities == 1) & (true_qualities == -1))
        fn = np.sum((predicted_qualities == -1) & (true_qualities == 1))

        # Calculate other metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        results = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1_score,
            'confusion_matrix': {'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn}
        }

        # Calculate all metrics using the new metrics module
        true_qualities = [paper.quality for paper in self.papers]
        all_metrics = calculate_all_metrics(self.final_decisions, true_qualities)

        # Extract results from metrics module
        calibration_error_result = all_metrics['calibration_error']
        bp_raw_error_result = all_metrics['bp_raw_error']  # Backward compatibility
        bp_f1_loss_result = all_metrics['bp_f1_loss']
        kl_divergence_result = all_metrics['kl_divergence']
        js_divergence_result = all_metrics['js_divergence']

        # Add new metrics to results
        results.update({
            'calibration_error': calibration_error_result['calibration_error'],
            'bp_raw_error': calibration_error_result['calibration_error'],  # Backward compatibility uses same value
            'bp_f1_loss': bp_f1_loss_result['f1_loss'],
            'bp_f1_score_50pct': bp_f1_loss_result['f1_score'],
            'bp_acceptance_rate': bp_f1_loss_result['acceptance_rate'],
            'kl_divergence': kl_divergence_result['kl_divergence'],
            'js_divergence': js_divergence_result['js_divergence']
        })

        print(f"\n=== Performance Evaluation ===")
        print(f"Accuracy: {accuracy:.3f}")
        print(f"Precision: {precision:.3f}")
        print(f"Recall: {recall:.3f}")
        print(f"F1-score: {f1_score:.3f}")
        print(f"Confusion Matrix: TP={tp}, TN={tn}, FP={fp}, FN={fn}")
        print(f"\n=== New Metrics ===")
        print(f"Calibration Error (avg): {calibration_error_result['calibration_error']:.3f}")
        print(f"BP Raw Error (avg): {calibration_error_result['calibration_error']:.3f}")  # Same value for compatibility
        print(f"BP F1 Loss (50% accept): {bp_f1_loss_result['f1_loss']:.3f}")
        print(f"BP F1 Score (50% accept): {bp_f1_loss_result['f1_score']:.3f}")
        print(f"BP Acceptance Rate: {bp_f1_loss_result['acceptance_rate']:.3f}")
        print(f"KL Divergence: {kl_divergence_result['kl_divergence']:.4f}")
        print(f"JS Divergence: {js_divergence_result['js_divergence']:.4f}")
        print(f"True Distribution - Accept: {kl_divergence_result['true_distribution']['accept']:.3f}, Reject: {kl_divergence_result['true_distribution']['reject']:.3f}")
        print(f"Review Distribution - Accept: {kl_divergence_result['review_distribution']['accept']:.3f}, Reject: {kl_divergence_result['review_distribution']['reject']:.3f}")

        return results

    def run_direct_review(self):
        """
        Run complete direct review process
        """
        print("=== Starting Direct Review Mode ===")

        # 1. Generate papers and reviewers
        self.generate_papers_and_reviewers()

        # 2. Assign review tasks
        self.assign_reviews_direct()

        # 3. Execute reviews
        self.conduct_reviews()

        # 4. Apply BP algorithm
        self.apply_belief_propagation()

        # 5. Evaluate performance
        results = self.evaluate_performance()

        print("=== Direct Review Mode Completed ===")
        return results


def main():
    """
    Test direct review mode
    """
    # Test configurations
    test_configs = [
        {
            'paper_num': 20,
            'reviewer_num': 15,
            'reviewer_quality': 0.8  # All reviewers have 80% accuracy
        },
        {
            'paper_num': 40,
            'reviewer_num': 25,
            'reviewer_quality': 0.9  # All reviewers have 90% accuracy
        }
    ]

    print("Starting Direct Review Mode Testing\n")

    for i, config in enumerate(test_configs):
        print(f"\n{'='*60}")
        print(f"Test Configuration {i+1}: {config}")
        print(f"{'='*60}")

        # Create and run direct review system
        conference = DirectReviewConference(**config)
        results = conference.run_direct_review()

        # Output summary
        if results:
            print(f"\n[Config {i+1} Summary] Accuracy: {results['accuracy']:.3f}, F1: {results['f1_score']:.3f}")


if __name__ == "__main__":
    main()