import argparse
import logging
import os
import random
import sys
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import wandb

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../../")))

# add the FedML root directory to the python path
from fedml_api.distributed.fedssl.utils import get_global_model_path, get_personalized_model_path

from fedml_api.data_preprocessing.Landmarks_per.data_loader import load_partition_data_landmarks

from fedml_api.data_preprocessing.cifar10_per.data_loader import load_partition_data_cifar10_ssl_linear_eval

from fedml_api.model.cv.ssl import get_ssl_model


def add_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # Training settings
    parser.add_argument('--run_id', type=int, default=0, metavar='N',
                        help='run id')

    parser.add_argument('--model', type=str, default='resnet18_cifar', metavar='N',
                        help='neural network used in training')

    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset used for training')

    parser.add_argument('--data_dir', type=str, default='./../../../data/cifar10',
                        help='data directory')

    parser.add_argument('--partition_method', type=str, default='hetero', metavar='N',
                        help='how to partition the dataset on local workers')

    parser.add_argument('--partition_alpha', type=float, default=0.5, metavar='PA',
                        help='partition alpha (default: 0.5)')

    parser.add_argument('--client_num_in_total', type=int, default=1000, metavar='NN',
                        help='number of workers in a distributed cluster')

    parser.add_argument('--client_num_per_round', type=int, default=4, metavar='NN',
                        help='number of workers')

    parser.add_argument('--batch_size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')

    parser.add_argument('--accumulation_steps', type=int, default=8,
                        help='accumulation_steps')

    parser.add_argument('--path_of_local_model', type=str, default="./checking_points/",
                        help='path_of_local_model')

    parser.add_argument('--client_optimizer', type=str, default='adam',
                        help='SGD with momentum; adam')

    parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
                        help='learning rate (default: 0.001)')

    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')

    parser.add_argument('--wd', help='weight decay parameter;', type=float, default=5e-4)

    parser.add_argument('--perFedAvg_is_first_order', help='perFedAvg_is_first_order', action='store_true')

    parser.add_argument('--epochs', type=int, default=1, metavar='EP',
                        help='how many epochs will be trained locally')

    parser.add_argument('--comm_round', type=int, default=800,
                        help='how many round of communications we shoud use')

    parser.add_argument('--is_mobile', type=int, default=0,
                        help='whether the program is running on the FedML-Mobile server side')

    parser.add_argument('--frequency_of_the_test', type=int, default=10,
                        help='the frequency of the algorithms')

    parser.add_argument('--gpu_server_num', type=int, default=1,
                        help='gpu_server_num')

    parser.add_argument('--gpu_num_per_server', type=int, default=4,
                        help='gpu_num_per_server')

    parser.add_argument('--gpu_mapping_file', type=str, default="gpu_mapping.yaml",
                        help='the gpu utilization file for servers and clients. If there is no \
                        gpu_util_file, gpu will not be used.')

    parser.add_argument('--gpu_mapping_key', type=str, default="mapping_default",
                        help='the key in gpu utilization file')

    # personalization
    # m1: model interpolation
    parser.add_argument('--personalized_model_path', type=str, default="./checkpoint",
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")

    parser.add_argument('--pssl_optimizer', type=str, default="pFedMe",
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")

    parser.add_argument('--pssl_lambda', type=float, default=1,
                        help="personalization_method: None; pFedMe; ditto; perFedAvg")

    parser.add_argument('--personal_local_epochs', type=int, default=1, metavar='EP',
                        help='how many epochs will be trained locally for the personal model')

    parser.add_argument('--pfedme_beta', type=int, default=0.001, metavar='EP',
                        help='moving averaging beta value.')

    # self-supervised related
    parser.add_argument('--ssl_method', type=str, default="simsiam",
                        help='self-supervised learning method: byol, simclr, moco, etc')

    parser.add_argument('--ci', type=int, default=0,
                        help='CI')

    # eval
    parser.add_argument('--gpu_id', type=int, default=0,
                        help='GPU ID')
    parser.add_argument('--global_train', action='store_true')
    parser.add_argument('--global_test', action='store_true')

    args = parser.parse_args()
    return args


def load_data(args, dataset_name):
    if dataset_name == "gld23k_per":
        fed_train_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_train.csv')
        fed_test_map_file = os.path.join(args.data_dir, 'data_user_dict/gld23k_user_dict_test.csv')
        args.data_dir = os.path.join(args.data_dir, 'images')

        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num, args.client_num_in_total = load_partition_data_landmarks(dataset=dataset_name,
                                                                            data_dir=args.data_dir,
                                                                            fed_train_map_file=fed_train_map_file,
                                                                            fed_test_map_file=fed_test_map_file,
                                                                            partition_method=None, partition_alpha=None,
                                                                            batch_size=args.batch_size)
        logging.info("data point len = %d" % len(test_data_local_dict))
        logging.info("args.client_num_in_total = %d" % args.client_num_in_total)

    else:
        if dataset_name == "cifar10":
            data_loader = load_partition_data_cifar10_ssl_linear_eval
        else:
            raise Exception("no such dataset")

        train_data_num, test_data_num, train_data_global, test_data_global, \
        train_data_local_num_dict, train_data_local_dict, test_data_local_dict, \
        class_num = data_loader(args.dataset, args.data_dir, args.ssl_method, args.partition_method,
                                args.partition_alpha, args.client_num_in_total,
                                args.batch_size, args.accumulation_steps)
    dataset = [train_data_num, test_data_num, train_data_global, test_data_global,
               train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num]
    return dataset


def get_personalized_model_by_client_idx(class_num, client_idx):
    if args.pssl_optimizer == "perFedAvg" or \
            args.pssl_optimizer == "FedAvg" or \
            args.pssl_optimizer == "FedAvg_LocalAdaptation":
        path = get_global_model_path(args)
        logging.info("=> loading checkpoint '{}'".format(path))
        state_dict = torch.load(path, map_location="cpu")
    else:
        path = get_personalized_model_path(args, client_idx)
        logging.info(path)
        logging.info("=> loading checkpoint '{}'".format(path))
        state_dict = torch.load(path, map_location="cpu")

    ssl_model = get_ssl_model(class_num, 'simsiam', args.pssl_optimizer, 'resnet18_cifar')
    ssl_model.load_state_dict(state_dict)
    return ssl_model.model


def local_linear_eval(
        encoder: nn.Module,
        device,
        train_dataloader: DataLoader,
        test_dataloader: DataLoader,
        feat_dim: int,
        num_class: int,
        lr: float = 30,
        weight_decay: float = 0.0001,
        momentum: float = 0.9,
        num_epoch=200
):
    encoder.to(device)
    linear_classifier = nn.Sequential(
        nn.Dropout(p=0.2),
        nn.Linear(feat_dim, num_class),
        nn.Softmax(dim=1)
    )
    linear_classifier.to(device)

    # train linear classifier
    optim = torch.optim.SGD(linear_classifier.parameters(), lr=lr, weight_decay=weight_decay,
                            momentum=momentum)  # SimSiam also uses LARS
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=num_epoch)

    acc_history = []

    for epoch_id in range(num_epoch):
        linear_classifier.train()

        for batch_idx, (x, y) in enumerate(train_dataloader):
            x, y = x.to(device), y.to(device)
            rep = torch.flatten(encoder(x), start_dim=1).detach()
            _y = linear_classifier(rep)
            loss = F.cross_entropy(_y, y)

            loss.backward()
            optim.step()
            optim.zero_grad()

        lr_scheduler.step()

        # eval linear classifier
        linear_classifier.eval()
        num_match = 0
        num_sample = 0
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            rep = torch.flatten(encoder(x), start_dim=1)
            _y = linear_classifier(rep)

            num_match += torch.sum(torch.max(_y, dim=1).indices == y)
            num_sample += len(y)
        acc_history.append((num_match / num_sample).cpu().numpy())
        logging.warning("epoch_id = %d, acc = %.32f" % (epoch_id, acc_history[-1]))

    return max(acc_history)


if __name__ == "__main__":
    # parse python script input parameters
    parser = argparse.ArgumentParser()
    args = add_args(parser)

    logging.basicConfig(level=logging.DEBUG,
                        format='%(process)s %(asctime)s.%(msecs)03d - {%(module)s.py (%(lineno)d)} - %(funcName)s(): %(message)s',
                        datefmt='%Y-%m-%d,%H:%M:%S')

    logging.warning(args)

    # Set the random seed. The np.random seed determines the dataset partition.
    # The torch_manual_seed determines the initial weight.
    # We fix these two, so that we can reproduce the result.
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)

    wandb.init(project="self-supervised-fl", entity="automl", config=args, name="SSFL-PER_LINEAR_EVAL")

    # load data
    dataset = load_data(args, args.dataset)
    [train_data_num, test_data_num, train_data_global, test_data_global,
     train_data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num] = dataset

    acc_list = []

    device = torch.device("cuda:" + str(args.gpu_id))

    logging.warning("Start eval")
    for client_idx in range(args.client_num_in_total):
        logging.info(f"eval on client {client_idx}")
        model = get_personalized_model_by_client_idx(class_num=class_num, client_idx=client_idx)

        acc = local_linear_eval(
            model,
            device,
            train_data_local_dict[client_idx][0] if not args.global_train else train_data_global,
            test_data_local_dict[client_idx][0] if not args.global_test else test_data_global,
            feat_dim=model.output_dim,
            num_class=class_num, lr=args.lr
        )
        acc_list.append(acc)
        logging.warning(f"client acc {acc}")

    mean, std = np.mean(acc_list), np.std(acc_list)

    logging.warning(f'mean_acc: {mean}')
    logging.warning(f'std_acc: {std}')
