import torch
from torch import nn

from modules.common import ShuffleNetV2
from modules.attention_mil import DefaultAttentionModule, DefaultClassifier, DefaultMILGraph
from modules.additive_mil import AdditiveClassifier
from modules.additive_transmil import TransformerMILGraph, AdditiveTransMIL
from modules.transmil import TransMIL


def get_attention_mil_model_n_weights():
    model = DefaultMILGraph(
        featurizer = ShuffleNetV2(),
        pointer=DefaultAttentionModule(hidden_activation = nn.LeakyReLU(0.2), hidden_dims=[256, 256], input_dims=1024),
        classifier=DefaultClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=2)
    )
    weights_path = "artifacts/nsclc/model_weights/wt_attention_mil.pth"
    return model, weights_path


def get_additive_mil_model_n_weights():
    model = DefaultMILGraph(
        featurizer = ShuffleNetV2(),
        pointer=DefaultAttentionModule(hidden_activation = nn.LeakyReLU(0.2), hidden_dims=[256, 256], input_dims=1024),
        classifier=AdditiveClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=2)
    )
    weights_path = "artifacts/nsclc/model_weights/wt_additive_mil.pth"
    return model, weights_path


def get_transmil_model_n_weights():
    model = TransformerMILGraph(
        featurizer = ShuffleNetV2(),
        classifier = TransMIL(n_classes=2)
    )
    weights_path = "artifacts/nsclc/model_weights/wt_transmil.pth"
    return model, weights_path


def get_additive_transmil_model_n_weights():
    model = TransformerMILGraph(
        featurizer = ShuffleNetV2(),
        classifier = AdditiveTransMIL(additive_hidden_dims=[256], n_classes=2)
    )
    weights_path = "artifacts/nsclc/model_weights/wt_additive_transmil.pth"
    return model, weights_path



def load_torch_model(model, weights):
    state_dict = torch.load(weights, map_location=torch.device('cpu'))
    print(model.load_state_dict(state_dict))
    print("Model loading complete ...")


if __name__ == '__main__':
    print("Loading AttentionMIL model ...")
    load_torch_model(*get_attention_mil_model_n_weights())

    print("Loading AdditiveMIL model ...")
    load_torch_model(*get_additive_mil_model_n_weights())
    
    print("Loading TransMIL model ...")
    load_torch_model(*get_transmil_model_n_weights())

    print("Loading AdditiveTransMIL model ...")
    load_torch_model(*get_transmil_model_n_weights())

