import argparse
import json
import logging
from collections import defaultdict
from pathlib import Path

import tqdm

from kge.dataset import get_triple_dataset

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--datasets", type=str, nargs="+", required=True)
    parser.add_argument(
        "--data-folder",
        type=Path,
        required=False,
        default=Path("data", "processed"),
        help="Path to the data folder. Will save results for each dataset in <data-folder>/<dataset-name>/",
    )
    parser.add_argument(
        "--count-inverse",
        action="store_true",
        help="If set, also count inverse relations (e.g., if (s,r,o) exists, also count (o,r_inv,s))",
    )
    parser.add_argument(
        "--output",
        type=Path,
        required=False,
        help="Path to save combined results JSON",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    results = {}

    for dataset_name in args.datasets:
        try:
            logging.info(f"Processing dataset: {dataset_name}")
            dataset = get_triple_dataset(
                dataset_name,
                data_folder=args.data_folder,
                add_inverse=args.count_inverse,
            )
            sr_to_o = defaultdict(set)
            for s, r, o in tqdm.tqdm(dataset.train, desc="Counting SR-to-O (train)"):
                sr_to_o[(int(s), int(r))].add(int(o))
            if dataset_name.startswith("ogb"):
                for ds in [dataset.valid, dataset.test]:
                    for s, r, o, _ in tqdm.tqdm(ds, desc=f"Counting SR-to-O ({ds.split})"):
                        sr_to_o[(int(s), int(r))].add(int(o))
            else:
                for ds in [dataset.valid, dataset.test]:
                    for s, r, o in tqdm.tqdm(ds, desc=f"Counting SR-to-O ({ds.split})"):
                        sr_to_o[(int(s), int(r))].add(int(o))

            # Calculate and log the statistics
            outdegrees = [len(objects) for objects in sr_to_o.values()]
            max_outdegree = max(outdegrees)
            avg_outdegree = sum(outdegrees) / len(outdegrees)
            median_outdegree = sorted(outdegrees)[len(outdegrees) // 2]
            mode_outdegree = max(set(outdegrees), key=outdegrees.count)
            object_groups = set()
            for objects in sr_to_o.values():
                object_groups.add(frozenset(objects))

            # Count total number of triples
            num_triples = len(dataset.train) + len(dataset.valid) + len(dataset.test)

            # Find top 5 (s,r) pairs with highest outdegree
            top_pairs = sorted(
                [(pair, len(objects)) for pair, objects in sr_to_o.items()],
                key=lambda x: x[1],
                reverse=True,
            )[:5]

            # Store results for this dataset
            dataset_results = {
                "num_entities": dataset.num_entities,
                "num_relations": dataset.num_relations,
                "num_triples": num_triples,
                "max_outdegree": max_outdegree,
                "avg_outdegree": avg_outdegree,
                "median_outdegree": median_outdegree,
                "mode_outdegree": mode_outdegree,
                "num_object_groups": len(object_groups),
                "top_outdegree_pairs": [
                    {"subject": pair[0], "relation": pair[1], "outdegree": count}
                    for (pair, count) in top_pairs
                ],
            }

            results[dataset_name] = dataset_results
            logging.info(f"\nDataset: {dataset_name}")
            logging.info(f"Total number of triples: {num_triples}")
            logging.info(f"Max (s,r) out-degree: {max_outdegree}")
            logging.info(f"Avg (s,r) out-degree: {avg_outdegree:.2f}")
            logging.info(f"Median (s,r) out-degree: {median_outdegree}")
            logging.info(f"Mode (s,r) out-degree: {mode_outdegree}")
            logging.info(f"Num object groups: {len(object_groups)}")
            logging.info("\nTop 5 (subject, relation) pairs by outdegree:")
            for pair, count in top_pairs:
                logging.info(f"Subject: {pair[0]}, Relation: {pair[1]}, Outdegree: {count}")

            # Save results for individual dataset
            if args.output is None:
                filename = (
                    "sr_outdegrees_with_inverse.json"
                    if args.count_inverse
                    else "sr_outdegrees.json"
                )
                if dataset_name.startswith("ogb"):
                    output_path = args.data_folder / dataset_name.replace("-", "_") / filename
                else:
                    output_path = args.data_folder / dataset_name / filename
                with open(output_path, "w") as f:
                    json.dump(dataset_results, f, indent=2)
                logging.info(f"Results saved to {output_path}")

        except Exception as e:
            logging.exception(f"Error processing dataset {dataset_name}: {e!s}")
            continue

    # Save combined results to JSON file if output path is provided
    if args.output is not None:
        with open(args.output, "w") as f:
            json.dump(results, f, indent=2)
        logging.info(f"\nCombined results saved to {args.output}")
