import types

from models.CNN import CNN
from models.LeNet import LeNet
from models.ResNet import ResNet18
from models.LeNet_5 import LENET
from models.LogisticRegression import LogisticRegression, Cifar10LogisticRegression


def ray_wrapper(obj):
    def get_weights(self):
        return {k: v.cpu() for k, v in self.state_dict().items()}

    def set_weights(self, weights):
        self.load_state_dict(weights, strict=False)
        
    def get_model(self):
        param = []
        for p in self.parameters():
            param.append(p.detach().clone())
        return param
    
    def set_model(self, param):
        for p, m in zip(param, self.named_parameters()):
            m[1].data = p

    def get_gradients(self):
        grads = []
        for p in self.parameters():
            grad = None if p.grad is None else p.grad.data
            grads.append(grad)
        return grads

    def set_gradients(self, gradients):
        for g, p in zip(gradients, self.named_parameters()):
            if g is not None:
                p[1].grad = g

    obj.get_weights = types.MethodType(get_weights, obj)
    obj.set_weights = types.MethodType(set_weights, obj)
    
    obj.get_model = types.MethodType(get_model, obj)
    obj.set_model = types.MethodType(set_model, obj)
    
    obj.get_gradients = types.MethodType(get_gradients, obj)
    obj.set_gradients = types.MethodType(set_gradients, obj)
    return obj

def get_model(args):
    if args.model == 'LR' or args.model == 'svm' or args.model == 'RR' or args.model == 'Lasso':
        if args.dataset == 'mnist' or args.dataset == 'fmnist':
            return ray_wrapper(LogisticRegression())
        elif args.dataset == 'cifar10':
            return ray_wrapper(Cifar10LogisticRegression())
    elif args.model == 'CNN':
        return ray_wrapper(CNN())
    elif args.model == 'resnet':
        return ray_wrapper(ResNet18())
    elif args.model == 'LeNet':
        if args.dataset == 'mnist':
            return ray_wrapper(LeNet())
        elif args.dataset == 'cifar10':
            return ray_wrapper(LENET())

