import os
import sys

import torch.optim

from models.vfedcd_model import VFedCDEncoder, VFedCDDecoder

sys.path.append(os.pardir)

import pickle


def load_models(args):
    args.net_list = [None] * args.k
    for ik in range(args.k):
        current_model_type = args.model_list[str(ik)]['type']
        current_model_path = args.model_list[str(ik)]['path']
        args.net_list[ik] = pickle.load(
            open('.././src/models/model_parameters/' + current_model_type + '/' + current_model_path + '.pkl', "rb"))
        args.net_list[ik] = args.net_list[ik].to(args.device)
    # important
    return args


def load_basic_models(args, index):
    current_model_type = args.model_list[str(index)]['type']
    print(f"current_model_type={current_model_type}")
    current_input_dim = args.model_list[str(index)]['input_dim'] if 'input_dim' in args.model_list[str(index)] else -1
    current_hidden_dim = args.model_list[str(index)]['hidden_dim'] if 'hidden_dim' in args.model_list[
        str(index)] else -1
    current_output_dim = args.model_list[str(index)]['output_dim'] if 'output_dim' in args.model_list[
        str(index)] else -1
    current_vocab_size = args.model_list[str(index)]['vocab_size'] if 'vocab_size' in args.model_list[
        str(index)] else -1
    # print(f"index={index}, current_input_dim={current_input_dim}, current_output_dim={current_output_dim}")
    # current_model_path = args.model_list[str(index)]['path']
    # local_model = pickle.load(open('.././model_parameters/'+current_model_type+'/'+current_model_path+'.pkl',"rb"))
    if 'resnet' in current_model_type.lower() or 'lenet' in current_model_type.lower() or 'cnn' in current_model_type.lower() or 'alexnet' in current_model_type.lower():
        local_model = globals()[current_model_type](current_output_dim)
    elif 'gcn' in current_model_type.lower():
        local_model = globals()[current_model_type](nfeat=current_input_dim, nhid=current_hidden_dim,
                                                    nclass=current_output_dim, device=args.device,
                                                    dropout=0.0, lr=args.main_lr)
    elif 'lstm' in current_model_type.lower():
        local_model = globals()[current_model_type](current_vocab_size, current_output_dim)
    else:
        local_model = globals()[current_model_type](current_input_dim, current_output_dim)
    local_model = local_model.to(args.device)
    print(f"local_model parameters: {sum(p.numel() for p in local_model.parameters())}")
    local_model_optimizer = torch.optim.Adam(list(local_model.parameters()), lr=args.main_lr, weight_decay=0.0)
    # print(f"use SGD for local optimizer for PMC checking")
    # local_model_optimizer = torch.optim.SGD(list(local_model.parameters()), lr=args.main_lr, momentum=0.9, weight_decay=5e-4)


    global_model = None
    global_model_optimizer = None
    if index == args.k - 1:
        if args.apply_trainable_layer == 0:
            global_model = globals()[args.global_model]()
            global_model = global_model.to(args.device)
            global_model_optimizer = None
        else:
            print("global_model", args.global_model)
            global_input_dim = 0
            for ik in range(args.k):
                global_input_dim += args.model_list[str(ik)]['output_dim']
            global_model = globals()[args.global_model](global_input_dim, args.num_classes)
            global_model = global_model.to(args.device)
            global_model_optimizer = torch.optim.Adam(list(global_model.parameters()), lr=args.main_lr)
            # print(f"use SGD for global optimizer for PMC checking")
            # global_model_optimizer = torch.optim.SGD(list(global_model.parameters()), lr=args.main_lr, momentum=0.9, weight_decay=5e-4)

    return args, local_model, local_model_optimizer, global_model, global_model_optimizer

def load_defense_models(args, index, local_model, local_model_optimizer, global_model, global_model_optimizer):
    print('Load Defense models')
    # no defense at all, set some variables as None
    args.encoder = None
    return args, local_model, local_model_optimizer, global_model, global_model_optimizer


def load_basic_models_vfedcd(args, owned_by_party):
    optim = args.optim
    assert optim in ('SGD', "Adam"), "Unsupported optim {}".format(optim)
    if optim == "SGD":
        optim = torch.optim.SGD
    else:
        optim = torch.optim.Adam
    current_model_type = args.model_list[str(owned_by_party)]['type']
    assert current_model_type == "VFedCD", "unrecognized model type:{} for VFedCD models".format(current_model_type)
    dims = args.dataset_split['dims']
    d = sum(dims)
    current_hidden_dim = args.model_list[str(owned_by_party)]['hidden_dim'] \
        if 'hidden_dim' in args.model_list[str(owned_by_party)] else [10]
    # print(f"index={index}, current_input_dim={current_input_dim}, current_output_dim={current_output_dim}")
    # current_model_path = args.model_list[str(index)]['path']
    # local_model = pickle.load(open('.././model_parameters/'+current_model_type+'/'+current_model_path+'.pkl',"rb"))
    local_models = []
    local_model_optimizers = []
    for belong_to_model in range(args.k):
        mask = args.causal['mask'][args.stage]
        from_row = sum(dims[:belong_to_model]) if belong_to_model > 0 else 0
        to_row = from_row + dims[belong_to_model]
        if mask == 'none':
            mask = None
        else:
            mask = mask[from_row:to_row, :]
        local_model = VFedCDEncoder(owned_by_party=owned_by_party,
                                    belong_to_model=belong_to_model,
                                    in_dim=dims[belong_to_model],
                                    out_dim=d,
                                    hidden_dims=current_hidden_dim,
                                    self_reconstruction_row=[from_row, to_row],
                                    adjacency_p=args.causal['adj_p'],
                                    mask=mask,
                                    dag_penalty_flavor=args.causal['dag_penalty_flavor'][args.stage])
        local_model = local_model.to(args.device)
        local_models.append(local_model)
        print(f"local_model parameters: {sum(p.numel() for p in local_model.parameters())}")
        local_model_optimizer = optim(list(local_model.parameters()), lr=args.main_lr[args.stage],
                                      weight_decay=0.0)
        local_model_optimizers.append(local_model_optimizer)

    global_output_dim = args.dataset_split['dims'][owned_by_party]
    global_model = VFedCDDecoder(current_hidden_dim, global_output_dim, k=args.k)
    global_model = global_model.to(args.device)
    global_model_optimizer = optim(list(global_model.parameters()), lr=args.main_lr[args.stage])
    # print(f"use SGD for global optimizer for PMC checking")
    # global_model_optimizer = torch.optim.SGD(list(global_model.parameters()), lr=args.main_lr, momentum=0.9, weight_decay=5e-4)

    return args, local_models, local_model_optimizers, global_model, global_model_optimizer


def load_models_per_party_vfedcd(args, index):
    current_model_type = args.model_list[str(index)]['type']
    val_model = None

    args, local_model, local_model_optimizer, global_model, global_model_optimizer = load_basic_models_vfedcd(args, index)

    # important
    return args, local_model, local_model_optimizer, global_model, global_model_optimizer


def load_models_per_party(args, index):
    current_model_type = args.model_list[str(index)]['type']
    val_model = None

    args, local_model, local_model_optimizer, global_model, global_model_optimizer = load_basic_models(args, index)
    args, local_model, local_model_optimizer, global_model, global_model_optimizer = load_defense_models(args, index,
                                                                                                         local_model,
                                                                                                         local_model_optimizer,
                                                                                                         global_model,
                                                                                                         global_model_optimizer)
    # important
    return args, local_model, local_model_optimizer, global_model, global_model_optimizer


if __name__ == '__main__':
    pass
