from typing import List, Dict
import numpy as np
from tqdm import tqdm
import json
from dotenv import load_dotenv
import os

import sys
sys.path.append('.')
# need to rename logging.py -> logger.py when running this file.

from datasets import load_dataset

from src.utils.prompt_handler import PromptHandler
from src.utils.api import get_llm_output
from src.utils.extract_json_reliable import extract_json

# --- Prompt Templates ---
CLUSTER_INITIALIZATION_PROMPT = """You are a helpful assistant that groups a given list of statements. The grouping should be based on the following criteria: {cluster_criteria_desc}.  

Here are the statements to be grouped:
{statements_with_indices}

Task: Carefully read through the provided criteria and statements, and then group the statements into disjoint groups. The number of groups is not specified, but you should make sure the groups are as disjoint as possible.

Output format: Provide a sentence describing the common pattern for each group and categorize the given statements accordingly. Your output should be a JSON object with the following structure:

{{
  0: {{
    "pattern": "<pattern>",
    "statement_idx": [<statement_idx1>, <statement_idx2>, ...]}},
  1: {{...}},
  ...
}}
{additional_requirements}
Your output:
"""

CLUSTER_ASSIGNMENT_PROMPT = """You are a helpful assistant tasked with classifying each statement into one of the specified patterns, choosing the most appropriate one based on the criteria: {cluster_criteria_desc}.

These are the statements to be classified:
{statements_with_indices}

These are the descriptions of the patterns:
{group_patterns}

Your output should be in the form of a JSON object, where the key represents the index of the query and the value corresponds to the index of the pattern. 

Your output:
"""


# --- Subfunction: Create Clusters ---
def create_clusters(
    model_name: str,
    statements_with_indices: str,
    cluster_criteria_desc: str,
    additional_requirements: str,
    cluster_initiator: PromptHandler,
    **generation_kwargs
) -> Dict[int, Dict[str, any]]:
    """Initialize clusters from a list of statements using the cluster initialization prompt."""
    while True:
        prompt = cluster_initiator(
            cluster_criteria_desc=cluster_criteria_desc,
            statements_with_indices=statements_with_indices,
            additional_requirements=additional_requirements,
        )
        
        clusters = get_llm_output(prompt, model=model_name, **generation_kwargs, json_object=True)
        if clusters and all(cluster_initiator.check_format(cluster) for cluster in clusters.values()):
            break
    return clusters

# --- Subfunction: Assign Clusters to Patterns ---
def assign_clusters_to_patterns(
    model_name: str,
    statement_types: str,
    statements_with_indices: str,
    cluster_criteria_desc: str,
    clusters: Dict[int, Dict[str, any]],
    cluster_assigner: PromptHandler,
    **generation_kwargs
) -> Dict[str, str]:
    """Assign statements to existing clusters based on their patterns."""
    if len(statement_types) > 0:
        group_patterns = statement_types
    else:
        group_patterns = "\n".join([f"{i}: {cluster['pattern']}" for i, cluster in clusters.items()])
    prompt = cluster_assigner(
        cluster_criteria_desc=cluster_criteria_desc,
        statements_with_indices=statements_with_indices,
        group_patterns=group_patterns,
    )
    print(prompt)
    assignment = get_llm_output(prompt, model=model_name, **generation_kwargs, json_object=True)
    print(assignment)

    return assignment

