from gnn_xai_common.datasets import BaseGraphDataset
from gnnboundary import *
from typing import List, Dict, Tuple
import torch
from dataclasses import dataclass

@dataclass
class TrainingParams:
    iterations: int
    target_probs: Dict[int, Tuple[float, float]] = None
    target_size: int = 30
    w_budget_init: int = 1
    w_budget_inc: float = 1.1
    w_budget_dec: float = 0.95
    k_samples: int = 16
    eval_threshold: float = 0.5

def get_model_kwargs(dataset: object, model_name: str):
    MODEL_KWARGS = {
        "COLLAB": dict(node_features=len(dataset.NODE_CLS),
                      num_classes=len(dataset.GRAPH_CLS),
                      hidden_channels=64,
                      num_layers=5),
        "Motif": dict(node_features=len(dataset.NODE_CLS),
                      num_classes=len(dataset.GRAPH_CLS),
                      hidden_channels=6,
                      num_layers=3),
        "ENZYMES": dict(node_features=len(dataset.NODE_CLS),
                      num_classes=len(dataset.GRAPH_CLS),
                      hidden_channels=32,
                      num_layers=3),
        "REDDIT-MULTI-5K": dict(node_features=len(dataset.NODE_CLS),
                        num_classes=len(dataset.GRAPH_CLS),
                        hidden_channels=64,
                        num_layers=5),
        "IMDB": dict(node_features=len(dataset.NODE_CLS),
                        num_classes=len(dataset.GRAPH_CLS),
                        hidden_channels=64,
                        num_layers=5),
    }

    return MODEL_KWARGS[model_name]

def get_default_criteria_dynamic_boundary(dataset_name: str, mean_embeds: List[torch.Tensor], cls_1:int , cls_2: int):
    match dataset_name:
        case "Motif":
            return WeightedCriterion([
                dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(
                    classes=[cls_1, cls_2], alpha=1, beta=2
                ), weight=25),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=1),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
            ])

        case "ENZYMES":
            return WeightedCriterion([
                dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(
                    classes=[cls_1, cls_2], alpha=1, beta=1
                ), weight=5),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=0),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
            ])
        case "COLLAB":
            return WeightedCriterion([
            dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(classes=[cls_1, cls_2]), weight=25),
            dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
            dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
            dict(key="logits", criterion=MeanPenalty(), weight=0),
            dict(key="omega", criterion=NormPenalty(order=1), weight=2),
            dict(key="omega", criterion=NormPenalty(order=2), weight=1),
            dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
        ])
        case "REDDIT-MULTI-5K":
            return WeightedCriterion([
                dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(
                    classes=[cls_1, cls_2], alpha=1, beta=2
                ), weight=25),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=1),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
            ])
        case "IMDB":
            return WeightedCriterion([
                dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(
                    classes=[cls_1, cls_2], alpha=1, beta=2
                ), weight=25),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=1),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
            ])


    raise ValueError("Invalid dataset name")

def get_default_criteria_interpreter(dataset_name: str, mean_embeds: List[torch.Tensor], cls_1:int):
    match dataset_name:
        case "Motif":
            return WeightedCriterion([
                dict(key="logits", criterion=ClassScoreCriterion(class_idx=cls_1, mode='maximize'), weight=50),
                #dict(key="logits", criterion=ClassScoreCriterion(class_idx=0, mode='minimize'), weight=50),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=0),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="xi", criterion=NormPenalty(order=1), weight=0),
                dict(key="xi", criterion=NormPenalty(order=2), weight=0),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=1),
            ])
        case "ENZYMES":
            return WeightedCriterion([
                dict(key="logits", criterion=ClassScoreCriterion(class_idx=cls_1, mode='maximize'), weight=1),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=1),
                dict(key="logits", criterion=MeanPenalty(), weight=1),
                dict(key="omega", criterion=NormPenalty(order=1), weight=2),
                dict(key="omega", criterion=NormPenalty(order=2), weight=2),
                dict(key="xi", criterion=NormPenalty(order=1), weight=0),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=5),
            ])
        case "COLLAB":
            return WeightedCriterion([
                dict(key="logits", criterion=ClassScoreCriterion(class_idx=cls_1, mode='maximize'), weight=50),
                # dict(key="logits", criterion=ClassScoreCriterion(class_idx=0, mode='minimize'), weight=50),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=0),
                dict(key="omega", criterion=NormPenalty(order=1), weight=2),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="xi", criterion=NormPenalty(order=1), weight=0),
                dict(key="xi", criterion=NormPenalty(order=2), weight=0),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=1),
            ])
        case "REDDIT-MULTI-5K":
            return WeightedCriterion([
                dict(key="logits", criterion=ClassScoreCriterion(class_idx=cls_1, mode='maximize'), weight=50),
                # dict(key="logits", criterion=ClassScoreCriterion(class_idx=0, mode='minimize'), weight=50),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=0),
                dict(key="omega", criterion=NormPenalty(order=1), weight=2),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="xi", criterion=NormPenalty(order=1), weight=0),
                dict(key="xi", criterion=NormPenalty(order=2), weight=0),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=1),
            ])
    raise ValueError("Invalid dataset name")

