import os
import gc
import time
import argparse
import numpy as np
import random
import torch
import torch.cuda as cuda
import torch.backends.cudnn as cudnn
from copy import deepcopy
import torch.optim as optim
import torch.nn.functional as F

from utiles import *
from data_prepration.data import get_partitioned_data
from dnr import Masking


class Server():
    def __init__(self, args):
        self.args = args
        self.frac = args.frac
        self.accuracies = []
        self.comm_round = args.fl_rounds
        self.num_clients = self.args.num_users
        self.key_path = self.args.key_path
        self.global_mask = None
        self.num_participants = max(int(self.frac * self.num_clients), 1)
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
        self.model = get_model(self.args)
        self.hetero_densities = hetero_62 if self.args.dataset == FEMNIST else hetero
        self.initialize()

    def initialize(self):
        self.all_test, self.clients_train_data, self.client_weight = get_partitioned_data(self.args)
        self.clients = [Client(self.args, self.clients_train_data[i]) for i in range(self.num_clients)]
        if self.args.experiment_type != DENSE:
            if self.args.experiment_type in [NST, PDST]:
                self.global_mask = get_rand_mask(self.args.density, self.key_path)
            else:
                self.train_global_mask()
            self.model = apply_mask(self.model, self.global_mask)
            if self.args.experiment_type in [NST, PDST]:
                [client.set_masks(deepcopy(self.global_mask)) for client in self.clients]

    def get_weight(self):
        return deepcopy(self.model.cpu().state_dict())

    def set_global_weight(self, weight):
        self.model.load_state_dict(deepcopy(weight))

    def select_clients(self, num_participant):
        participant = np.random.choice(range(self.num_clients), num_participant, replace=False)
        return [self.clients[idx] for idx in participant], [self.client_weight[i] for i in participant]

    def train_global_mask(self):
        local_epoch, exp_type = self.args.local_epoch, self.args.experiment_type
        self.args.local_epoch, self.args.experiment_type = self.args.init_epoch, NST
        mask = get_rand_mask(self.args.density, self.key_path)
        clients, _ = self.select_clients(max(int(self.args.init_frac * self.num_clients), 1))
        clients_mask = []
        for client in clients:
            client.set_masks(deepcopy(mask))
            _, output_mask = client.update(self.get_weight())
            clients_mask.append(output_mask)
        self.args.local_epoch = local_epoch
        self.args.experiment_type = PDST if exp_type == SPDST else exp_type
        self.global_mask = get_avg_mask(clients_mask, self.args.density, self.key_path)
        self.model = get_model(self.args)

    def update_global_mask(self, masks):
        if self.args.experiment_type == JMWST:
            self.global_mask = next_mask(masks, self.get_weight(), self.args.density, self.key_path, self.args.subsample)
        if self.args.hetero_client:
            self.hetero_mask = {d: next_mask(masks, self.get_weight(), d * self.args.density, self.key_path, self.args.subsample) for d in density_levels}
            self.hetero_mask[1] = self.global_mask

    def start_federated_learning(self):
        for self.round_idx in range(self.comm_round):
            weights, clients_mask = [], []
            participants, client_size = self.select_clients(self.num_participants)
            for i, client in enumerate(participants):
                self.set_round_param(client, i)
                w, mask = client.update(self.get_weight())
                weights.append(w)
                clients_mask.append(mask)
            new_weights = aggregate(weights, clients_mask, client_size, self.args.update_mode, self.args.dataset)
            self.set_global_weight(new_weights)
            self.test_inference()
            self.args.lr = self.args.lr_start * np.exp(self.round_idx / self.comm_round * self.args.lr_gamma)
            self.update_global_mask(clients_mask)
        print(f'Highest accuracy {max(self.accuracies)}')

    def set_round_param(self, client, i):
        if self.args.hetero_client and self.round_idx > 0:
            h_density = self.hetero_densities[i]
            client.client_density = h_density * self.args.density
            client.set_masks(deepcopy(self.hetero_mask[h_density]))
        elif self.args.experiment_type == JMWST:
            client.set_masks(deepcopy(self.global_mask))

    def test_inference(self):
        model = deepcopy(self.model)
        density = get_density(model.state_dict())
        model.eval()
        model.to(self.device)
        correct, total = 0, len(self.all_test.dataset)
        with torch.no_grad():
            for images, labels in self.all_test:
                images, labels = images.to(self.device), labels.to(self.device)
                pred_labels = model(images)
                correct += (pred_labels.argmax(1) == labels).type(torch.float).sum().item()
        accuracy = 100 * correct / total
        self.accuracies.append(accuracy)
        print(f'ROUND {self.round_idx} | Density {density} | Accuracy {accuracy}')


