# load data

def process_sqa(path):
    with open(path, encoding='utf-8') as f:
        data = json.load(f)
    
    examples = []
    for d in data:
        ex = {
            "id": d["qid"], "question": d["question"].lstrip(),
            "answer": "Yes" if d["answer"] else "No",
            "facts": d["facts"], "decomposition" : d["decomposition"]
        }
        examples.append(ex)

    return examples

def process_gpqa(path):
    df = pd.read_csv(path)
    data = df.to_dict(orient="records")
    letter_answer_choices = ['(A)', '(B)', '(C)', '(D)']

    examples = []
    for index,d in enumerate(data):
        random.seed(index+42)
        list_choices = [d['Incorrect Answer 1'], d['Incorrect Answer 2'], d['Incorrect Answer 3'], d['Correct Answer']]
        random.shuffle(list_choices)
        ex = {
            "question": d["Question"].strip(),
            "correct_answer": list_choices.index(d['Correct Answer']),
            "text_answer":d['Correct Answer'],
            "explanation": d["Explanation"],
            "choice1": list_choices[0],"choice2": list_choices[1],
            "choice3": list_choices[2],"choice4": list_choices[3],
        }
        examples.append(ex)

    return examples


def data_loading_main(args):
    
    dataset = {"demo":[],"test":[]}

    if args.task in ["pddl","pddl-ood"]:

        DATA_DIR = "./data/pddl/"

        train_file = os.path.join(DATA_DIR, "training_file_pddl_plan_v6.json")
        dev_file = os.path.join(DATA_DIR, "dev_file_pddl_plan_v6.json")

        if "-ood" in args.task:
             dev_file = os.path.join(DATA_DIR, "dev_file_ood_8_20_v6.json")
        
        plans_df = pd.read_json(train_file, lines=True, orient="records")
        test_df = pd.read_json(dev_file, lines=True, orient="records")

        dataset["demo"] = plans_df.to_dict(orient="records") # [{inputs, targets}]
        dataset["test"] = test_df.to_dict(orient="records") 

    elif args.task in ["aime24"]:
        s1k = load_dataset("simplescaling/s1K-1.1")
        # solution or gemini_attempt? gemini_attempt
        dataset["demo"] = [{"problem":row["question"],"solution":row["gemini_attempt"]} for row in s1k["train"] if row["cot_type"] == "math"]
        
        temp_dataset = load_dataset("HuggingFaceH4/aime_2024", split="train")
        df = temp_dataset.to_pandas()
        examples = [row.to_dict() for _, row in df.iterrows()]
        dataset["test"] = [{k: row[k] for k in ["problem", "solution", "answer"]} for row in examples] # ["problem, solution, answer"]

    elif args.task in ["sqa"]:
        # we use the train-test split in EXPLORA paper
        with open('./data/sqa/strategyqa_train.json', encoding='utf-8') as f:
            data = json.load(f)
        
        examples = []
        for d in data:
            ex = {
                "id": d["qid"], "question": d["question"].strip(),
                "answer": "Yes" if d["answer"] else "No",
                "facts": d["facts"], "decomposition" : d["decomposition"]
            }
            examples.append(ex)
        
        data = pd.DataFrame(examples)
        
        train_num = 1800 # 490 test ex
        labels = [x["answer"] for index, x in data.iterrows()]
        train_set, test_set, _, _ = train_test_split(data, labels, train_size=train_num, stratify=labels, random_state=7)
             
        dataset["demo"] = train_set.to_dict(orient="records")  
        dataset["test"] = test_set.to_dict(orient="records") #[id,question,answer,facts,decomposition]

    elif args.task in ["gpqa"]:
        temp_demos = process_gpqa('./data/gpqa/gpqa_main.csv') # 'dataset/gpqa_extended.csv'
        dataset["test"] = process_gpqa('./data/gpqa/gpqa_diamond.csv') # [{question, correct answer, incorrect answer 1-3, explanation}]
        
        test_questions = [row["question"] for row in dataset["test"]]
        for row in temp_demos:
            if row["question"] not in test_questions:
                dataset["demo"].append(row)
    else:
        print("unknown task...")
        
    if args.test_scope != -1:
        dataset["test"] = dataset["test"][:args.test_scope]

    print("data stats:",len(dataset["demo"]),len(dataset["test"]))

    return dataset


