import torch
import torchvision
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
import random
from scipy.optimize import minimize
import torch.nn.functional as F
from net.models import build_model
import torchvision.transforms as transforms
from torch.utils.data import Subset
import torch.nn as nn
from numpy import linalg as LA
from server.collaboration import *
from utils.process_data import *


def get_ensemble_values(args):
    if args.dataset == "CIFAR10":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        raw_data = torchvision.datasets.CIFAR10('../datasets', train=True, download=True,
                                                transform=transform)
        test_dataset = torchvision.datasets.CIFAR10('../datasets', train=False, download=True,
                                                transform=transform)
    elif args.dataset == "MNIST":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))])
        raw_data = torchvision.datasets.MNIST('../datasets', train=True, download=True,
                                            transform=transform)
        test_dataset = torchvision.datasets.MNIST('../datasets', train=False, download=True,
                                            transform=torchvision.transforms.Compose([
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize(
                                                (0.1307,), (0.3081,))
                                        ]))
    elif args.dataset == "SVHN":
        transform = transforms.Compose(
                            [transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        raw_data = torchvision.datasets.SVHN('../datasets', split = "train", download=True,
                                                transform=transform)
        test_dataset = torchvision.datasets.SVHN('../datasets', split = "test", download=True,
                                                transform=transform)

    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=1, shuffle=False, num_workers=4)

    # Divide training and validation dataset
    val_size = args.val_size
    train_size = len(raw_data) - val_size
    # Create indices for the datasets
    indices = list(range(len(raw_data)))
    random.shuffle(indices)

    train_indices = indices[:train_size]
    val_indices = indices[train_size:]
    # Create the training and validation datasets
    train_data = Subset(raw_data, train_indices)
    val_data = Subset(raw_data, val_indices)
    
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, num_workers=4)
    print(len(val_loader))

    # Get non-iid data
    data_by_class, n_train_sample, SRC_N_CLASS = get_dataset(train_data, len(train_indices), num_class=args.num_class)
    SRC_CLASSES=[l for l in range(SRC_N_CLASS)]
    random.shuffle(SRC_CLASSES)
    datasets_X, datasets_y = process_user_data(data_by_class, n_train_sample, SRC_CLASSES, 
                                                num_users=args.num_users,
                                                alpha=args.alpha)

    device = torch.device(args.cuda_num)
    models = []
    errors = []

    for i in range(args.num_users):
        # Process data from each user        
        user_X, user_y = torch.tensor(datasets_X[i], dtype=torch.float32), torch.tensor(datasets_y[i], dtype=torch.int64)
        user_dataset = torch.utils.data.TensorDataset(user_X, user_y)
        train_loader = torch.utils.data.DataLoader(user_dataset,batch_size=args.batch_size, shuffle=True)

        # Build training models
        if args.dataset == 'MNIST':
            model = build_model("MLP", 10)
        else:
            model = build_model("CNN", 10)
        model.to(device)
        optimizer=torch.optim.Adam(params=model.parameters(),
                            lr=1e-3, betas=(0.9, 0.999))
        loss_fn = nn.CrossEntropyLoss()

        # Start model training
        for epoch in range(args.epochs):
            # iterate over batches
            for train_X, train_y  in train_loader:
                output = model(train_X.to(device))
                loss = loss_fn(output, train_y.to(device))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        with torch.no_grad():
            corrects = 0
            counts = 0
            for test_X, test_y in test_loader: # test_loader
                    test_X = test_X.to(device)
                    test_y = test_y.to(device)
                    output = model(test_X)
                    corrects += (torch.sum(torch.argmax(output, dim=1) == test_y)).item()
                    counts += len(test_y)
            acc = corrects / counts
        errors.append(1-acc)
        models.append(model)
    
    # Calculate model values using optimal ensemble
    opt_shapley_set = []
    mv_shapley_set = []
    kv_shapley_set = []
    mwu_shapley_set = []

    avg_correct = 0
    mv_correct = 0
    kv_correct = 0
    mwu_correct = 0
    mwu_weights = np.ones(len(models))/len(models)
    w_collects = []
    count_none = 0
    with torch.no_grad():
        print('[Valuation]', end='')
        for i, (X, y) in enumerate(val_loader):
            print("\r Iteration: [", i, "/", len(val_loader), "]", end='')
            one_hot_y = one_hot_encode(y,args.num_class)
            with torch.no_grad():
                opt_shapley = get_optimal_weights(models, X.to(device), one_hot_y, args.num_class)

                avg_output = 0
                preds = []
                probs = []
                for i, model in enumerate(models):
                    output = model(X.to(device)).cpu()
                    output = F.softmax(output,dim=1)
                    probs.append(output)
                    avg_output += output
                    pred =  torch.argmax(output, dim=1)
                    preds.append(pred)
                
                # Average Ensemble 
                avg_correct += (torch.argmax(avg_output, dim=1) == y)

                # Majority Vote
                mv_shapley = np.zeros(len(models))
                counts = np.bincount(preds)
                consensus_prediction = np.argmax(counts)
                mv_correct += (consensus_prediction == y)
                for i, pred in enumerate(preds):
                    if consensus_prediction == pred:
                        mv_shapley[i] += 1/counts[consensus_prediction]
                
                # Knowledge Vote
                probs = torch.cat(probs,0)
                kv_shapley = np.zeros(len(models))
                indices, _, kv_pred, kv_len = knowledge_vote(probs, num_class=args.num_class)
                if kv_pred is not None:
                    # print("len of indice: ", len(indices))
                    kv_shapley[indices] += 1 / len(indices)
                    _correct = (kv_pred == y)
                else:
                    # print('kv_pred is none: ', kv_len)
                    count_none += 1
                    # Replace KV with Average when it is not available
                    kv_shapley = np.ones(len(models))/len(models)
                    _correct = (torch.argmax(avg_output, dim=1) == y)
                kv_correct += _correct
                
                # MWU
                mwu_shapley = np.zeros(len(models))
                mwu_pred = 0
                for i, weight in enumerate(mwu_weights):
                    mwu_pred += weight *  probs[i]
                    mwu_shapley[i] = weight
                mwu_correct += (np.argmax(mwu_pred) == y)
                # Update MWU
                weight_sum = 0
                for i, weight in enumerate(mwu_weights):
                    weight_sum += LA.norm(probs[i]-mwu_pred)
                for i, weight in enumerate(mwu_weights):
                    mwu_weights[i] = -np.log(LA.norm(probs[i]-mwu_pred)/weight_sum)
                
                opt_shapley_set.append(opt_shapley)
                mv_shapley_set.append(mv_shapley)
                kv_shapley_set.append(kv_shapley)
                mwu_shapley_set.append(mwu_shapley)

   
    print("kv count_none: ", count_none)
    print("w collect mean: ", np.array(w_collects).mean())
    opt_shapley_norm = np.array(opt_shapley_set).mean(0)
    avg_shapley_norm = np.ones(len(models))/len(models)
    avg_shapley_norm[0] += 1e-8
    avg_shapley_norm[1] -= 1e-8
    mv_shapley_norm = np.array(mv_shapley_set).mean(0)
    kv_shapley_norm = np.array(kv_shapley_set).mean(0)
    mwu_shapley_norm = np.array(mwu_shapley_set).mean(0)
    # Get Correlations
    errors = np.array(errors)
    avg_corr = np.corrcoef(errors, avg_shapley_norm)[0,1]
    mv_corr = np.corrcoef(errors, mv_shapley_norm)[0,1]
    kv_corr = np.corrcoef(errors, kv_shapley_norm)[0,1]
    mwu_corr = np.corrcoef(errors, mwu_shapley_norm)[0,1]
    print('--------------Correlation--------------')
    print([avg_corr, mv_corr, kv_corr, mwu_corr])

    # get ACC
    avg_acc = avg_correct.item()/len(val_loader)
    mv_acc = mv_correct.item()/len(val_loader)
    kv_acc = kv_correct.item()/len(val_loader)
    mwu_acc = mwu_correct.item()/len(val_loader)
    print('--------------Accuracy--------------')
    print([avg_acc, mv_acc, kv_acc, mwu_acc])

    return [avg_acc, mv_acc, kv_acc, mwu_acc], [avg_corr, mv_corr, kv_corr, mwu_corr]


