import torch
from net_FCN import FullyConnectedNN_preact_CND, NewsMLP
from net_resnet import ResNet, BasicBlock, Bottleneck

def model_definition(device, args, embedding_weights=None):
    if args.network == "Preact_Base_FCN":
        model = FullyConnectedNN_preact_CND(device, 
                                 input_channels = args.input_channels, 
                                 input_size = args.input_size, 
                                 num_classes=args.num_classes,
                                 N = args.n_neurons_x_layer,
                                 L = args.n_layers,
                                 batch_normalization_flag=getattr(args, "batch_normalization_flag", False),
                                 dropout_rate=getattr(args, "dropout_rate", 0),
                                 embedding_weights = embedding_weights,
                                 activation_fn=getattr(args, "activation_fn", 'relu'),
                                 ).to(device)
        neurs_x_hid_lyr = {i: args.n_neurons_x_layer for i in range(args.n_layers)}

    elif args.network == "NewsMLP":
        # Fixed no modular class designed for NEWS basen on Yu 19 article
        model = NewsMLP(device, 
                        embedding_weights,
                        num_classes = args.num_classes,
                        dropout_rate=args.dropout,
                        ).to(device)
        neurs_x_hid_lyr = {
            0: 300*20,
            1: 300*4,
            2: 300
        }



    elif "ResNet9" == args.network:
        model = ResNet(BasicBlock, [1, 1, 1, 1],  input_channels=args.input_channels, num_classes=args.num_classes, dropout_rate=getattr(args, "dropout_rate", 0.0), network=args.network).to(device)
        neurs_x_hid_lyr = {
            0: 64,
            1: 128,
            2: 256,
            3: 512
        }
    elif "ResNet18" == args.network:
        model = ResNet(BasicBlock, [2, 2, 2, 2],  input_channels=args.input_channels, num_classes=args.num_classes, dropout_rate=getattr(args, "dropout_rate", 0.0), network=args.network).to(device)
        neurs_x_hid_lyr = {
            0: 64,
            1: 128,
            2: 256,
            3: 512
        }

    elif "ResNet34" == args.network:
        # ResNet-34: 3, 4, 6, 3 layers in each block
        model =  ResNet(BasicBlock, [3, 4, 6, 3], num_classes=args.num_classes, input_channels=args.input_channels, dropout_rate=getattr(args, "dropout_rate", 0.0), network=args.network).to(device)
        neurs_x_hid_lyr = {
            0: 64,
            1: 128,
            2: 256,
            3: 512
        }

    elif "ResNet50" == args.network:
        # ResNet-50: 3, 4, 6, 3 layers in each block using Bottleneck block
        model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=args.num_classes, input_channels=args.input_channels, dropout_rate=getattr(args, "dropout_rate", 0.0), network=args.network).to(device)
        neurs_x_hid_lyr = {
            0: 256,
            1: 512,
            2: 1024,
            3: 2048
        }

    else:
        raise ValueError(f"Unsupported model_type: {args.network}")
    print(f"Model Selected {args.network}")


    args.neurs_x_hid_lyr = {key: int(value) for key, value in neurs_x_hid_lyr.items()}

    if getattr(args, "n_GPUs", 1) > 1:
        print("PARALLEL MODE ACTIVATED")
        model = torch.nn.DataParallel(model)

    return model, args