def save_results(data, path, filename):
    if not os.path.exists(path):
        os.makedirs(path)
    
    with open(os.path.join(path, filename), "w") as f:
        json.dump(data, f, indent=2)

class Step(BaseModel):
    explanation: str
    output: str

class MathReasoning(BaseModel):
    steps: list[Step]
    final_answer: str

def gen_with_together(prompt,args):

    client_map = {
        "qwen2.5":"Qwen/Qwen2.5-7B-Instruct-Turbo",
        "qwen70":"Qwen/Qwen2.5-72B-Instruct-Turbo",
    }
    
    response = together_client.chat.completions.create(
        model= client_map[args.client],
        messages= [{"role": "user", "content": prompt}],
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        n=args.num_samples,
        top_p=0.7,
        top_k=50,
        repetition_penalty=1,
        stop=["<|eot_id|>","<|eom_id|>"],
        stream=False
    )

    return [choice.message.content for choice in response.choices]

def gen_with_gemini(prompt, client, args):
    
    response = client.generate_content(prompt,
            generation_config=genai.types.GenerationConfig(
            # Only one candidate for now.
            candidate_count=args.num_samples,
            max_output_tokens=args.max_tokens,
            temperature=args.temperature,
        ),)

    if args.num_samples == 1:
        return [response.text]
    else:
        return [candidate.content.parts[0].text for candidate in response.candidates]

def generate_with_retries(prompt, client, args):
  retry_count = 0
  while retry_count < MAX_RETRIES:
    try:
        if args.structure_output:
            response = client.generate_content(prompt,
                    generation_config=genai.types.GenerationConfig(
                    # Only one candidate for now.
                    candidate_count=args.num_samples,
                    max_output_tokens=args.max_tokens,
                    temperature=args.temperature,
                    response_mime_type="application/json",
                    response_schema=MathReasoning,
                ),)
        else:

            if args.client in ["pro"]:            
                return gen_with_gemini(prompt, client, args)
            elif args.client in ["qwen2.5","qwen70"]:
                return gen_with_together(prompt, args)

            
    except Exception as e:
      retry_count += 1
      # print(f"Attempt {retry_count} failed for prompt '{prompt}': {e}")
      if retry_count < MAX_RETRIES:
        time.sleep(RETRY_INTERVAL)
      else:
        # print(f"Max retries reached for prompt '{prompt}', skipping.")
        return " "


def get_one_step_prompt(example, args, prompt_type="demo"):

    if args.task in ["pddl","pddl-ood"]:
        prompt = ""
        if prompt_type=="demo":
            demo_template = "Please solve the problem:\n<input>\nYour plan as plain text without formatting:\n<plan>\n done.\n"
            prompt = demo_template.replace("<input>",example["inputs"]).replace("<plan>",example["targets"])
        elif prompt_type=="test":
            test_template = "Please solve the problem:\n<input>\nYour plan as plain text without formatting:"
            prompt = test_template.replace("<input>",example["inputs"])
        
    elif args.task in ["aime24"]:
        prompt = ""
        if prompt_type=="demo":
            demo_template = "Please solve the problem: <problem>\nAnswer:<solution>" # Answer:
            prompt = demo_template.replace("<problem>",example["problem"]).replace("<solution>",example["solution"])
        elif prompt_type=="test":
            test_template = "Please solve the problem: <problem>"
            prompt = test_template.replace("<problem>",example["problem"])

    elif args.task in ["sqa"]:
        if prompt_type=="demo":
            count=0
            prompt = ""
            prompt += "\nQuestion:" + example["question"]
            
            for p in example["facts"]:
                if count==0:
                    prompt += "\nFacts:" + p
                else:
                    prompt += p
                count+=1
            
            prompt += "\nAnswer:\n"
            for i, p in enumerate(example["decomposition"]):
                prompt += "Sub-question {}: {}\n".format(i+1,p)
            prompt += "The answer is:" + example["answer"]
            
        elif prompt_type=="test":
            count = 0
            prompt = ""            

            prompt += "Question:" + example["question"]

    elif args.task in ["gpqa"]:
        # we use the official implementation from GPQA github
        letter_answer_choices = ['A', 'B', 'C', 'D']
        
        if prompt_type == "demo":
            prompt = f'Question: {example["question"]}'
            prompt += f'\nChoices:\n(A) {example["choice1"]}\n(B) {example["choice2"]}\n(C) {example["choice3"]}\n(D) {example["choice4"]}'
            prompt += f"\nLet's think step by step: \n{example['explanation']}\n"  

            prompt += f'The correct answer is ({letter_answer_choices[example["correct_answer"]]})\n' # (A)...(D)
            
        elif prompt_type=="test":
            prompt = "Question:" + example["question"]
            prompt += f'\nChoices:\n(A) {example["choice1"]}\n(B) {example["choice2"]}\n(C) {example["choice3"]}\n(D) {example["choice4"]}'

    return prompt


