import copy
import os
import torch
import random
import wandb
import signal
import sys
import numpy as np
from torch import nn

from torch.utils.data import DataLoader
from utils import Logger, fed_args, read_config, log_config, AverageMeter
from utils.fed_utils import init_model, assign_dataset
from utils.models import ShufflePatches
import torchvision.transforms as transforms
import torchvision.models as tvmodels

args = fed_args()
args = read_config(args.config, args)
if Logger.logger is None:
    L = Logger()
    if not os.path.exists("train_records/"):
        os.makedirs("train_records/")
    L.set_log_name(os.path.join("train_records", "train_record_" + args.save_name + ".log"))
    logger = L.get_logger()
    log_config(args)
using_wandb = args.using_wandb
print_specified = False
from preprocessing.baselines_dataloader import divide_data_with_dirichlet, divide_data_with_local_cls
from fedsd2c.fedsd2c_client import FedClient
from fedsd2c.fedsd2c_server import FedSD2CServer
from fedsd2c.fedsd2c_utils import DistilledDataset, compute_psnr

torch.cuda.empty_cache()


def test_model(model, testset, batch_size, device):
    """
    Server tests the model on test dataset.
    """
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
    model.to(device)
    accuracy_collector = 0
    for step, (x, y) in enumerate(test_loader):
        with torch.no_grad():
            b_x = x.to(device)  # Tensor on GPU
            b_y = y.to(device)  # Tensor on GPU

            test_output = model(b_x)
            pred_y = torch.max(test_output, 1)[1].to(device).data.squeeze()
            accuracy_collector = accuracy_collector + sum(pred_y == b_y)
    accuracy = accuracy_collector / len(testset)

    return accuracy.cpu().numpy()


def test_specified(model, testset, batch_size, device, specified_labels):
    """
    Server tests the model on test dataset and record specified label accuracy
    """
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)
    model.to(device)
    accuracy_collector = 0
    specified_acc_collector = {label: AverageMeter() for label in specified_labels}
    for step, (x, y) in enumerate(test_loader):
        with torch.no_grad():
            b_x = x.to(device)  # Tensor on GPU
            b_y = y.to(device)  # Tensor on GPU

            test_output = model(b_x)
            pred_y = torch.max(test_output, 1)[1].to(device).data.squeeze()
            accuracy_collector = accuracy_collector + sum(pred_y == b_y)

            for label in specified_labels:
                indexes = torch.where(b_y == label)[0]
                if indexes.shape[0] > 0:
                    specified_pred = pred_y[indexes]
                    specified_y = b_y[indexes]
                    specified_acc_collector[label].update(sum(specified_pred == specified_y) / int(indexes.shape[0]),
                                                          int(indexes.shape[0]))
    accuracy = accuracy_collector / len(testset)

    return accuracy.cpu().numpy(), specified_acc_collector


def get_specified_labels(client, ipc):
    dataset = client._train_data
    indices = np.array(dataset.indices)
    targets = np.array(dataset.dataset.targets, dtype=np.int64)[indices]
    unique_classes = np.unique(targets)
    containing_labels = set()
    distilled_data_labels = set()
    for c in unique_classes.tolist():
        containing_labels.add(c)
        c_indices = np.where(targets == c)[0]
        if len(c_indices) > ipc:
            distilled_data_labels.add(c)
    return distilled_data_labels, containing_labels


def load_teacher_model():
    num_class, img_dim, image_channel = assign_dataset(args.sys_dataset)
    model = init_model(args.sys_model, num_class, image_channel, im_size=img_dim)
    if args.fedsd2c_m_path == "local":
        return model
    elif args.fedsd2c_m_path == "tv":
        _m = tvmodels.resnet18(tvmodels.ResNet18_Weights.IMAGENET1K_V1)
        _m.conv1 = nn.Identity()
        _m.fc = nn.Identity()
        _missing_keys = model.load_state_dict(_m.state_dict(), strict=False)
        return model
    w = torch.load(args.fedsd2c_m_path, map_location="cpu")
    model.load_state_dict(w)
    logger.info(f"Load teacher model from: {args.fedsd2c_m_path}")
    return model


augment = transforms.Compose([
    # ShufflePatches(args.fedsd2c_factor),
    transforms.RandomResizedCrop(
        size=64,
        scale=(1, 1),
        antialias=True
    ),
    transforms.RandomHorizontalFlip()
])