def get_default_criteria_entropy(dataset_name: str, mean_embeds: List[torch.Tensor], cls_1:int , cls_2: int):
    match dataset_name:
        case "Motif":
            return WeightedCriterion([
                dict(key="logits", criterion=CrossEntropyBoundaryCriterion(class_a=cls_1, class_b=cls_2), weight=25),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=1),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
            ])

        case "ENZYMES":
            return WeightedCriterion([
                dict(key="logits", criterion=CrossEntropyBoundaryCriterion(class_a=cls_1, class_b=cls_2), weight=5),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=0),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
            ])
        case "COLLAB":
            return WeightedCriterion([
            dict(key="logits", criterion=CrossEntropyBoundaryCriterion(class_a=cls_1, class_b=cls_2), weight=25),
            dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
            dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
            dict(key="logits", criterion=MeanPenalty(), weight=0),
            dict(key="omega", criterion=NormPenalty(order=1), weight=2),
            dict(key="omega", criterion=NormPenalty(order=2), weight=1),
            dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
        ])
        case "REDDIT-MULTI-5K":
            return WeightedCriterion([
                dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(
                    classes=[cls_1, cls_2], alpha=1, beta=2
                ), weight=25),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=1),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
            ])
        case "IMDB":
            return WeightedCriterion([
                dict(key="logits", criterion=DynamicBalancingBoundaryCriterion(
                    classes=[cls_1, cls_2], alpha=1, beta=2
                ), weight=25),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_1]), weight=0),
                dict(key="embeds", criterion=EmbeddingCriterion(target_embedding=mean_embeds[cls_2]), weight=0),
                dict(key="logits", criterion=MeanPenalty(), weight=1),
                dict(key="omega", criterion=NormPenalty(order=1), weight=1),
                dict(key="omega", criterion=NormPenalty(order=2), weight=1),
                dict(key="theta_pairs", criterion=KLDivergencePenalty(binary=True), weight=0),
            ])


    raise ValueError("Invalid dataset name")

DATASET_TO_CLS_PAIRS = {
    "Motif": [(0, 1), (0, 2) , (1, 3)],
    "ENZYMES": [(0, 3), (0, 4), (0, 5), (1, 2), (3, 4), (4, 5)],
    "COLLAB": [(0, 1), (0, 2)],
    "IMDB": [(0, 1), (0, 2), ],
    "REDDIT-MULTI-%K": [(1, 2), (1, 4), (2, 4), (3, 4)]
}


CKPT_PATHS = {
    "COLLAB": "./ckpts/collab.pt",
    "Motif": "./ckpts/motif.pt",
    "ENZYMES": "./ckpts/enzymes.pt",
    "IMDB": "./ckpts/IMDB.pt",
    "REDDIT-MULTI-5K": "./ckpts/reddit.pt",
}


def get_default_training_params(dataset_name: str, cls1: int, cls2: int):
    DEFAULT_DATASET_TRAINING_PARAMS = {
    "COLLAB": TrainingParams(2000, target_probs={cls1: (0.4, 0.6), cls2: (0.4, 0.6)}, target_size=30,
                                w_budget_init=1, w_budget_inc=1.1, w_budget_dec=0.95, k_samples=16),
    "ENZYMES": TrainingParams(2000, target_probs={cls1: (0.4, 0.6), cls2: (0.4, 0.6)}, target_size=40,
                                w_budget_init=1, w_budget_inc=1.1, w_budget_dec=0.95, k_samples=16),
    "Motif": TrainingParams(2000, target_probs={cls1: (0.4, 0.6), cls2: (0.4, 0.6)}, target_size=40,
                              w_budget_init=1, w_budget_inc=1.1, w_budget_dec=0.95, k_samples=16),
    "IMDB": TrainingParams(2000, target_probs={cls1: (0.45, 0.55), cls2: (0.45, 0.55)}, target_size=57,
                              w_budget_init=1, w_budget_inc=1.1, w_budget_dec=0.97, k_samples=32),
    "REDDIT-MULTI-5K": TrainingParams(2000, target_probs={cls1: (0.4, 0.6), cls2: (0.4, 0.6)}, target_size=40,
                              w_budget_init=1, w_budget_inc=1.1, w_budget_dec=0.95, k_samples=16)
    }

    return DEFAULT_DATASET_TRAINING_PARAMS[dataset_name]


BOUNDARY_GRAPHS_DIRs_TO_DATASET = {
    "COLLAB": {
        (0, 1): "./graphs/boundary/COLLAB/0-1",
        (0, 2): "./graphs/boundary/COLLAB/0-2",
        (1, 0): "./graphs/boundary/COLLAB/1-0",
    },
    "Motif": {
        (0, 1): "./graphs/boundary/Motif/0-1",
        (0, 2): "./graphs/boundary/Motif/0-2",
        (1, 3): "./graphs/boundary/Motif/1-3",
    },
    "ENZYMES": {
        (0, 3): "./graphs/boundary/ENZYMES/0-3",
        (0, 4): "./graphs/boundary/ENZYMES/0-4",
        (0, 5): "./graphs/boundary/ENZYMES/0-5",
        (1, 2): "./graphs/boundary/ENZYMES/1-2",
        (3, 4): "./graphs/boundary/ENZYMES/3-4",
        (4, 5): "./graphs/boundary/ENZYMES/4-5",
    },
    "IMDB": {
        (0, 1): "./graphs/boundary/IMDB/0-1",
        (0, 2): "./graphs/boundary/IMDB/0-2",
    },
}


DATASET_TO_MAX_NODES_INTERPRETER = {
    "COLLAB": {
        0: 25,
        1: 20,
        2: 20
    },
    "Motif": {
        0: 20,
        1: 20,
        2: 20,
        3: 20
    },
    "ENZYMES": {
        0: 20,
        1: 20,
        2: 20,
        3: 20,
        4: 20,
        5: 20
    }
}