def prompt_for_generation(example, all_demos, sim_index, item_scores, args):

    if args.icl_sim_mode in ["zero","cot"]:
        if args.icl_sim_mode == "zero":
            return get_one_step_prompt(example, args, prompt_type="test")
        elif args.icl_sim_mode == "cot":
            return get_one_step_prompt(example, args, prompt_type="test") + "\nLet's think step by step"
    
    demos = [all_demos[index] for index in sim_index[:args.num_demo]]

    if args.debug_mode:
        print(sim_index[:args.num_demo])

    assert len(demos) == args.num_demo

    prompt = ""
    offset = "\n\n\n"

    if args.task in ["pddl","pddl-ood","aime24"]:
        head_prompt = ""
        offset = "\n"

    elif args.task in ["sqa"]:
        head_prompt = """\n\nFollow the given examples that use the facts to answer a question by decomposing into sub questions first and then predicting the final answer as "Yes" or "No" only.\n\n\n"""  

    elif args.task in ["gpqa"]:
        head_prompt = "Here are some example questions from experts. Answer the final question yourself, following the format of the previous questions exactly.\n\n\n"

    
    for demo in demos:
        # prompt = get_one_step_prompt(demo, args, prompt_type="demo")+"\n\n\n"+prompt # reverse
        prompt += offset + get_one_step_prompt(demo, args, prompt_type="demo") # direct

    prompt = head_prompt + prompt.strip() + offset + get_one_step_prompt(example, args, prompt_type="test")

    return prompt.strip()


def process_answer(response, example, args):
    # for PDDL, we need offline processing

    if args.structure_output:
        try:
            parsed_ans = json.loads(response)
            response = parsed_ans["final_answer"]
        except:
            response = response # do nothing

    if args.task in ["sqa"]:
        if len(response.split("The answer is:"))>1:
            answer = response.split("The answer is:")[1]
            answer = answer.split("\n")[0]
        else:
            answer = response

        ground_truth = example["answer"]
        
        if ground_truth.lower() in answer.lower() or answer.lower() in ground_truth.lower():
            this_match = 1
        else:
            this_match = 0
    
    elif args.task in ["aime24"]:
        ground_truth = example["answer"]
        answer = response

        if str(ground_truth) in answer:
            this_match = 1
        else:
            this_match = 0

    elif args.task in ["gpqa"]:
        LETTER_TO_INDEX = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        
        patterns = [r'answer is \((.)\)', r'Answer: \((.)\)', r'answer: \((.)\)', r'answer \((.)\)', r'\((.)\)']
        answer = -1
        ground_truth = example["correct_answer"]
        for pattern in patterns:
            match = re.search(pattern, response)
            if match and match.group(1) in LETTER_TO_INDEX:
                answer = LETTER_TO_INDEX[match.group(1)]

        if answer == ground_truth:
            this_match = 1
        else:
            this_match = 0
            
    return answer, ground_truth, this_match


