import torch
from torch.utils.data import DataLoader
import numpy as np
import random
import copy

from client import Client
from dataloader import load_data, DatasetSplit, get_mask, get_mask2, get_mask3
from torch.utils.tensorboard import SummaryWriter
from utils import all_combinations, aggregate_models, valid_list, valid_model, get_models_weights_list, draw_scatter, \
    valid_clients, add_noise_to_model


def process(name, args, experiment_epo, device):
    log_dir = 'runs/' + name + '/' + str(experiment_epo)
    writer = SummaryWriter(log_dir)

    clients = []

    dataset, dims, all_view, data_size, class_num = load_data(args.dataset, args.num_users, args.Dirichlet_alpha)
    args.class_num = class_num
    args.input_dims = dims

    num_users = args.num_users
    if args.missing_rate > -0.1:
        num_views_glob = get_mask2(len(all_view), num_users, args.missing_rate)
    else:
        num_views_glob = get_mask3(num_users, args.mask_rate)

    # num_views_glob = sorted(num_views_glob, key=lambda x: len(x))

    for nu in range(num_users):
        if args.sample_num > 0.01:
            dataset.user_data[nu] = dataset.user_data[nu][:args.sample_num]
        data_loader = DataLoader(DatasetSplit(dataset.X, dataset.Y, dataset.user_data[nu], dims, num_views_glob[nu]),
                                 batch_size=args.batch_size, shuffle=False)
        client = Client(nu, all_view, num_views_glob[nu], data_loader, args, device, writer)
        clients.append(client)

    print(num_views_glob)
    nets = []

    clients[0].train1()
    for nu in range(num_users):
        clients[nu].net.load_state_dict(clients[0].net.state_dict())
        nets.append(copy.deepcopy(clients[nu].net))

    for i in range(args.main_epochs):
        print(i)
        for nu in range(num_users):
            print(nu, 'train2')
            clients[nu].train2(nets, i)
            nets[nu] = copy.deepcopy(clients[nu].net)
            nets[nu].load_state_dict(add_noise_to_model(nets[nu], epsilon=args.epsilon))

        correct_count, accuracy, nmi, ari = valid_clients(clients)
        writer.add_scalar('all/acc', accuracy, i)
        writer.add_scalar('all/nmi', nmi, i)
        writer.add_scalar('all/ari', ari, i)
        print(accuracy)

    return accuracy

    # for nu in range(num_users):
    #     print(nu, 'train14')
    #     clients[nu].train14()
    #     nets[nu] = copy.deepcopy(clients[nu].net)

    # acc = valid_clients(clients)
    # writer.add_scalar('all/acc', acc, i)