import torch
import torch.nn as nn

from model import LeNet5, MLP, CNN1, CNN2, get_resnet18
from opacus import PrivacyEngine
from sklearn.model_selection import train_test_split


def init_model(model_type, in_channel, n_class):
    if model_type == 'LeNet5':
        model = LeNet5(in_channel, n_class)
    elif model_type == 'MLP':
        model = MLP(n_class)
    elif model_type == 'CNN1':
        model = CNN1(in_channel, n_class)
    elif model_type == 'CNN2':
        model = CNN2(in_channel, n_class)
    elif model_type == 'ResNet18':
        model = get_resnet18(n_class)
    else:
        raise ValueError(f"Unknown model type {model_type}")

    return model


def init_optimizer(model, args):
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=5e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=1e-4)
    else:
        raise ValueError("Unknown optimizer")

    return optimizer

def init_dp_optimizer(model, data_size, args):
    opt = init_optimizer(model, args)
    orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))
    privacy_engine = PrivacyEngine(
        model,
        sample_rate=args.batch_size / data_size,
        alphas=orders,
        noise_multiplier=args.noise_multiplier,
        max_grad_norm=args.l2_norm_clip,
    )
    privacy_engine.attach(opt)
    return opt


class Client(nn.Module):

    def __init__(self, data, args, major_class=-1):

        super(Client, self).__init__()

        self.private_data = data
        # input: data[0] target: data[1]
        if args.eval_ratio > 0:
            x_train, x_test, y_train, y_test = train_test_split(data[0], data[1], test_size=args.eval_ratio)
            self.private_test_data = (x_test, y_test)
        else:
            x_train = data[0]
            y_train = data[1]
            self.private_test_data = None
        self.private_train_data = (x_train, y_train)
        self.major_class = major_class
        
        def model(): return init_model(args.private_model_type, args.in_channel, args.n_class).to(args.device)
        if args.algorithm == 'DFedEM':
            self.private_components = [model() for _ in range(args.n_components)]
            self.private_component_weights = (torch.ones(args.n_components) / args.n_components).to(args.device)
            self.private_opts = [init_optimizer(self.private_components[i], args) for i in range(args.n_components)]

        elif args.algorithm == 'Federico':
            self.private_model = model()
            x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=args.cw_ratio)
            self.private_train_data = (x_train, y_train)
            # Store dummy models on client into which sampled models will be loaded
            self.dummy_models = [model() for _ in range(args.n_neighbors)]

        else:
            self.private_model = model()
            self.private_opt = init_optimizer(self.private_model, args)
        
        if args.algorithm == 'FedFomo':
            x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=args.cw_ratio)
            self.private_train_data = (x_train, y_train)
            self.private_val_data = (x_val, y_val)
            self.dummy_model = model()

        self.device = args.device
        self.tot_epochs = 0
        self.privacy_budget = 0.
