import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.optimize import minimize
import torch.nn.functional as F
from net.models import build_model
import torchvision.transforms as transforms
from torch.utils.data import Subset
import torch.nn as nn
from server.collaboration import one_hot_encode, get_optimal_weights
from utils.process_data import *


def get_values(args, split_ratio):
    if args.dataset == "CIFAR10":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        raw_data = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,
                                                transform=transform)
        test_dataset = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,
                                                transform=transform)
    elif args.dataset == "MNIST":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))])
        raw_data = torchvision.datasets.MNIST('../datasets', train=True, download=True,
                                            transform=transform)
        test_dataset = torchvision.datasets.MNIST('../datasets', train=False, download=True,
                                            transform=torchvision.transforms.Compose([
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize(
                                                (0.1307,), (0.3081,))
                                        ]))
    elif args.dataset == "SVHN":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        raw_data = torchvision.datasets.SVHN('../datasets', split = "train", download=True,
                                                transform=transform)
        test_dataset = torchvision.datasets.SVHN('../datasets', split = "test", download=True,
                                                transform=transform)

    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False, num_workers=4)

    # Divide training and validation dataset
    val_size = args.val_size
    train_size = len(raw_data) - val_size

    # Create indices for the datasets
    indices = list(range(len(raw_data)))
    random.shuffle(indices)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]
    # Create the training and validation datasets
    train_data = Subset(raw_data, train_indices)
    val_data = Subset(raw_data, val_indices)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=4)
    
    # Generate iid/non-iid data from the training data
    if args.iid_data == 1:
        # Get iid data
        datasets = torch.utils.data.random_split(train_data, split_ratio, generator=torch.Generator().manual_seed(42))
    else:
        # Get non-iid data
        data_by_class, n_train_sample, SRC_N_CLASS = get_dataset(train_data, len(train_indices), num_class=args.num_class)
        SRC_CLASSES=[l for l in range(SRC_N_CLASS)]
        random.shuffle(SRC_CLASSES)
        datasets_X, datasets_y = process_user_data(data_by_class, n_train_sample, SRC_CLASSES, 
                                                   num_users=args.num_users,
                                                   alpha=args.alpha)
    


    device = torch.device(args.cuda_num)
    models = []
    errors = []
    for i in range(args.num_users):
        # Process data from each user
        if args.iid_data == 1:
            train_loader = torch.utils.data.DataLoader(datasets[i],batch_size=args.batch_size, shuffle=True, num_workers=4)
        else:         
            user_X, user_y = torch.tensor(datasets_X[i], dtype=torch.float32), torch.tensor(datasets_y[i], dtype=torch.int64)
            user_dataset = torch.utils.data.TensorDataset(user_X, user_y)
            train_loader = torch.utils.data.DataLoader(user_dataset,batch_size=args.batch_size, shuffle=True)

        # Build training models
        if args.dataset == 'MNIST':
            model = build_model("MLP", 10)
        else:
            model = build_model("CNN", 10)
        model.to(device)
        optimizer=torch.optim.Adam(params=model.parameters(),
                            lr=1e-3, betas=(0.9, 0.999))
        loss_fn = nn.CrossEntropyLoss()

        # Start model training
        for epoch in range(args.epochs):
            # iterate over batches
            for train_X, train_y  in train_loader:
                output = model(train_X.to(device))
                loss = loss_fn(output, train_y.to(device))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Compute the test accuracy
        with torch.no_grad():
            corrects = 0
            counts = 0
            for test_X, test_y in test_loader: # test_loader
                    test_X = test_X.to(device)
                    test_y = test_y.to(device)
                    output = model(test_X)
                    corrects += (torch.sum(torch.argmax(output, dim=1) == test_y)).item()
                    counts += len(test_y)
            acc = corrects / counts
        models.append(model)
        print("Party         : ", i)
        print(f"Test Accuracy: {acc}")
        errors.append(1-acc)

    # Calculate model values using optimal ensemble
    outputs = []
    with torch.no_grad():
        for val_X, val_y in val_loader:
            val_y = one_hot_encode(val_y,args.num_class)
            with torch.no_grad():
                output = get_optimal_weights(models, val_X.to(device), val_y, args.num_class)
                outputs.append(output)
    outputs = np.array(outputs)    
    return errors, outputs.mean(0)


def main(args):
    print('Dataset: ', args.dataset, ' --- IID Data: ', (args.iid_data==1))
    errors_collects = []
    shapleys_collects = []
    for i in range(100):
        print('------- Trial: ', i, ' ---------')
        partition = np.random.dirichlet(np.ones(5),size=1).reshape(5,).tolist()
        # try:
        errors, shapleys = get_values(args, partition)
        # except:
        #     continue
        errors_collects.append(errors)
        shapleys_collects.append(shapleys)

    #define data
    x = np.array(errors_collects)
    y = np.array(shapleys_collects)
    file_name = args.dataset + '_' + str(args.alpha) + '_' + str(args.num_users) + '.npy' 
    np.save('X_' + file_name, x)
    np.save('Y_' + file_name, y)

    correlations = []
    for i in range(len(x)):
        _x = x[i]
        _y = y[i]
        correlation_matrix = np.corrcoef(_x, _y)  # This returns a 2x2 correlation matrix
        correlation = correlation_matrix[0, 1]  # We want the correlation coefficient between the two arrays
        correlations.append(correlation)

    samples = np.array(correlations)
    print('Means: ', np.mean(samples), '     Std:', np.std(samples))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_class", type=int, default=10)
    parser.add_argument("--num_users", type=int, default=5)
    parser.add_argument("--dataset", type=str, default='CIFAR10')
    parser.add_argument("--iid_data", type=int, default=1)
    parser.add_argument("--alpha", type=float, default=0.1)
    parser.add_argument("--cuda_num", type=str, default='cuda:6')
    parser.add_argument("--ensemble", type=str, default='opt')
    parser.add_argument("--model", type=str, default='MLP')
    parser.add_argument("--epochs", type=int, default=15)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--val_size", type=int, default=5000)

    args = parser.parse_args()

    main(args)
