#!/usr/bin/env python
# -*-coding:utf-8 -*-
import torch 
import torch.nn as nn 
import fed_model.vgg as vgg 
import fed_model.resnet as resnet 


class CLSModel(nn.Module):
    def __init__(self, num_input, num_class, num_hidden=1):
        super(CLSModel, self).__init__()
        if num_hidden == 1:
            self.fc_layer = nn.Linear(num_input, num_class)
        else:
            self.fc_layer = nn.Sequential(
                nn.Linear(num_input, 512), 
                nn.Linear(512, num_class))
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc_layer(x)
    
    
def get_model(conf, device):
    if "fc" in conf.arch:
        num_hidden = int(conf.arch.split("fc")[1])
        if conf.dataset == "cifar10":
            num_input, num_class = 32 * 32 * 3, 10
        elif conf.dataset == "cifar100":
            num_input, num_class = 32 * 32 * 3, 100
        elif conf.dataset == "mnist":
            num_input, num_class = 28 * 28, 10
        elif conf.dataset == "dsprite":
            num_input, num_class = 64 * 64, 3
        model_obj = CLSModel(num_input,num_class, num_hidden)
        model_obj.to(device)
    elif "VGG" in conf.arch:
        model_obj = vgg.vgg(conf)
        model_obj.to(device)
    elif "resnet" in conf.arch:
        model_obj = resnet.resnet(conf)
        model_obj.to(device)
    return model_obj 


def get_model_params(dataset, arch):
    class PARAM:
        dataset = "cifar10"
    conf = PARAM 
    conf.dataset = dataset
    conf.arch = arch 
    conf.vgg_scaling = None     
    model_use = get_model(conf, torch.device("cpu"))
    param_size, title_group = [], []
    for name, p in model_use.named_parameters():
        if p.requires_grad and "bias" not in name:
            param_size.append(p.shape)
            title_group.append(name)
    return param_size, title_group 