from pathlib import Path

import fire
from loguru import logger

from src.utils.io import (
    load_qa_dataset,
    save_qa_dataset,
    load_memory_dataset,
    save_memory_dataset,
    load_conflict_graph,
    save_conflict_graph,
    load_colored_graph,
    save_colored_graph,
)
from src.dataset_construction.conflict_detection import ConflictDetectionStage
from src.dataset_construction.distribution_mimicking import DistributionMimickingStage
from src.dataset_construction.graph_coloring import GraphColoringStage
from src.dataset_construction.memory_generation import MemoryGenerationStage
from src.dataset_construction.real_data_injection import RealDataInjectionStage
from src.dataset_construction.timestamping import TimestampingStage


def run_pipeline(
    input_path: str,
    qa_dataset_path: str,
    memory_dataset_path: str,
    conflict_graph_path: str,
    activate_cache: bool = True,
) -> None:
    qa_dataset = load_qa_dataset(file_path=input_path)

    # ConflictDetectionStage
    cached_conflict_graph_path = conflict_graph_path.replace(".json", "_conflict_detection_stage.json")
    if activate_cache and Path(cached_conflict_graph_path).exists():
        conflict_graph = load_conflict_graph(cached_conflict_graph_path)
        logger.info("ConflictDetectionStage skipped.")
    else:
        conflict_detection_stage = ConflictDetectionStage()
        conflict_graph = conflict_detection_stage.run(qa_dataset)
        save_conflict_graph(file_path=cached_conflict_graph_path, conflict_graph=conflict_graph)
        logger.info("ConflictDetectionStage completed.")

    # DistributionMimickingStage
    cached_qa_dataset_path = qa_dataset_path.replace(".jsonl", "_distribution_mimicking_stage.jsonl")
    cached_conflict_graph_path = conflict_graph_path.replace(".json", "_distribution_mimicking_stage.json")
    if activate_cache and Path(cached_qa_dataset_path).exists():
        qa_dataset = load_qa_dataset(file_path=cached_qa_dataset_path)
        conflict_graph = load_conflict_graph(cached_conflict_graph_path)
        logger.info("DistributionMimickingStage skipped.")
    else:
        distribution_mimicking_stage = DistributionMimickingStage()
        qa_dataset, conflict_graph = distribution_mimicking_stage.run(qa_dataset, conflict_graph, max_total_count=10)  # TODO: check here
        save_qa_dataset(file_path=cached_qa_dataset_path, qa_dataset=qa_dataset)
        save_conflict_graph(file_path=cached_conflict_graph_path, conflict_graph=conflict_graph)
        logger.info("DistributionMimickingStage completed.")

    # GraphColoringStage
    cached_colored_graph_path = conflict_graph_path.replace(".json", "_graph_coloring_stage.json")
    if activate_cache and Path(cached_colored_graph_path).exists():
        colored_graph = load_colored_graph(cached_colored_graph_path)
        logger.info("GraphColoringStage skipped.")
    else:
        graph_coloring_stage = GraphColoringStage()
        colored_graph = graph_coloring_stage.run(conflict_graph, known_clique_size=3)  # TODO: check here
        save_colored_graph(file_path=cached_colored_graph_path, colored_graph=colored_graph)
        logger.info("GraphColoringStage completed.")

    # MemoryGenerationStage
    cached_qa_dataset_path = qa_dataset_path.replace(".jsonl", "_memory_generation_service.jsonl")
    cached_memory_dataset_path = memory_dataset_path.replace(".jsonl", "_memory_generation_service.jsonl")
    if activate_cache and Path(cached_qa_dataset_path).exists() and Path(cached_memory_dataset_path).exists():
        qa_dataset = load_qa_dataset(file_path=cached_qa_dataset_path)
        memory_dataset = load_memory_dataset(file_path=cached_memory_dataset_path)
        logger.info("MemoryGenerationStage skipped.")
    else:
        memory_generation_service = MemoryGenerationStage()
        qa_dataset, memory_dataset = memory_generation_service.run(qa_dataset, colored_graph)
        save_qa_dataset(file_path=cached_qa_dataset_path, qa_dataset=qa_dataset)
        save_memory_dataset(file_path=cached_memory_dataset_path, memory_dataset=memory_dataset)
        logger.info("MemoryGenerationStage completed.")

    # RealDataInjectionStage
    cached_memory_dataset_path = memory_dataset_path.replace(".jsonl", "_real_data_injection_stage.jsonl")
    if activate_cache and Path(cached_memory_dataset_path).exists():
        memory_dataset = load_memory_dataset(file_path=cached_memory_dataset_path)
        logger.info("RealDataInjectionStage skipped.")
    else:
        real_data_injection_stage = RealDataInjectionStage()
        memory_dataset = real_data_injection_stage.run(memory_dataset, session_multiplier=2)  # TODO: check here
        save_memory_dataset(file_path=cached_memory_dataset_path, memory_dataset=memory_dataset)
        logger.info("RealDataInjectionStage completed.")

    # TimestampingStage
    cached_memory_dataset_path = memory_dataset_path.replace(".jsonl", "_timestamping_stage.jsonl")
    if activate_cache and Path(cached_memory_dataset_path).exists():
        memory_dataset = load_memory_dataset(file_path=cached_memory_dataset_path)
        logger.info("TimestampingStage skipped.")
    else:
        timestamping_stage = TimestampingStage()
        memory_dataset = timestamping_stage.run(qa_dataset, memory_dataset)
        save_memory_dataset(file_path=cached_memory_dataset_path, memory_dataset=memory_dataset)
        logger.info("TimestampingStage completed.")

    # Save
    save_qa_dataset(file_path=qa_dataset_path, qa_dataset=qa_dataset)
    save_memory_dataset(file_path=memory_dataset_path, memory_dataset=memory_dataset)
    logger.info("Entire pipeline completed successfully.")


if __name__ == "__main__":
    fire.Fire(run_pipeline)
