import argparse
import logging
import os
os.environ["OMP_NUM_THREADS"] = "8"

import random
import sys

import numpy as np
import torch
from torch import nn
from torch.utils.data import dataset
from torchvision import models
import wandb

from data_loader.load_data_json import load_data_fmnist, load_data_emnist, load_data_cifar10, load_data_shakespeare, load_data_fets21

from model.two_nn import Two_NN
from model.cnn import CNN # 28*28 or 32*32 -> 10/47
from model.logistic import LogisticRegression # 784 -> 10
from model.rnn import RNN
# from model.resnext import PretrainedResNext
from model.unet import UNet

from criterion.dice_loss import dice_loss

from accuracy.classification_accuracy import classification_accuracy, stream_accuracy
from accuracy.dice_score import dice_score

from api.vanilla_api import Local_API, FedAvg_API, FedProx_API, PerFL_API, Cho_API, Song_API, Chen_API 
from api.varsel_api import VARSEL_API


def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # Training settings
    parser.add_argument('--model', type=str, default='logistic', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--dataset', type=str, default='fashion-mnist', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--data_dir', type=str, default='/home/$HOME/DATA/EMNIST',
                        help='data directory')

    parser.add_argument('--batch_size', type=int, default=32, metavar='N',
                        help='input batch size for training (default: 32)')

    parser.add_argument('--optimizer', type=str, default='adam',
                        help='SGD with momentum; adam')

    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.1)')

    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=0.00001)

    parser.add_argument('--rounds', type=int, default=200, metavar='R',
                        help='how many iteration rounds will be trained')

    parser.add_argument('--epochs', type=int, default=3, metavar='EP',
                        help='how many epochs will be taken within each iteration round')

    parser.add_argument('--test_interval', type=int, default=5,
                        help='the test interval')

    parser.add_argument('--training_strategy', type=str, default='FedAvg', help='the training strategy (default: selective learning)')

    parser.add_argument('--gpu', type=int, default=0,
                        help='gpu')

    parser.add_argument('--expected_colaborators', type=int, default=10, help='number of expected colaborators within each iteration round')

    parser.add_argument('--internal_batch_train', type=str, default='mini', help='minibatch or fullbatch for internal clients')

    parser.add_argument('--external_batch_train', type=str, default='mini', help='minibatch or fullbatch for external clients')

    parser.add_argument('--lambda_prox', type=float, default=0, help='proximal term for clients')

    return parser


def load_data(args, dataset_name):
    logging.info("load_data. dataset_name = %s" % dataset_name)

    if dataset_name == "fashion-mnist":
        data_loader = load_data_fmnist
    elif dataset_name == "emnist":
        data_loader = load_data_emnist
    elif dataset_name == "cifar10":
        data_loader = load_data_cifar10
    elif dataset_name == "shakespeare":
        data_loader = load_data_shakespeare
    elif dataset_name == "fets21":
        data_loader = load_data_fets21
    else:
        raise ValueError("Invalid_Dataset_Name")
    
    if dataset_name == "fashion-mnist" or dataset_name == "emnist" or dataset_name == "cifar10" or dataset_name == "cifar100" or dataset_name == "shakespeare" or dataset_name == "fets21":
        internal_cid, internal_train_data, internal_test_data, external_cid, external_data = data_loader(args.data_dir+"/internal", args.data_dir+"/external")
    else:
        pass

    return {"internal_cid": internal_cid, "internal_train_data": internal_train_data, "internal_test_data": internal_test_data, "external_cid": external_cid, "external_data": external_data}


def create_model(args, model_name):
    logging.info("create_model. model_name = %s" % (model_name))
    logging.info("{} + {}".format(model_name, args.dataset))
    model = None

    if model_name == "logistic" and args.dataset == "fashion-mnist":
        model = LogisticRegression(output_dim=10)
    elif model_name == "2nn" and args.dataset == "emnist":
        model = Two_NN(only_digits=False)
    elif model_name == "cnn" and args.dataset == "cifar10":
        model = CNN(in_channels=3, size=32, only_digits=True)
    elif model_name == "rnn" and args.dataset == "shakespeare":
        model = RNN()
    elif model_name == "unet" and args.dataset == "fets21":
        model = UNet(1, 2, size=(120, 120))
    else:
        raise NotImplementedError("Unregistered_Pair_Of_Model_And_Dataset")

    return model


def select_criterion(args, dataset_name):
    criterion = None 

    if dataset_name == "fashion-mnist" or dataset_name == "emnist" or dataset_name == "cifar10":
        criterion = torch.nn.CrossEntropyLoss()
    elif dataset_name == "shakespeare":
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
    elif dataset_name == "fets21":
        criterion = dice_loss
    else:
        raise ValueError("Unspecified_Criterion")

    return criterion


def calculate_accuracy(args, dataset_name):
    accuracy = None 

    if dataset_name == "fashion-mnist" or dataset_name == "emnist" or dataset_name == "cifar10":
        accuracy = classification_accuracy
    elif dataset_name == "shakespeare":
        accuracy = stream_accuracy
    elif dataset_name == "fets21":
        accuracy = dice_score
    else:
        raise ValueError("Unspecified_Accuracy")

    return accuracy


def find_api(args):
    api = None

    if args.training_strategy == 'Local':
        api = Local_API
    elif args.training_strategy == 'FedAvg':
        api = FedAvg_API 
    elif args.training_strategy == 'FedProx':
        api = FedProx_API
    elif args.training_strategy == 'PerFL':
        api = PerFL_API
    elif args.training_strategy == 'Cho':
        api = Cho_API
    elif args.training_strategy == 'Song':
        api = Song_API
    elif args.training_strategy == 'Chen':
        api = Chen_API
    elif args.training_strategy == 'VARSEL':
        api = VARSEL_API
    else:
        raise NotImplementedError("Invalid_API")

    return api


if __name__ == "__main__":
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)

    parser = add_args(argparse.ArgumentParser(description='Selective-Learning'))
    args = parser.parse_args()
    logger.info(args)
    device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
    # device = "cpu"
    logger.info(device)

    wandb.init(
        project="Selfish-Federated-Learning",
        name=str(args.dataset) + "-" + str(args.model) + "-" + str(args.optimizer) + "-b" + str(args.batch_size) + "-r" + str(args.rounds) + "-e" + str(args.epochs) + "-lr" + str(args.lr) + "-wd" + str(args.wd) + "-lambda" + str(args.lambda_prox) + "-" + str(args.training_strategy),
        config=args,
        tags=["FINAL"]
    )

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    random.seed(2021)
    np.random.seed(2021)
    torch.manual_seed(2021)
    torch.cuda.manual_seed_all(2021)

    # load data
    data = load_data(args, args.dataset)

    # create model
    model = create_model(args, model_name=args.model)

    # set criterion, accuracy
    criterion = select_criterion(args, args.dataset)
    accuracy = calculate_accuracy(args, args.dataset)

    # call correspongding learning api
    api = find_api(args)
    learning_api = api(data, device, args, model, criterion, accuracy)    
    learning_api.run()