def main(args):
    print('Dataset: ', args.dataset)
    ss_collects = []
    acc_collects = []
    corr_collects = []
    for i in range(args.total_trials):
        print('------- Trial: ', i, ' ---------')
        acc_array, corr_array = get_ensemble_values(args)
        corr_collects.append(corr_array)
        acc_collects.append(acc_array)

    acc_collects = np.array(acc_collects)
    corr_collects = np.array(corr_collects)
    file_name = args.dataset + '_' + str(args.alpha) + '_' + str(args.num_users) + '.npy' 
    np.save('acc_' + file_name, acc_collects)
    np.save('corre_' + file_name, corr_collects)

    print('--------------------------')
    print('Means AVG Corr  : ', corr_collects[:,0].mean(), ' Std: ', corr_collects[:,0].std())
    print('Means MV Corr   : ', corr_collects[:,1].mean(), ' Std: ', corr_collects[:,1].std())
    print('Means KV Corr   : ', corr_collects[:,2].mean(), ' Std: ', corr_collects[:,2].std())
    print('Means MWU Corr   : ', corr_collects[:,3].mean(), ' Std: ', corr_collects[:,3].std())
    print('--------------------------')
    print('Means AVG ACC : ', acc_collects[:,0].mean(), ' Std: ', acc_collects[:,0].std())
    print('Means MV ACC  : ', acc_collects[:,1].mean(), ' Std: ', acc_collects[:,1].std())
    print('Means KV ACC  : ', acc_collects[:,2].mean(), ' Std: ', acc_collects[:,2].std())
    print('Means MWU ACC  : ', acc_collects[:,3].mean(), ' Std: ', acc_collects[:,3].std())



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_class", type=int, default=10)
    parser.add_argument("--num_users", type=int, default=5)
    parser.add_argument("--total_trials", type=int, default=100)
    parser.add_argument("--dataset", type=str, default='CIFAR10')
    parser.add_argument("--alpha", type=float, default=0.1)
    parser.add_argument("--cuda_num", type=str, default='cuda:6')
    parser.add_argument("--ensemble", type=str, default='opt')
    parser.add_argument("--model", type=str, default='MLP')
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--val_size", type=int, default=5000)

    args = parser.parse_args()

    main(args)