def icl_main(client, dataset, pair_sim_map, args):
    # running the icl pipeline
    print("We are working on:",args.task,"| Num Demo",args.num_demo,"|", args.icl_sim_mode)

    output_collection = [] # metadata, prompt, response, accuracy, question, answer
    demo_count = []
    acc_record = []
    acc_record_at_k = []

    input_key_mapper={"pddl":"inputs","pddl-ood":"inputs","aime24":"problem","sqa":"question","gpqa":"question"}
    ans_key_mapper={"pddl":"targets","pddl-ood":"targets","aime24":"answer","sqa":"answer","gpqa":"correct_answer"}

    if args.test_scope == -1:
        args.test_scope = len(dataset["test"])
        
        if args.task in ["pddl","pddl-ood"]:
            args.test_scope = 300

    suffix = ""
    if args.icl_sim_type == "pg":
        suffix = "-prompt-gs"
    if args.icl_sim_type == "pq":
        suffix = "-prompt-question"

    if args.iters > 0:
        suffix += "-iter"+str(args.iters)
    
    if args.icl_sim in ["dynamic","pivot"]:
        try:                
            if args.icl_sim_mode != "mmr":    
                with open(f"./resources/{args.task}/sim_index-{args.task}-{args.icl_sim_mode}{suffix}.json","r",encoding="utf-8") as f:
                    pair_sim_map = json.load(f)

            else:          
                with open(f"./resources/{args.task}/mmr_index-{args.task}-gecko{suffix}.json","r",encoding="utf-8") as f:
                    pair_sim_map = json.load(f)

        except:
            print("unknown icl sim mode...")
            return False
            
    elif args.icl_sim in ["static","pivot"]:

        if args.icl_sim_mode != "random":

            ablation_list = []
            for raw_sim_mode in ['weighted', 'eigenvector', 'authority', 'pagerank', 'modularity']:
                for raw_graph_mode in ["bi","tr","bi+tr"]:
                    ablation_list.append(raw_sim_mode+"-"+raw_graph_mode)                
            
            if args.icl_sim_mode in ["auto-cot", "majority-test","cumulation-test", "majority-train","cumulation-train"]:
                # representation based
                with open(f"./resources/{args.task}/static_index-{args.task}-gecko{suffix}.json","r",encoding="utf-8") as f:
                    static_sim_collection = json.load(f)
                    
            elif args.icl_sim_mode in ['weighted', 'eigenvector', 'authority', 'pagerank', 'modularity', 'ensemble']:
                with open(f"./resources/{args.task}/static_graph_index-{args.task}-gecko{suffix}.json","r",encoding="utf-8") as f:
                    static_sim_collection = json.load(f)
            
            elif args.icl_sim_mode in ablation_list:
                with open(f"./resources/{args.task}/ablation_graph_index-{args.task}-gecko{suffix}.json","r",encoding="utf-8") as f:
                    static_sim_collection = json.load(f)

    if args.icl_sim == "pivot":
        global_mean = float(np.mean(pair_sim_map))
        global_std = float(np.mean([np.std(item) for item in pair_sim_map]))
        # global_std = float(np.std(pair_sim_map))

        try:
            with open(f"./resources/{args.task}/static_graph_index-{args.task}-gecko{suffix}.json","r",encoding="utf-8") as f:
                static_sim_collection = json.load(f)

        except:
            with open(f"./resources/{args.task}/static_graph_index-{args.task}-gecko-prompt-question.json","r",encoding="utf-8") as f:
                static_sim_collection = json.load(f)

    
    
    for i in tqdm(range(args.test_scope)):
        example = dataset["test"][i]
        
        if args.data_scope != -1:
            dataset["demo"] = dataset["demo"][:args.data_scope]
        else:
            if args.task in ["pddl","pddl-ood"]:
                args.data_scope = len(dataset["demo"])
                #args.data_scope = 28000
            else:
                args.data_scope = len(dataset["demo"])
            
        if args.icl_sim == "static":
            if args.icl_sim_mode in ["random","zero","cot"]:
                sim_index= range(args.data_scope)
                random.Random(7).shuffle(list(sim_index))
                item_scores = [0] * len(sim_index)
                
            elif args.icl_sim_mode in ["trial"]:
                sim_index = [422, 707, 337, 697, 355, 858, 102, 182, 792, 760, 55, 491, 492, 34] 
                item_scores = [0] * len(sim_index)
                
            else:
                sim_index = static_sim_collection[args.icl_sim_mode]
                item_scores = [0] * len(sim_index)
                
        elif args.icl_sim == "dynamic": 
            if args.icl_sim_mode != "mmr":
                sim_index = (-np.array(pair_sim_map[i])).argsort().tolist()
                item_scores = pair_sim_map[i]
            else:
                sim_index = pair_sim_map[i]
                item_scores = None

        elif args.icl_sim == "pivot":
            raw_sim_index = (-np.array(pair_sim_map[i])).argsort().tolist()
            sim_index = []

            std_offset = {"aime24":2,"gpqa":2,"sqa":2,"pddl":2}
            
            for selected_demo_index in raw_sim_index:
                if pair_sim_map[i][selected_demo_index] < (global_mean + std_offset[args.task]*global_std):
                    break
                sim_index.append(selected_demo_index)

            if args.debug_mode:
                print(len(sim_index))
                
            demo_count.append(len(sim_index))

            if len(sim_index) < args.num_demo:
                append_len = args.num_demo - len(sim_index)
                # method 1: direct
                # sim_index.extend(static_sim_collection["authority"])

                for static_demo_index in static_sim_collection["authority"]:
                    if static_demo_index not in sim_index:
                        sim_index.append(static_demo_index)
                        
                # method 2 prepend
                # top_k = copy.deepcopy(static_sim_collection["authority"][:append_len])
                # sim_index = top_k + sim_index
                
                # assert len(sim_index) == args.num_demo

                # potential method 3, we also control the static score?
            
            item_scores = [0] * len(sim_index)
            
        else: 
            print("unknown icl_sim")
            
        # icl
        all_demos = copy.deepcopy(dataset["demo"])
        prompt = prompt_for_generation(example, all_demos, sim_index, item_scores, args)

        
        if args.debug_mode:
            response = [""]
            # if i in [0,1]:
            #     print(prompt)
        else:
            response = generate_with_retries(prompt, client, args)

        # evaluation
        if args.task in ["pddl","pddl-ood"]:
            # need offline processing
            acc_record.append(-1)
            acc_record_at_k.append(-1)
        else:
            answer, ground_truth, this_match = process_answer(response[0], example, args)
            if this_match == 1:
                acc_record.append(1)
            else:
                acc_record.append(0)
                
            if args.passk:
                # pass@k
                temp_acc = 0
                for resp in response:
                    answer, ground_truth, this_match = process_answer(resp, example, args)
                    if this_match == 1:
                        temp_acc = 1
                        
                acc_record_at_k.append(temp_acc)

        temp_output_dic = {
            "metadata":{"temperature":args.temperature, "max_answer_tokens":args.max_tokens},
            "prompt":prompt,
            "response":response,
            "accuracy":acc_record[-1],
            "question":example[input_key_mapper[args.task]],
            "answer":example[ans_key_mapper[args.task]],
        }

        output_collection.append(temp_output_dic)

    if args.debug_mode:
        print(np.mean(demo_count))
    
    print("EM:", np.mean(acc_record))
    if args.passk:
        print("EM@k:", np.mean(acc_record_at_k))
    
    segs = [args.client,args.task,str(args.test_scope),str(args.data_scope),args.client,str(args.num_demo),args.icl_sim,args.icl_sim_mode]

    if args.icl_sim_type != "qq":
        segs.append(args.icl_sim_type)
    
    if args.passk:
        segs.append("pass@"+str(args.num_samples))

    if args.iters > 0:
        segs.append("iter"+str(args.iters))
    
    filename = "-".join(segs)+".json"
    path = f"./ablation_results/trial/{args.task}"

    if not args.debug_mode:
        save_results(output_collection, path, filename)
        
    return output_collection