# --- Main Function ---
def cluster_statements_to_types(
    model_name: str,
    statements: List[str],
    initial_batch_size: int = 50,
    batch_size: int = 20,
    add_none: bool = True,
    cluster_criteria_desc: str = 'The causes of the reported coding exception on unit tests',
    additional_requirements: str = '',
    statement_types: str = '',
    **generation_kwargs
) -> Dict[int, Dict[str, any]]:
    """Cluster statements into types by initializing clusters and assigning remaining statements."""
    cluster_initiator = PromptHandler(
        template=CLUSTER_INITIALIZATION_PROMPT,
        name="Cluster_Initialization",
        input_keys=["cluster_criteria_desc", "statements_with_indices", "additional_requirements"],
        output_format={"pattern": str, "statement_idx": list},
        system_message="",
        strict_input=True
    )
    
    cluster_assigner = PromptHandler(
        template=CLUSTER_ASSIGNMENT_PROMPT,
        name="Cluster_Assignment",
        input_keys=["cluster_criteria_desc", "statements_with_indices", "group_patterns"],
        output_format=str,
        system_message="",
        strict_input=True
    )
    
    # Cluster Initialization with first batch
    initial_statements_with_indices = "\n".join([f"{i}: {s}" for i, s in enumerate(statements[:initial_batch_size])])
    clusters = create_clusters(
        model_name, initial_statements_with_indices, cluster_criteria_desc, additional_requirements, cluster_initiator, **generation_kwargs
    )
    
    remaining_statements = statements[initial_batch_size:]
    if not remaining_statements:
        return clusters

    # Add "None of the above" cluster if specified
    if add_none:
        none_idx = len(clusters)
        clusters[none_idx] = {"pattern": "None of the above", "statement_idx": []}
    print(clusters)
    
    # Assign remaining statements in batches
    for batch_idx in tqdm(range(int(np.ceil(len(remaining_statements) / batch_size)))):
        batch = remaining_statements[batch_idx * batch_size:(batch_idx + 1) * batch_size]
        statements_with_indices = "\n".join([f"{i + initial_batch_size + batch_idx * batch_size}: {s}" for i, s in enumerate(batch)])
        assignment = assign_clusters_to_patterns(
            model_name=model_name,
            statement_types=statement_types,
            statements_with_indices=statements_with_indices,
            cluster_criteria_desc=cluster_criteria_desc,
            clusters=clusters,
            cluster_assigner=cluster_assigner,
            **generation_kwargs
        )
        for str_idx, cluster_idx in assignment.items():
            if int(cluster_idx) not in clusters:
                print(f"Cluster {cluster_idx} not found in clusters, adding to none of the above")
                clusters[none_idx]["statement_idx"].append(int(str_idx))
                continue
            clusters[int(cluster_idx)]["statement_idx"].append(int(str_idx))
    
    # Handle "None of the above" reassignment and new cluster creation
    if add_none and clusters[none_idx]["statement_idx"]:
        none_statements_indices = clusters.pop(none_idx)["statement_idx"]
        statements_with_indices = "\n".join([f"{i}: {statements[i]}" for i in none_statements_indices])
        
        # Reassign "None of the above" statements to existing clusters
        assignment = assign_clusters_to_patterns(
            model_name=model_name,
            statement_types=statement_types,
            statements_with_indices=statements_with_indices,
            cluster_criteria_desc=cluster_criteria_desc,
            clusters=clusters,
            cluster_assigner=cluster_assigner,
            **generation_kwargs
        )
        still_none_indices = []
        for str_idx, cluster_idx in assignment.items():
            if cluster_idx is None:
                still_none_indices.append(int(str_idx))
                continue

            try:
                idx = int(cluster_idx)
            except (TypeError, ValueError):
                still_none_indices.append(int(str_idx))
                continue

            if idx >= len(clusters):  # Assigned to "None" (out of bounds)
                still_none_indices.append(int(str_idx))
            else:
                clusters[idx]["statement_idx"].append(int(str_idx))
        
        if still_none_indices:
            statements_with_indices = "\n".join([f"{i}: {statements[i]}" for i in still_none_indices])
            new_clusters = create_clusters(
                model_name, statements_with_indices, cluster_criteria_desc, additional_requirements, cluster_initiator, **generation_kwargs
            )
            for cluster in new_clusters.values():
                clusters[len(clusters)] = cluster
    
    return clusters


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Cluster statements from a HuggingFace dataset')
    parser.add_argument('--dataset', type=str, required=True,
                      help='HuggingFace dataset name (e.g., anonymous/bigcodebench-complete_qwen7b_gpt-4o-mini_att_iter0_att20_sol5_fixed)')
    parser.add_argument('--model', type=str, default="gpt-4o",
                      help='Model name to use for clustering (default: gpt-4o)')
    parser.add_argument('--initial_batch_size', type=int, default=200,
                      help='Initial batch size for clustering (default: 200)')
    parser.add_argument('--batch_size', type=int, default=10,
                      help='Batch size for remaining statements (default: 10)')
    parser.add_argument('--temperature', type=float, default=0,
                      help='Temperature for model generation (default: 0)')
    args = parser.parse_args()

    model_name = args.model
    load_dotenv()
    
    additional_requirements = "Please provide at least 5 clusters."
    dataset = load_dataset(args.dataset, split="train")
    
    statements = []
    for entry in dataset:
        statements.append(entry["mutation_explanation"])
    print(len(statements))
    print(statements[0])

    statement_types_list = [
        "Incorrect handling of data structures or data types leading to incorrect results or errors",
        "Incorrect or unexpected changes in function logic or return values",
        "Incorrect or missing function calls or method usage", 
        "Incorrect handling of plot or visualization elements",
        "Incorrect handling of file operations or I/O",
        "Incorrect handling of string or character data",
        "Incorrect handling of numerical data",
        "None of the above"
    ]
    statement_types = "\n".join([f"{i}: {s}" for i, s in enumerate(statement_types_list)])
    clusters = cluster_statements_to_types(
        model_name, 
        statements, 
        initial_batch_size=args.initial_batch_size, 
        batch_size=args.batch_size, 
        temperature=args.temperature, 
        additional_requirements=additional_requirements, 
        statement_types=statement_types,
        add_none=True
    )
    print(clusters)
    print()
    for _, bug_type_dict in clusters.items():
        print(bug_type_dict["pattern"])
        print(len(bug_type_dict["statement_idx"]))
    
    # save to JSON
    os.makedirs("outputs/clusters", exist_ok=True)
    dataset_name = args.dataset.split("/")[-1]
    with open(f"outputs/clusters/{dataset_name}.json", "w") as f:
        json.dump(clusters, f)
    
