import attr
from pathlib import Path
import pandas as pd

import logging
import time
from functools import partial
from typing import Callable, List, Optional

import torch
from torchbiggraph.batching import AbstractBatchProcessor, call, process_in_batches
from torchbiggraph.bucket_scheduling import create_buckets_ordered_lexicographically
from torchbiggraph.checkpoint_manager import CheckpointManager
from torchbiggraph.config import add_to_sys_path, ConfigFileLoader, ConfigSchema
from torchbiggraph.graph_storages import EDGE_STORAGES
from torchbiggraph.losses import LOSS_FUNCTIONS
from torchbiggraph.model import make_model, MultiRelationEmbedder
from torchbiggraph.stats import Stats
from torchbiggraph.eval import RankingEvaluator
from torchbiggraph.filtered_eval import FilteredRankingEvaluator
from torchbiggraph.types import EntityName, Partition, UNPARTITIONED
from torchbiggraph.util import (
    create_pool,
    EmbeddingHolder,
    get_async_result,
    get_num_workers,
    set_logging_verbosity,
    setup_logging,
    split_almost_equally,
    SubprocessInitializer,
    tag_logs_with_process_name,
)

from SEPAL import SEPAL_DIR
from SEPAL.downstream_evaluation import DATASETS_NAMES

logger = logging.getLogger("torchbiggraph")


# A slightly modified version of torchbiggraph.eval.do_eval_and_report_stats that outputs the mean_stats
def do_eval_and_report_stats(
    config: ConfigSchema,
    model: Optional[MultiRelationEmbedder] = None,
    evaluator: Optional[AbstractBatchProcessor] = None,
    subprocess_init: Optional[Callable[[], None]] = None,
):
    """Computes eval metrics (mr/mrr/r1/r10/r50) for a checkpoint with trained
    embeddings.
    """
    tag_logs_with_process_name(f"Evaluator")

    if evaluator is None:
        evaluator = RankingEvaluator(
            loss_fn=LOSS_FUNCTIONS.get_class(config.loss_fn)(margin=config.margin),
            relation_weights=[relation.weight for relation in config.relations],
        )

    if config.verbose > 0:
        import pprint

        pprint.PrettyPrinter().pprint(config.to_dict())

    checkpoint_manager = CheckpointManager(config.checkpoint_path)

    def load_embeddings(entity: EntityName, part: Partition) -> torch.nn.Parameter:
        embs, _ = checkpoint_manager.read(entity, part)
        assert embs.is_shared()
        return torch.nn.Parameter(embs)

    holder = EmbeddingHolder(config)

    num_workers = get_num_workers(config.workers)
    pool = create_pool(
        num_workers, subprocess_name="EvalWorker", subprocess_init=subprocess_init
    )

    if model is None:
        model = make_model(config)
    model.share_memory()

    state_dict, _ = checkpoint_manager.maybe_read_model()
    if state_dict is not None:
        model.load_state_dict(state_dict, strict=False)

    model.eval()

    for entity in holder.lhs_unpartitioned_types | holder.rhs_unpartitioned_types:
        embs = load_embeddings(entity, UNPARTITIONED)
        holder.unpartitioned_embeddings[entity] = embs

    all_stats: List[Stats] = []
    for edge_path_idx, edge_path in enumerate(config.edge_paths):
        logger.info(
            f"Starting edge path {edge_path_idx + 1} / {len(config.edge_paths)} "
            f"({edge_path})"
        )
        edge_storage = EDGE_STORAGES.make_instance(edge_path)

        all_edge_path_stats = []
        # FIXME This order assumes higher affinity on the left-hand side, as it's
        # the one changing more slowly. Make this adaptive to the actual affinity.
        for bucket in create_buckets_ordered_lexicographically(
            holder.nparts_lhs, holder.nparts_rhs
        ):
            tic = time.perf_counter()
            # logger.info(f"{bucket}: Loading entities")

            old_parts = set(holder.partitioned_embeddings.keys())
            new_parts = {(e, bucket.lhs) for e in holder.lhs_partitioned_types} | {
                (e, bucket.rhs) for e in holder.rhs_partitioned_types
            }
            for entity, part in old_parts - new_parts:
                del holder.partitioned_embeddings[entity, part]
            for entity, part in new_parts - old_parts:
                embs = load_embeddings(entity, part)
                holder.partitioned_embeddings[entity, part] = embs

            model.set_all_embeddings(holder, bucket)

            # logger.info(f"{bucket}: Loading edges")
            edges = edge_storage.load_edges(bucket.lhs, bucket.rhs)
            num_edges = len(edges)

            load_time = time.perf_counter() - tic
            tic = time.perf_counter()
            # logger.info(f"{bucket}: Launching and waiting for workers")
            future_all_bucket_stats = pool.map_async(
                call,
                [
                    partial(
                        process_in_batches,
                        batch_size=config.batch_size,
                        model=model,
                        batch_processor=evaluator,
                        edges=edges[s],
                    )
                    for s in split_almost_equally(num_edges, num_parts=num_workers)
                ],
            )
            all_bucket_stats = get_async_result(future_all_bucket_stats, pool)

            compute_time = time.perf_counter() - tic
            logger.info(
                f"{bucket}: Processed {num_edges} edges in {compute_time:.2g} s "
                f"({num_edges / compute_time / 1e6:.2g}M/sec); "
                f"load time: {load_time:.2g} s"
            )

            total_bucket_stats = Stats.sum(all_bucket_stats)
            all_edge_path_stats.append(total_bucket_stats)
            mean_bucket_stats = total_bucket_stats.average()
            logger.info(
                f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}, "
                f"bucket {bucket}: {mean_bucket_stats}"
            )

            model.clear_all_embeddings()

        total_edge_path_stats = Stats.sum(all_edge_path_stats)
        all_stats.append(total_edge_path_stats)
        mean_edge_path_stats = total_edge_path_stats.average()
        logger.info("")
        logger.info(
            f"Stats for edge path {edge_path_idx + 1} / {len(config.edge_paths)}: "
            f"{mean_edge_path_stats}"
        )
        logger.info("")

    mean_stats = Stats.sum(all_stats).average()
    logger.info("")
    logger.info(f"Stats: {mean_stats}")
    logger.info("")

    pool.close()
    pool.join()

    return mean_stats