def trial_bipartite_weighted_hits(G1, G2, connections, max_iter=100, tol=1.0e-8, normalized=True):

    # Initialize hubs for G1 nodes (sources) and authorities for G2 nodes (targets)
    hubs = dict([(n, 1.0) for n in G1.nodes()])
    authorities = dict([(n, 1.0) for n in G2.nodes()])
    
    # Create the connection matrix between the two graphs
    nodes_G1 = list(G1.nodes())
    nodes_G2 = list(G2.nodes())
    
    # Create mapping from node to index
    node_to_idx_G1 = {node: i for i, node in enumerate(nodes_G1)}
    node_to_idx_G2 = {node: i for i, node in enumerate(nodes_G2)}

    
    # Initialize weight matrix
    weight_matrix = np.zeros((len(nodes_G1), len(nodes_G2)))
    
    # Fill in weight matrix
    for source, target, weight in connections:
        if source in G1 and target in G2:
            weight_matrix[node_to_idx_G1[source], node_to_idx_G2[target]] = weight

    G1 = our_clean(G1)
    G2 = our_clean(G2)
    
    # Power iteration
    for _ in range(max_iter):
        # Save previous hubs and authorities for convergence check
        old_hubs = hubs.copy()
        old_authorities = authorities.copy()
        
        # Update authorities (G2 nodes) using weighted connections from hubs (G1 nodes)
        for j, node in enumerate(nodes_G2):
            auth_score = 0.0
            for i, source in enumerate(nodes_G1):
                if weight_matrix[i, j] > 0:  # If there's an edge from i to j
                    auth_score += old_hubs[source] * weight_matrix[i, j]
            authorities[node] = auth_score
            
        # Update hubs (G1 nodes) using weighted connections to authorities (G2 nodes)
        for i, node in enumerate(nodes_G1):
            hub_score = 0.0
            for j, target in enumerate(nodes_G2):
                if weight_matrix[i, j] > 0:  # If there's an edge from i to j
                    hub_score += old_authorities[target] * weight_matrix[i, j]
            hubs[node] = hub_score
        
        # Normalize
        if normalized:
            auth_sum = sum(v for v in authorities.values())
            hub_sum = sum(v for v in hubs.values())
            
            if auth_sum > 0:
                authorities = {k: v/auth_sum for k, v in authorities.items()}
            if hub_sum > 0:
                hubs = {k: v/hub_sum for k, v in hubs.items()}
        
        # Check for convergence
        err_hubs = sum([abs(hubs[n] - old_hubs[n]) for n in G1.nodes()])
        err_authorities = sum([abs(authorities[n] - old_authorities[n]) for n in G2.nodes()])
        
        if err_hubs < tol and err_authorities < tol:
            break
            
    return hubs, authorities



