import torch
from resnet import ResNet18, ResNet34
from resnet_cifar import ResNet20, ResNet32
from vgg import vgg11, vgg13, vgg11_bn, vgg13_bn
from densenet import densenet121
from mobilenet import mobilenetv2
import torch.nn as nn
from slimable_op import SlimmableConv2d,Scaler

def get_model_optimizer_scheduler(arch,cutting_layer,max_channel,
                        num_agent,num_class,
                        cloud_lr,local_lr,
                        logger,epochs,scaler_fixed=False,no_subnetwork=False):
    adds_bottleneck = True
    # initilize the local and server model and optimizer
    if arch == "resnet20":
        model = ResNet20(cutting_layer, num_agent=num_agent,num_class=num_class)
    elif arch == "vgg11_bn":
        model = vgg11_bn(cutting_layer, num_agent=num_agent, num_class=num_class)
    elif arch == "mobilenetv2":
        model = mobilenetv2(cutting_layer,num_agent=num_agent, num_class=num_class)
    elif arch == "densenet121":
        model = densenet121(num_agent=num_agent, num_class=num_class)
    else:
        raise ("No such architecture!")
   
    # initializing the returned models
    model.local_b_list = None
    model.cloud_b_list = None

    # initializing the returned optimizers
    local_optimizer_list = None
    local_b_optimizer_list = None
    cloud_b_optimizer_list = None
    optimizer = None

    # initializing the returned scheduler
    train_scheduler = None
    train_local_scheduler_list = None
    train_local_b_scheduler_list = None
    train_cloud_b_scheduler_list = None

    ########## server side model settings ##########
    f_tail = model.cloud
    classifier = model.classifier
    f_tail.cuda()
    classifier.cuda()
    params = list(f_tail.parameters()) + list(classifier.parameters())
    optimizer = torch.optim.SGD(params, lr=cloud_lr, momentum=0.9, weight_decay=5e-4)
    train_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)  # learning rate decay

    ########## client side model settings ##########
    local_params = []
    local_optimizer_list = []
    train_local_scheduler_list = []
    for i in range(num_agent):
        model.local_list[i].cuda()
        local_params.append(model.local_list[i].parameters())
        local_optimizer_list.append(torch.optim.SGD(local_params[-1], lr=local_lr, momentum=0.9, weight_decay=5e-4))
        train_local_scheduler_list.append(
            torch.optim.lr_scheduler.CosineAnnealingLR(local_optimizer_list[-1], T_max=epochs))  # learning rate decay

    ## bottleneck layer
    orig_channel = 0
    if adds_bottleneck:
        #check the inputsize of server-size model
        if 'vgg' in arch:
            orig_channel = model.cloud._modules['0'].in_channels
        elif 'resnet' in arch:
            orig_channel = model.cloud._modules['0']._modules['conv_a'].in_channels
        elif 'densenet' in arch:
            orig_channel = 128
        elif 'mobilenet' in arch:
            # orig_channel = 24
            orig_channel = 32

        ########## client side bottleneck layer settings ##########
        model.local_b_list = []
        train_local_b_scheduler_list = []
        local_b_optimizer_list = []

        for client in range(num_agent):
            new_copy = nn.Sequential(SlimmableConv2d(orig_channel, max_channel, kernel_size=3,padding=1))
            model.local_b_list.append(new_copy)
            model.local_b_list[-1].cuda()
            local_b_optimizer_list.append(
                torch.optim.SGD(list(model.local_b_list[-1].parameters()), lr=local_lr, momentum=0, weight_decay=5e-4))
            train_local_b_scheduler_list.append(
                torch.optim.lr_scheduler.CosineAnnealingLR(local_b_optimizer_list[-1], T_max=epochs))

        ########## server side bottleneck layer settings ##########
        model.cloud_b_list = []
        train_cloud_b_scheduler_list = []
        cloud_b_optimizer_list = []
        # only slimmable for the cloud one, since the local one can be selected using activations.
        output_channel = orig_channel
        if no_subnetwork:
            iteration=2
        else:
            iteration=1

        for _ in range(iteration):
            model.cloud_b_list.append(nn.Sequential(
                SlimmableConv2d(max_channel, output_channel, kernel_size=3, padding=1),
                Scaler(max_channel,fixed=scaler_fixed)))
            model.cloud_b_list[-1].cuda()
            cloud_b_optimizer_list.append(torch.optim.SGD(list(model.cloud_b_list[-1].parameters()), lr=cloud_lr, momentum=0,
                                                        weight_decay=5e-4))
            train_cloud_b_scheduler_list.append(
                torch.optim.lr_scheduler.CosineAnnealingLR(cloud_b_optimizer_list[-1], T_max=epochs))


    logger.debug('An example of the full model is \n Client Model{}\n Client BN{}\n Cloud BN{}\n Server Model{}\n Classifier{}'.format(
                                                                                model.local_list[0],
                                                                                model.local_b_list,
                                                                                model.cloud_b_list,
                                                                                model.cloud,
                                                                                model.classifier))
    return (model,
            local_optimizer_list, local_b_optimizer_list, cloud_b_optimizer_list, optimizer,
            train_scheduler,train_local_scheduler_list, train_local_b_scheduler_list,train_cloud_b_scheduler_list,
            orig_channel
            )