def eval_model(config, subprocess_init, subset, filtered, sampled, side):
    ## Warning: filtered evaluation can only be done with all negatives (i.e. sampled=False).

    output_train_path, output_val_path, output_test_path = config.edge_paths

    if sampled:
        eval_config = attr.evolve(
            config,
            num_uniform_negs=10000,
            num_batch_negs=0,
            batch_size=1000,
        )
    else:
        relations = [attr.evolve(r, all_negs=True) for r in config.relations]
        eval_config = attr.evolve(
            config,
            relations=relations,
            num_uniform_negs=0,
            batch_size=100,
        )

    if subset == "val":
        eval_config = attr.evolve(
            eval_config,
            edge_paths=[output_val_path],
            checkpoint_path=config.checkpoint_path + "/train",
        )
    elif subset == "test":
        eval_config = attr.evolve(
            eval_config,
            edge_paths=[output_test_path],
            checkpoint_path=config.checkpoint_path + "/train",
        )
    
    if side == "tail":
        eval_config = attr.evolve(eval_config, disable_lhs_negs=True)
    elif side == "head":
        eval_config = attr.evolve(eval_config, disable_rhs_negs=True)

    if filtered:
        filter_paths = [output_test_path, output_val_path, output_train_path]
        mean_stats = do_eval_and_report_stats(
            eval_config,
            evaluator=FilteredRankingEvaluator(eval_config, filter_paths),
            subprocess_init=subprocess_init,
        )
    else:
        mean_stats = do_eval_and_report_stats(
            eval_config, subprocess_init=subprocess_init
        )


    results = mean_stats.to_dict()
    results["filtered"] = filtered
    results["sampled"] = sampled
    results["eval_subset"] = subset
    results["side"] = side
    results["data"] = config.entity_path.split("/")[-1]
    results["id"] = f"PBG - {DATASETS_NAMES[results['data']]} train"
    results_file = SEPAL_DIR / f"baselines/PBG/{subset}_lp_scores_pbg.parquet"
    new_df_res = pd.DataFrame([results])
    if Path(results_file).is_file():
        df_res = pd.read_parquet(results_file)
        df_res = pd.concat([df_res, new_df_res]).reset_index(drop=True)
    else:
        df_res = new_df_res
    df_res.to_parquet(results_file, index=False)
    return


def main():
    for side in ["tail", "both", "head"]:
        for data in [
            # "mini_yago3_lcc",
            # "yago3_lcc",
            # "rel_core_yago4",
            # "rel_core_yago4.5",
            # "core_yago4",
            # "core_yago4.5",
            # "yago4_lcc",
            # "yago4.5_lcc",
            "yago4_with_full_ontology",
        ]:
            for subset in ["val", "test"]:
                print(f"Evaluating {data} embeddings on {subset} set")
                loader = ConfigFileLoader()
                config = loader.load_config(
                    Path(__file__).absolute().parent / f"configs/{data}_config.py"
                )
                set_logging_verbosity(config.verbose)
                subprocess_init = SubprocessInitializer()
                subprocess_init.register(setup_logging, config.verbose)
                subprocess_init.register(add_to_sys_path, loader.config_dir.name)

                eval_model(
                    config,
                    subprocess_init,
                    subset,
                    filtered=False,
                    sampled=True,
                    side=side,
                )


if __name__ == "__main__":
    main()