def visualize_bipartite_hits(G1, G2, connections, hubs, authorities):
    """Visualize the bipartite graph with hub and authority scores."""
    plt.figure(figsize=(12, 8))
    
    # Create bipartite graph
    B = nx.Graph()
    B.add_nodes_from(G1.nodes(), bipartite=0)  # First set of nodes
    B.add_nodes_from(G2.nodes(), bipartite=1)  # Second set of nodes
    
    # Add weighted edges
    for source, target, weight in connections:
        B.add_edge(source, target, weight=weight)
    
    # Node positions - place G1 nodes on left and G2 nodes on right
    pos = {}
    pos.update((node, (1, i)) for i, node in enumerate(G1.nodes()))
    pos.update((node, (2, i)) for i, node in enumerate(G2.nodes()))
    
    # Node sizes based on scores
    node_sizes = {}
    node_sizes.update({node: hubs[node] * 1000 for node in G1.nodes()})
    node_sizes.update({node: authorities[node] * 1000 for node in G2.nodes()})
    
    node_colors = {}
    node_colors.update({node: 'skyblue' for node in G1.nodes()})
    node_colors.update({node: 'lightgreen' for node in G2.nodes()})
    
    # Edge widths based on weights
    edge_widths = [B[u][v]['weight'] * 2 for u, v in B.edges()]
    
    # Draw the graph
    plt.title('Bipartite Weighted HITS')
    
    # Draw nodes
    nx.draw_networkx_nodes(B, pos, 
                          node_size=[node_sizes[node] for node in B.nodes()],
                          node_color=[node_colors[node] for node in B.nodes()])
    
    # Draw edges
    nx.draw_networkx_edges(B, pos, width=edge_widths, edge_color='gray', alpha=0.7)
    
    # Draw labels
    nx.draw_networkx_labels(B, pos, font_weight='bold')
    
    # Add score labels
    hub_labels = {node: f"{hubs[node]:.3f}" for node in G1.nodes()}
    auth_labels = {node: f"{authorities[node]:.3f}" for node in G2.nodes()}
    
    hub_pos = {node: (pos[node][0] - 0.1, pos[node][1] - 0.1) for node in G1.nodes()}
    auth_pos = {node: (pos[node][0] + 0.1, pos[node][1] - 0.1) for node in G2.nodes()}
    
    nx.draw_networkx_labels(B, hub_pos, labels=hub_labels, font_size=8, font_color='blue')
    nx.draw_networkx_labels(B, auth_pos, labels=auth_labels, font_size=8, font_color='green')
    
    # Add legend
    plt.text(0.8, 0.05, 'Hub Nodes (G1)', color='skyblue', fontsize=12, 
             bbox=dict(facecolor='white', alpha=0.7))
    plt.text(0.8, 0.1, 'Authority Nodes (G2)', color='lightgreen', fontsize=12,
             bbox=dict(facecolor='white', alpha=0.7))
    
    plt.axis('off')
    plt.tight_layout()
    plt.savefig('bipartite_weighted_hits.png')
    plt.show()


