from scripts.experiments import GraphGenerationAllClassPairsExperiment
from scripts.sampling_training import TrainingParams
from scripts.default_configs import *
from gnnboundary import *
from gnnboundary.utils.random_baseline import PAPER_CLASS_COMBINATIONS#
from gnnboundary.utils.boundary_generator import GraphGenerator
import numpy as np
from typing import Dict, Tuple
import logging


# Setup logging
logging.basicConfig(
    level=logging.INFO, format="%(process)d - %(levelname)s - %(message)s"
)


CKPT_PATHS = {
    "CollabDataset": "ckpts/collab.pt",
    "MotifDataset": "ckpts/motif.pt",
    "ENZYMESDataset": "ckpts/enzymes.pt",
    "RedditDataset": "ckpts/reddit.pt",
    "IMDB": "ckpts/IMDB.pt"
}



def run_training(
    fixed_params: dict,
    target_probs=dict,
    target_size: int = 32,
    learning_rate: float = 0.5,
    temperature: float = 0.15,
    w_budget_inc: float = 1.1,
    w_budget_dec: float = 0.95,
    num_runs: int = 40,
) -> float:
    """
    Run training and return success rate over given number of runs.
    """

    # dataset = MotifDataset(seed=12345)
    # dataset = CollabDataset(seed=12345)
    #dataset = ENZYMESDataset(seed=12345)
    dataset = IMDBDataset(seed=12345)
    target_probs_for_all_classes = {
        class_idx: target_probs for class_idx in range(len(dataset.GRAPH_CLS))
    }

    training_params = TrainingParams(
        target_size=target_size,
        target_probs=target_probs_for_all_classes,
        w_budget_inc=w_budget_inc,
        w_budget_dec=w_budget_dec,
        **fixed_params
    )

    model_kwargs = get_model_kwargs(dataset, dataset.name)
    model = GCNClassifier(**model_kwargs)
    model.load_state_dict(torch.load(CKPT_PATHS[dataset.name]))

    total_loss = 0

    cls_pairs = DATASET_TO_CLS_PAIRS[dataset.name]

    s = GraphSampler(
        max_nodes=25,
        temperature=temperature,
        num_node_cls=len(dataset.NODE_CLS),
        learn_node_feat=True
    )

    dataset_list_gt = dataset.split_by_class()
    mean_embeds = [d.model_transform(model, key="embeds").mean(dim=0) for d in dataset_list_gt]
    criterion = {cls_pair: get_default_criteria_dynamic_boundary(dataset.name, mean_embeds, cls_pair[0], cls_pair[1])
                 for cls_pair in cls_pairs}

    trainer = {}
    for cls_pair in cls_pairs:
        trainer[cls_pair] = Trainer(
            sampler=s,
            discriminator=model,
            criterion=criterion[cls_pair],
            optimizer=(o := torch.optim.SGD(s.parameters(), lr=learning_rate)),
            scheduler=torch.optim.lr_scheduler.ExponentialLR(o, gamma=1),
            dataset=dataset,
            budget_penalty=BudgetPenalty(budget=10, order=2, beta=1),
        )

    graph_generators = {cls_pair: GraphGenerator(trainer[cls_pair].sampler, dataset, trainer[cls_pair], model)
                        for cls_pair in cls_pairs}

    results = {}

    for iter_count in range(num_runs):
        for cls_pair in cls_pairs:
            target_probs_cls = {cls_pair[0]: target_probs, cls_pair[1]: target_probs}
            _, _, graph = graph_generators[cls_pair](num_runs=1,
                                                       num_graphs=1,
                                                       strategy="boundary",
                                                       save_graphs=False, cls=cls_pair,
                                                       add_non_successful=True,
                                                       target_size=target_size,
                                                       iterations=fixed_params["iterations"],
                                                       target_probs=target_probs_cls,
                                                       w_budget_init=1,
                                                       w_budget_inc=1.1,
                                                       w_budget_dec=0.95,
                                                       k_samples=fixed_params["k_samples"],
                                                       )

            probs = model(graph[0][0], edge_weight=graph[0][0].edge_weight)["probs"]
            print(type(probs))
            print(probs.device)
            print(probs)
            numpy_array = probs.detach().cpu().numpy()
            results[cls_pair] = numpy_array[0, [cls_pair]]
            #results[cls_pair] = probs.detach().cpu().numpy()[0, [cls_pair]]

        avg_loss = get_loss_from_run_results(results)
        total_loss += avg_loss
        logging.info(f"Finished run {iter_count+1} with an average loss of {avg_loss}")
    print(total_loss)
    return total_loss / num_runs


def get_loss_from_run_results(results: Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray]]) -> float:
    """
    Take the probabilities from the results dictionary and measure their distance to the target value of 0.5.
    Sum across all specified class-pairs and return the total loss.

    Args:
        results (Dict[Tuple[int, int], Tuple[np.ndarray, np.ndarray]]): 
            A dictionary where keys are tuples representing class-pairs and values are tuples of logits and probabilities.

    Returns:
        float: The summed loss across all class-pairs.
    """
    total_loss = 0.0
    
    for probabilities in results.values():
        # Calculate distance from 0.5 and average them
        loss = np.sum(np.abs(probabilities - 0.5)) / 2
        total_loss += loss
    print(total_loss)
    return total_loss / len(results.keys())

