from env import Paper, Task, Reviewer
from protocols import CommonDirectReview
import numpy as np
import random
from assign import bipartite_graph_matching_with_tolerance
import networkx as nx
from bp_new import bp_modified
from kargers_model import kg_binary

config = {
    "paper_num": 20,
    "budget": 120,
    "reviewer_num": 20,
    "reviewer_prior": [0.3, 1, 0.5]
}

def majority_vote(A: np.array):
    X = np.zeros(A.shape[0])
    for i in range(A.shape[0]):
        X[i] = np.sum(A[i])
    return X

def compute_acc(papers, decision):
    correct = 0
    for i, rate in enumerate(decision):
        if np.sign(rate) == papers[i].quality:
            correct += 1
    return correct / len(papers)

with open("records.txt", "w") as fw:
    for prior in [0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]:
        acc1_s = []
        acc2_s = []
        for t in range(20):
            print(f"Processing...Iteration {t+1}")
            budget = config["budget"]
            paper_num = config["paper_num"]
            reviewer_num = config["reviewer_num"]
            reviewer_prior = [prior, 1, 0.5]

            papers_quality = np.array([-1 for _ in range(paper_num)])
            accept_indices = random.sample([i for i in range(paper_num)], int(paper_num*0.2))
            papers_quality[accept_indices] = 1

            papers = []
            paper_id = 0
            for paperq in papers_quality:
                papers.append(Paper(paper_id, paperq))
                paper_id += 1

            reviewers_noise = np.array([reviewer_prior[2] for _ in range(reviewer_num)])
            expert_indices = random.sample([i for i in range(reviewer_num)], int(reviewer_num * reviewer_prior[1]))
            reviewers_noise[expert_indices] = reviewer_prior[0]

            reviewers = []
            id = 0
            for rr in reviewers_noise:
                reviewers.append(Reviewer(id, rr))
                id += 1

            l = budget // paper_num
            r = budget // reviewer_num

            q = 1

            assignment = bipartite_graph_matching_with_tolerance(paper_num, reviewer_num, l, r, q)
            tasks = []
            task_id = 0
            # Read edges from the graph
            edges = list(assignment.edges)
            for edge in edges:
                paper_id_l, reviewer_id_r = edge
                paper_id = int(paper_id_l[1:])
                reviewer_id = int(reviewer_id_r[1:])
                tasks.append(Task(task_id, CommonDirectReview(), [papers[paper_id]], [reviewers[reviewer_id]], []))
                task_id += 1

            maxIter = 10
            review_matrix = np.zeros((paper_num, reviewer_num))
            for task in tasks:
                review_matrix[task.papers[0].id, task.reviewers[0].id] = task.task_do()
            bp_prior = [[1-reviewer_prior[2], 1-reviewer_prior[0]], [1-reviewer_prior[1], reviewer_prior[1]]] 
            # decision = majority_vote(review_matrix)
            decision = bp_modified(review_matrix, "norm11", maxIter, r+q, bp_prior)
            # decision = kg_binary(review_matrix, maxIter)

            # simulate 2-stage
            review_rati = 0.75
            budget_rati = 0.75

            reviewers1 = random.sample(reviewers, int(reviewer_num * review_rati))

            budget1 = int(budget * budget_rati)
            l1 = budget1 // paper_num
            r1 = budget1 // len(reviewers1)

            q = 1

            assignment1 = bipartite_graph_matching_with_tolerance(paper_num, len(reviewers1), l1, r1, q)
            tasks1 = []
            task_id = 0
            # Read edges from the graph
            edges = list(assignment1.edges)
            for edge in edges:
                paper_id_l, reviewer_id_r = edge
                paper_id = int(paper_id_l[1:])
                reviewer_id = int(reviewer_id_r[1:])
                tasks1.append(Task(task_id, CommonDirectReview(), [papers[paper_id]], [reviewers1[reviewer_id]], []))
                task_id += 1

            maxIter = 10    
            review_matrix1 = np.zeros((paper_num, reviewer_num))
            for task in tasks1:
                review_matrix1[task.papers[0].id, task.reviewers[0].id] = task.task_do()
            bp_prior = [[1-reviewer_prior[2], 1-reviewer_prior[0]], [1-reviewer_prior[1], reviewer_prior[1]]] 
            max_cliquesize = np.max([len(v) for v in np.where(review_matrix1 != 0)])
            decision1 = bp_modified(review_matrix1, "norm11", maxIter, max_cliquesize, bp_prior)
            print(decision1)
            papersII = []
            # for paper in papers:
            #     if np.abs(decision1[paper.id]) < l1 / 2:
            #         papersII.append(paper)

            # Rank papers by the absolute value of decision1 and take the smallest half
            # ranked_papers = sorted(papers, key=lambda paper: np.abs(decision1[paper.id]))
            ranked_papers = sorted(papers, key=lambda paper: decision1[paper.id], reverse=True)
            papersII = ranked_papers[:int(len(ranked_papers)*0.5)]

            reviewers2 = [reviewer for reviewer in reviewers if reviewer not in reviewers1]
            budget2 = budget - budget1
            l = budget2 // len(papersII)
            r = budget2 // len(reviewers2)

            q = 1
            assignment2 = bipartite_graph_matching_with_tolerance(len(papersII), len(reviewers2), l, r, q)

            tasks2 = []
            task_id = len(tasks1)
            # Read edges from the graph
            edges = list(assignment2.edges)
            for edge in edges:
                paper_id_l, reviewer_id_r = edge
                paper_id = int(paper_id_l[1:])
                reviewer_id = int(reviewer_id_r[1:])
                tasks2.append(Task(task_id, CommonDirectReview(), [papersII[paper_id]], [reviewers2[reviewer_id]], []))
                task_id += 1

            for task in tasks2:
                review_matrix1[task.papers[0].id, task.reviewers[0].id] = task.task_do()
            # decision2 = majority_vote(review_matrix1)
            max_cliquesize = np.max([len(v) for v in np.where(review_matrix1 != 0)])
            decision2 = bp_modified(review_matrix1, "norm11", maxIter, max_cliquesize, bp_prior)
            # decision2 = kg_binary(review_matrix1, maxIter)

            acc1 = compute_acc(papers, decision)
            acc2 = compute_acc(papers, decision2)

            acc1_s.append(acc1)
            acc2_s.append(acc2)
    
    
        fw.write(f"prior: {prior}\n")
        fw.writelines(f"1-stage: {acc1_s}, mean: {np.mean(acc1_s)}\n")
        fw.write(f"2-stage: {acc2_s}, mean: {np.mean(acc2_s)}\n")