from gnnboundary import *
from dataclasses import dataclass, field
import torch
from typing import List, Tuple, Dict
import torch.nn as nn

from scripts.utils import draw_matrix_adj


@dataclass
class AnalyzerConfig:
    model_architecture: nn.Module = GCNClassifier
    model_checkpoint: str = None
    model_kwargs : Dict[str, int] = field(default_factory={
        "node_features": 3, "num_classes": 2, "hidden_channels": 64, "num_layers": 5}
    )


class Analyzer:
    def __init__(self, dataset, config: AnalyzerConfig):
        print("Getting data...")
        self.dataset = dataset
        self.model = config.model_architecture(**config.model_kwargs)

        if config.model_checkpoint:
            self.model.load_state_dict(torch.load(config.model_checkpoint))

        self.model.eval()
        print("Getting data...")
        self.dataset_list_pred = dataset.split_by_pred(self.model)

    def analyze_adjacency(self, draw=False, rt_fig=False):
        adj_ratio_mat, boundary_info = pairwise_boundary_analysis(self.model, self.dataset_list_pred)
        if draw:
            draw_matrix_adj(adj_ratio_mat, names=self.dataset.GRAPH_CLS.values(), fmt='.2f', return_fig=rt_fig)
        return adj_ratio_mat, boundary_info

    def evaluate(self, draw=False):
        evaluation = self.dataset.model_evaluate(self.model)
        if draw:
            draw_matrix_adj(evaluation['cm'], self.dataset.GRAPH_CLS.values(), fmt='d')

        return evaluation




if __name__ == "__main__":
    dataset = CollabDataset(seed=12345)
    config = AnalyzerConfig(model_checkpoint='../ckpts/collab.pt',
                            model_kwargs=dict(node_features=len(dataset.NODE_CLS),
                                              num_classes=len(dataset.GRAPH_CLS),
                                              hidden_channels=64,
                                              num_layers=5))

    analyzer = Analyzer(dataset, config)
    print(analyzer.analyze_adjacency())