from __future__ import print_function
import argparse
import torch
import os
import numpy as np
from utils.model_utils import read_data, read_user_data
from server.base_server import Server
from party.base_party import Party
import torchvision
from torchvision import transforms
from utils.process_data import get_dataset, process_user_data
import random
from torch.utils.data import Subset
import numpy as np
import matplotlib.pyplot as plt

def main(args):
    party_data_config = {"5_parties": [[0], [1, 2], [0, 1, 2, 3], [3, 4, 5], [6, 7, 8, 9]],}
    if args.dataset == "CIFAR10":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        all_data = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,
                                                transform=transform)
        test_data = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,
                                                transform=transform)
    elif args.dataset == "MNIST":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))])
        all_data = torchvision.datasets.MNIST('../datasets', train=True, download=True,
                                            transform=transform)
        test_data = torchvision.datasets.MNIST('../datasets', train=False, download=True,
                                            transform=torchvision.transforms.Compose([
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize(
                                                (0.1307,), (0.3081,))
                                        ]))
    elif args.dataset == "SVHN":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        all_data = torchvision.datasets.SVHN('../datasets', split = "train", download=True,
                                                transform=transform)
        test_data = torchvision.datasets.SVHN('../datasets', split = "test", download=True,
                                                transform=transform)

    train_size = len(all_data) - args.val_size
    train_data, val_data = torch.utils.data.random_split(all_data, [train_size, args.val_size], generator=torch.Generator().manual_seed(42))
    
    server = Server(args, val_data)

    # num_classes = 10
    # samples_per_class = 5
    # shared_data_idxs = []
    # for digit in range(num_classes):
    #     class_indices = torch.where(train_minist.targets == digit)[0]
    #     idxs = torch.randperm(len(class_indices))[:samples_per_class]
    #     shared_data_idxs.extend(list(class_indices[idxs].numpy()))
    # Local Training
    for i in range(args.num_users):
        classes = party_data_config["5_parties"][i]
        party_data_idxs = [j for j in range(len(train_data)) if train_data[j][1] in classes]
        # party_data_idxs.extend(shared_data_idxs)
        subset_dataset = Subset(train_data, party_data_idxs)
        print('Model Training for Party: ', i)
        party = Party(args, i, subset_dataset, test_data)
        party.train()
        server.add_party(party)
    
    # Black-box Model Sharing
    server.collab(args.ensemble)
    server.reward()
    original_acc = []
    fair_acc = []
    for i in range(args.num_users):
        test_acc = server.parties[i].test()
        print('Original Test Acc: ', test_acc)

        server.parties[i].train_new()

        test_new_acc = server.parties[i].test_new()
        print('Improved Test Acc: ', test_new_acc)
        print(server.parties[i].get_emp_optimal_alpha())

        original_acc.append(test_acc)
        fair_acc.append(test_new_acc)

    fairness = np.corrcoef(np.array(original_acc), np.array(fair_acc))[0,1]
    print('Fairness: ',fairness)
    print("Finished.")

    fig, ax = plt.subplots(1,1, figsize=(2, 2), dpi=200)
    row_labels = ['[0]', '[1-2]', '[0-3]', '[3-5]', '[6-9]']
    ax.plot(list(range(5)), original_acc, 'o--c', label=r'$ACC(h_i)$')
    ax.plot(list(range(5)), fair_acc, 'o-m', label=r"$ACC(h_i')$")
    ax.set_ylim(0, 0.7)
    ax.set_xticks(np.arange(5))
    ax.set_xticklabels(row_labels)
    plt.legend(loc='upper left') #loc='upper left' 6
    plt.rcParams.update({'font.size': 11})
    plt.title(args.dataset)
    plt.ylabel('Accuracy')
    plt.xlabel('Party Label')
    file_name = args.dataset + '_' + str(args.val_size) + '.png' 
    plt.savefig(file_name)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_classes", type=int, default=10)
    parser.add_argument("--num_users", type=int, default=5)
    parser.add_argument("--dataset", type=str, default='CIFAR10')
    parser.add_argument("--cuda_num", type=str, default='cuda:6')
    parser.add_argument("--ensemble", type=str, default='opt')
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--val_size", type=int, default=5000)

    args = parser.parse_args()

    main(args)