def main():
    """
    Main function for fedsd2c
    """

    dataset_list = ['MNIST', 'CIFAR10', 'FashionMNIST', 'SVHN', 'CIFAR100', 'TINYIMAGENET', 'Imagenette', 'openImg']
    assert args.sys_dataset in dataset_list, "The dataset is not supported"

    model_list = ["LeNet", 'AlexCifarNet', "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152", "CNN", "Conv4", "Conv5", "Conv6"]
    assert args.sys_model in model_list, "The model is not supported"

    random.seed(args.sys_i_seed)
    np.random.seed(args.sys_i_seed)
    torch.manual_seed(args.sys_i_seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed(args.sys_i_seed)
    torch.set_num_threads(10)

    client_dict = {}

    logger.info('======================Setup Clients==========================')
    if args.sys_n_local_class is not None and args.sys_dataset_dir_alpha is None:
        logger.info('Using divide data with local class')
        trainset_config, testset, cls_record = divide_data_with_local_cls(n_clients=args.sys_n_client,
                                                                          n_local_cls=args.sys_n_local_class,
                                                                          dataset_name=args.sys_dataset,
                                                                          seed=42,
                                                                          aug=args.client_instance_aug)
    elif args.sys_dataset_dir_alpha is not None:
        logger.info('Using divide data with dirichlet')
        trainset_config, testset, cls_record = divide_data_with_dirichlet(n_clients=args.sys_n_client,
                                                                          beta=args.sys_dataset_dir_alpha,
                                                                          dataset_name=args.sys_dataset,
                                                                          seed=42,
                                                                          aug=args.client_instance_aug)
    else:
        raise NotImplementedError("sys_n_local_class and sys_dataset_dir_alpha are both None")
    logger.info('Clients in Total: %d' % len(trainset_config['users']))

    server = FedSD2CServer(args, trainset_config['users'], epoch=args.server_n_epoch, batch_size=args.server_bs,
                          lr=args.server_lr, momentum=args.server_momentum, num_workers=args.server_n_worker,
                          dataset_id=args.sys_dataset, model_name=args.sys_model,
                          i_seed=args.sys_i_seed)
    server.load_testset(testset)

    teacher_model = load_teacher_model()

    ret_xs, ret_ys = [], []
    psnrs = []
    for client_id in trainset_config['users']:
        client_dict[client_id] = FedClient(args, client_id, dataset_id=args.sys_dataset)
        client_dict[client_id].load_train(trainset_config['user_data'][client_id])
        client_dict[client_id].load_cls_record(cls_record[client_id])

        model = copy.deepcopy(teacher_model)
        # if args.fedsd2c_m_path == "local":
        specified_labels, containing_labels = get_specified_labels(client_dict[client_id],
                                                                   client_dict[client_id].ipc * client_dict[
                                                                       client_id].factor ** 2)
        if args.client_model_root is not None:
            weight = torch.load(os.path.join(args.client_model_root, f"c{client_id}.pt"), map_location="cpu")
            if "CWT" in args.client_model_root:
                weight = weight["model"]
            logger.info(
                "Load Client {} from {}".format(client_id, os.path.join(args.client_model_root, f"c{client_id}.pt")))
            model.load_state_dict(weight)
            model = model.to(client_dict[client_id]._device)
        else:
            logger.info("Client {} local training".format(client_id))
            model, loss = client_dict[client_id].train(model)
        # acc, specified_acc = test_specified(model, testset, 200, client_dict[client_id]._device, containing_labels)
        # logger.info("===============Using Client {} local model - acc: {}===================".format(client_id, acc))
        # acc = test_model(model, testset, 200, client_dict[client_id]._device)
        if print_specified:
            for sl in sorted(list(specified_labels)):
                logger.info("sl: {} - cnt: {} acc: {}".format(sl, specified_acc[sl].cnt, specified_acc[sl].avg))
            for cl in containing_labels:
                if cl not in specified_labels:
                    logger.info("cl: {} - cnt: {} acc: {}".format(cl, specified_acc[cl].cnt, specified_acc[cl].avg))

        if args.client_instance == "coreset":
            ret_x, ret_y, ret_score = client_dict[client_id].coreset_stage(model)
        elif args.client_instance == "random":
            ret_x, ret_y, ret_score = client_dict[client_id].random_stage(model)
        elif args.client_instance == "coreset+featmixup":
            ret_x, ret_y, ret_score = client_dict[client_id].coreset_stage(model)
            ret_x, ret_y = client_dict[client_id].feat_mixup_stage(model)
        elif args.client_instance == "random+featmixup":
            ret_x, ret_y, ret_score = client_dict[client_id].random_stage(model)
            ret_x, ret_y = client_dict[client_id].feat_mixup_stage(model)
        elif args.client_instance == "load_latents":
            dir_path = args.fedsd2c_syn_root
            logger.info(f"Load Synthetic data from {os.path.join(dir_path)}")
            ret_x = torch.load(os.path.join(dir_path, f"client{client_id}_images.pt"))
            ret_y = torch.load(os.path.join(dir_path, f"client{client_id}_labels.pt"))
            ret_z = torch.load(os.path.join(dir_path, f"client{client_id}_latents.pt"))
            ret_score = [0] * len(ret_x)
            syn_x = client_dict[client_id].decode_latents(ret_z)
            psnr = compute_psnr(ret_x, torch.stack(syn_x))
            psnrs.extend(psnr.numpy().tolist())
            ret_x = syn_x
        elif "server" in args.client_instance:
            ret_x, ret_y, ret_score = [], [], []
        else:
            raise NotImplementedError("Not implemented yet.")
        # Server receive client model and data
        server.rec_distill(client_dict[client_id]._id,
                           model,
                           DistilledDataset(ret_x, ret_y, augment),
                           list(specified_labels),
                           ret_score)

        # ret_xs.extend(ret_x)
        # ret_ys.extend(ret_y)
        if not os.path.exists(os.path.join(args.sys_res_root, args.save_name)):
            os.mkdir(os.path.join(args.sys_res_root, args.save_name))
        if args.save_client_model and args.client_model_root is None:
            torch.save(model.state_dict(), os.path.join(args.sys_res_root, args.save_name, f"c{client_id}.pt"))

    # distilled_trainset = DistilledDataset(ret_xs, ret_ys)
    # server.load_distill(distilled_trainset)
    # server.train()
    # PSNR = True
    # if PSNR:
    #     logger.info(f"Mean PSNR: {np.mean(psnrs)} Max PSNR: {np.max(psnrs)}")
    #     return
    if args.fedsd2c_clip_client_data:
        server.clip_client_data()
    if "server" in args.client_instance:
        server.syn_data()
    server.train_distill()


def term_sig_handler(signum, frame):
    print(f'catched singal: {signum}')
    if using_wandb:
        wandb.finish()
    sys.exit()


if __name__ == "__main__":
    signal.signal(signal.SIGTERM, term_sig_handler)  # kill pid
    signal.signal(signal.SIGINT, term_sig_handler)  # ctrl -c
    main()
    if using_wandb:
        wandb.finish()