def run_comparison(G1, G2, connections):
    """Run and compare different variants of the algorithm."""
    # Basic bipartite weighted HITS
    hubs, authorities = bipartite_weighted_hits(G1, G2, connections)
    
    print("Basic Bipartite Weighted HITS:")
    print("Hub scores (G1 nodes):")
    for node, score in sorted(hubs.items()):
        print(f"Node {node}: {score:.4f}")
    
    print("\nAuthority scores (G2 nodes):")
    for node, score in sorted(authorities.items()):
        print(f"Node {node}: {score:.4f}")
    
    # Variant 1: Non-normalized scores
    hubs_nn, authorities_nn = bipartite_weighted_hits(G1, G2, connections, normalized=False)
    
    print("\nNon-normalized Bipartite Weighted HITS:")
    print("Hub scores (G1 nodes):")
    for node, score in sorted(hubs_nn.items()):
        print(f"Node {node}: {score:.4f}")
    
    print("\nAuthority scores (G2 nodes):")
    for node, score in sorted(authorities_nn.items()):
        print(f"Node {node}: {score:.4f}")
    
    return hubs, authorities



def create_example_bipartite_graph():
    """Create two example graphs with weighted connections between them."""
    # Create first graph (G1) - these will be our hubs
    G1 = nx.Graph()
    G1.add_nodes_from(['A1', 'A2', 'A3', 'A4'])
    
    # Create second graph (G2) - these will be our authorities
    G2 = nx.Graph()
    G2.add_nodes_from(['B1', 'B2', 'B3', 'B4', 'B5'])
    
    # Define weighted connections between G1 and G2
    connections = [
        ('A1', 'B1', 0.9), ('A1', 'B2', 0.7), ('A1', 'B3', 0.3),
        ('A2', 'B1', 0.5), ('A2', 'B2', 0.8), ('A2', 'B4', 0.4),
        ('A3', 'B2', 0.6), ('A3', 'B3', 1.0), ('A3', 'B5', 0.7),
        ('A4', 'B3', 0.4), ('A4', 'B4', 0.9), ('A4', 'B5', 0.6)
    ]
    
    return G1, G2, connections


def create_specific_bipartite_graph():
    G1 = nx.Graph()
    G1.add_nodes_from(["A"+str(i) for i in range(len(test_questions))])

    G2 = nx.Graph()
    G2.add_nodes_from(["B"+str(i) for i in range(len(s1k_math))])

    corpus = [example["problem"]+"\n\n"+example["solution"] for example in s1k_math]

    retriever = bm25s.BM25()
    retriever.index(bm25s.tokenize(corpus))

    connections = []
    
    for i, example in enumerate(test_questions):
        query = example["problem"]+"\n\n"+example["solution"]

        query_tokens = bm25s.tokenize(query)
        results, scores = retriever.retrieve(query_tokens, k=800)
        top_indexes = results.tolist()[0]
        top_scores = scores.tolist()[0]

        for j, index in enumerate(top_indexes):
            connections.append(('A'+str(i),'B'+str(index),top_scores[j]))

    for i, query in enumerate(corpus):
        query_tokens = bm25s.tokenize(query)
        results, scores = retriever.retrieve(query_tokens, k=800)
        top_indexes = results.tolist()[0]
        top_scores = scores.tolist()[0]

        for j, index in enumerate(top_indexes):
            if i != index:
                connections.append(('A'+str(i),'B'+str(index),top_scores[j]))
        
    return G1, G2, connections