class Client():
    def __init__(self, args, train_loader):
        self.args = args
        self.train_criterion = F.cross_entropy
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
        self.trainloader = train_loader
        self.predefined_mask = None
        self.total_num_batch_size = len(train_loader)
        self.client_density = self.args.density

    def set_masks(self, masks, global_mask=None):
        self.predefined_mask = masks

    def get_optimizer(self, model):
        return optim.SGD(model.parameters(), lr=self.args.lr)

    def get_parameters(self, model):
        model.train()
        model.cpu()
        weights = deepcopy(model.state_dict())
        return weights

    def get_mask(self, model, optimizer):
        T_max = len(self.trainloader) * (self.args.local_epoch)
        mask = Masking(optimizer, T_max, self.args.prune_rate)
        mask.add_module(model, density=self.client_density, init_masks=self.predefined_mask)
        return mask

    def init_model(self, weights):
        model = get_model(self.args)
        model.load_state_dict(weights, strict=True)
        model.to(self.device)
        return model

    def update(self, weights):
        model = self.init_model(deepcopy(weights))
        optimizer = self.get_optimizer(model)

        if self.args.hetero_client:
            to_dict(self.predefined_mask, self.device)
            apply_mask(model, self.predefined_mask)

        if self.args.experiment_type == DENSE:
            assert self.predefined_mask is None
            self.update_model_mask(model, optimizer, self.args.local_epoch, mask=None)

        elif self.args.experiment_type == PDST:
            assert self.predefined_mask is not None
            to_dict(self.predefined_mask, self.device)
            self.update_model_mask(model, optimizer, self.args.local_epoch, mask=None)
            to_dict(self.predefined_mask, 'cpu')

        else:
            to_dict(self.predefined_mask, self.device)
            mask = self.get_mask(model, optimizer)
            self.update_model_mask(model, optimizer, self.args.local_epoch, mask=mask)
            self.predefined_mask = deepcopy(mask.masks)
            to_dict(self.predefined_mask, 'cpu')

        current_weight = self.get_parameters(model)
        self.client_density = self.args.density
        return current_weight, deepcopy(self.predefined_mask)

    def update_model_mask(self, model, optimizer, total_epochs, mask):
        for epoch in range(1, total_epochs + 1):
            self.train(epoch, model, optimizer, mask)
            if mask is not None and epoch < total_epochs:
                mask.at_end_of_epoch()

    def train(self, epoch, model, optimizer, mask):
        model.train()
        for batch_idx, (data, target) in enumerate(self.trainloader):
            data, target = data.to(self.device), target.to(self.device)
            optimizer.zero_grad()
            output = model(data)
            ce_loss = self.train_criterion(output, target)
            ce_loss.backward()
            if mask is not None:
                mask.step()
                if self.args.local_epoch == 1 and batch_idx == self.total_num_batch_size // 2:
                    mask.at_end_of_epoch()
            else:
                optimizer.step()
                if self.predefined_mask is not None:
                    model = apply_mask(model, self.predefined_mask)


parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default='0')
parser.add_argument('--path', type=str, default='cifar10')
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--experiment_type', type=str, choices=experiment_choices, default=SPDST)
parser.add_argument('--hetero_client', type=str, default=FALSE, choices=true_false)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--alpha', type=float, default=1)
parser.add_argument('--local_bs', type=int, default=32)
parser.add_argument('--lr_start', type=float, default=0.1)
parser.add_argument('--lr_end', type=float, default=0.001)
parser.add_argument('--local_epoch', type=int, default=1)
parser.add_argument('--trackBN', type=str, default=TRUE, choices=true_false)
parser.add_argument('--model_type', type=str, default=RESNET18, choices=model_choices)
parser.add_argument('--dataset', type=str, default=CIFAR10, choices=[CIFAR10, MNIST])
parser.add_argument('--init_frac', type=float, default=0.1)
parser.add_argument('--init_epoch', type=int, default=10)
parser.add_argument('--fl_rounds', type=int, default=600)
parser.add_argument('--update_mode', type=int, default=0, choices=[0, 1], help="0: fedAvg, 1:masked_FedAvg")
parser.add_argument('--frac', type=float, default=0.1)
parser.add_argument('--num_users', type=int, default=100)
parser.add_argument('--density', type=float, default=0.1)
parser.add_argument('--prune_rate', type=float, default=0.25)
parser.add_argument('--conv_only', type=str, default=TRUE, choices=true_false)
parser.add_argument('--subsample', type=str, default='avg', choices=['aggregate', 'avg'])


if __name__ == "__main__":
    args = parser.parse_args()
    seed = args.seed
    if args.gpu is not None and torch.cuda.is_available():
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.set_num_threads(4)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    cuda.manual_seed_all(seed)
    cuda.empty_cache()
    gc.collect()
    cudnn.deterministic = True
    cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    args.trackBN = args.trackBN == TRUE
    args.hetero_client = args.hetero_client == TRUE
    args.conv_only = args.conv_only == TRUE
    args.redistribution = 'magnitude'
    args.growth = 'random'
    args.prune = 'magnitude'
    data_name = args.dataset
    args.num_classes = 10
    args.update_mode = 0 if args.experiment_type in [DENSE, PDST] else args.update_mode
    if args.hetero_client:
        assert args.experiment_type in [JMWST, SPDST]
        assert args.subsample != 'avg'
    if args.model_type == RESNET18:
        assert args.conv_only
    linear = '_linear' if not args.conv_only else ''
    args.lr_gamma = np.log(args.lr_end / args.lr_start)
    args.key_path = f'{args.model_type}{linear}_{data_name}_mask_layers.pt'
    name = f'{args.experiment_type}|{args.dataset}|{args.density}|{args.alpha}{"_hetero" if args.hetero_client else ""}|{seed}|{int(time.time())}'
    server = Server(args)
    server.start_federated_learning()
