import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.optimize import minimize
import torch.nn.functional as F
from net.models import build_model
import torchvision.transforms as transforms
from torch.utils.data import Subset
import torch.nn as nn
from server.collaboration import one_hot_encode, get_optimal_weights
from utils.process_data import *


def get_values(args, split_ratio):
    transform = transforms.Compose(
                        [transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))])
    raw_data = torchvision.datasets.MNIST('../datasets', train=True, download=True,
                                        transform=transform)
    test_dataset = torchvision.datasets.MNIST('../datasets', train=False, download=True,
                                        transform=torchvision.transforms.Compose([
                                        torchvision.transforms.ToTensor(),
                                        torchvision.transforms.Normalize(
                                            (0.1307,), (0.3081,))
                                    ]))

    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=len(test_dataset), shuffle=False, num_workers=4)

    # Divide training and validation dataset
    val_size = args.val_size
    train_size = len(raw_data) - val_size

    # Create indices for the datasets
    indices = list(range(len(raw_data)))
    random.shuffle(indices)
    train_indices = indices[:train_size]
    val_indices = indices[train_size:]
    # Create the training and validation datasets
    train_data = Subset(raw_data, train_indices)
    val_data = Subset(raw_data, val_indices)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=4)
    

    # Get iid data
    datasets = torch.utils.data.random_split(train_data, split_ratio, generator=torch.Generator().manual_seed(42))


    device = torch.device(args.cuda_num)
    models = []
    errors = []
    for i in range(args.num_users):
        # Process data from each user
        train_loader = torch.utils.data.DataLoader(datasets[i],batch_size=args.batch_size, shuffle=True, num_workers=4)

        # Build training models
        model = build_model("MLP"+str(i+1), 10)
        model.to(device)
        optimizer=torch.optim.Adam(params=model.parameters(),
                            lr=1e-3, betas=(0.9, 0.999))
        loss_fn = nn.CrossEntropyLoss()

        # Start model training
        for epoch in range(args.epochs):
            # iterate over batches
            for train_X, train_y  in train_loader:
                output = model(train_X.to(device))
                loss = loss_fn(output, train_y.to(device))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Compute the test accuracy
        with torch.no_grad():
            corrects = 0
            counts = 0
            for test_X, test_y in test_loader: # test_loader
                    test_X = test_X.to(device)
                    test_y = test_y.to(device)
                    output = model(test_X)
                    corrects += (torch.sum(torch.argmax(output, dim=1) == test_y)).item()
                    counts += len(test_y)
            acc = corrects / counts
        models.append(model)
        print("Party         : ", i)
        print(f"Test Accuracy: {acc}")
        errors.append(1-acc)

    # Calculate model values using optimal ensemble
    outputs = []
    with torch.no_grad():
        for val_X, val_y in val_loader:
            val_y = one_hot_encode(val_y,args.num_class)
            with torch.no_grad():
                output = get_optimal_weights(models, val_X.to(device), val_y, args.num_class)
                outputs.append(output)
    outputs = np.array(outputs)    
    return errors, outputs.mean(0)


def main(args):
    errors_collects = []
    shapleys_collects = []
    i = 1
    while i <= 100:
        print('------- Trial: ', i, ' ---------')
        partition = np.random.dirichlet(np.ones(5),size=1).reshape(5,).tolist()
        try:
            errors, shapleys = get_values(args, partition)
        except:
            continue
        i += 1
        errors_collects.append(errors)
        shapleys_collects.append(shapleys)

    #define data
    x = np.array(errors_collects)
    y = np.array(shapleys_collects)
    np.save('X_.npy', x)
    np.save('Y_.npy', y)
    print('Shapleys: ', y.mean(0))
    correlations = []
    for i in range(len(x)):
        _x = x[i]
        _y = y[i]
        correlation_matrix = np.corrcoef(_x, _y)  # This returns a 2x2 correlation matrix
        correlation = correlation_matrix[0, 1]  # We want the correlation coefficient between the two arrays
        correlations.append(correlation)

    samples = np.array(correlations)
    print('Means: ', np.mean(samples), '     Std:', np.std(samples))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_class", type=int, default=10)
    parser.add_argument("--num_users", type=int, default=5)
    parser.add_argument("--cuda_num", type=str, default='cuda:6')
    parser.add_argument("--ensemble", type=str, default='opt')
    parser.add_argument("--model", type=str, default='MLP')
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--val_size", type=int, default=5000)

    args = parser.parse_args()

    main(args)
