
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, MNIST
from tqdm import tqdm
import json
import copy
import time
import os

"""
base_model
"""


class Model_CIFAR10(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class Model_MNIST(nn.Module):
    def __init__(self):
        super(Model_MNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


"""
base data
"""


def load_data(dataset_name):

    if dataset_name == "CIFAR10":
        """Load CIFAR-10 (training and test set)."""
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        trainset = CIFAR10("./CIFAR10", train=True, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
        testset = CIFAR10("./CIFAR10", train=False, download=True, transform=transform)
        testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
        num_examples = {"trainset": len(trainset), "testset": len(testset)}

    if dataset_name == "MNIST":
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
        )
        trainset = MNIST("./data/MNIST", train=True, download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
        testset = MNIST("./data/MNIST", train=False, download=True, transform=transform)
        testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
        num_examples = {"trainset": len(trainset), "testset": len(testset)}

    return trainset, trainloader, testloader, num_examples


def make_client_dataset(dataset, num_clients, batch_size):
    '''
    Split training set into num_clients subsets with the same size
    For example, 50,000 training CIFAR images are divided into subsets with 1,000 images for each client in 50 clients
    '''
    traindata_splits = torch.utils.data.random_split(
        dataset, [int(dataset.data.shape[0] / num_clients) for _ in range(num_clients)]
    )
    train_loaders = [
        torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True)
        for x in traindata_splits
    ]
    # print(len(train_loaders))
    return traindata_splits, train_loaders


"""
base training
"""


def train(model, dataloader, epochs, lr, grad_clip=None, weight_decay=0, momentum=0, device=None):
    """Train the network."""
    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 

    # print(f"Training {epochs}-th epoch(s) w/ {len(dataloader)} batches each")

    # Train the network
    for epoch in range(epochs):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):

            images, labels = data[0].to(device), data[1].to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()

            if grad_clip is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print(
                    "[%d, %5d] loss: %f"
                    % (epoch + 1, i + 1, running_loss / len(dataloader))
                )
                # writer.add_scalar("loss",running_loss / len(dataloader),epoch*999+i)
                running_loss = 0.0
    return model


def test(model, testloader, device):
    """Validate the network on the entire test set."""
    criterion = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    loss = 0.0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss / len(testloader), accuracy


"""
base federated (non-quantized)
"""


def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg


class OFL(object):
    def __init__(self, global_model, num_clients, rounds, batch_size, device):
        self.num_clients = num_clients
        self.rounds = rounds
        self.batch_size = batch_size
        self.device = device
        self.global_model = global_model

        # if dataset == "CIFAR10":
        #     # initialize global model
        #     self.global_model = Model_CIFAR10().to(self.device)

        # if dataset == "MNIST":
        #     self.global_model = Model_MNIST().to(self.device)

        print('Init model:', self.global_model)
        # copy model paras
        self.global_params = self.global_model.state_dict()
        self.local_params = [self.global_params for i in range(self.num_clients)]

    def global_train(self, dataloaders, testloader, learning_rate, client_iters):

        start_time = time.time()

        #  global training
        for r in range(self.rounds):

            # local training
            for client_k in range(self.num_clients):

                local_model = train(
                    model=copy.deepcopy(self.global_model),
                    dataloader=dataloaders[client_k],
                    epochs=client_iters,
                    lr=learning_rate,
                    device=self.device,
                )
                print(
                    "finish round {} of local training at client {}".format(r, client_k)
                )

                # local parameters
                self.local_params[client_k] = local_model.state_dict()

            # aggregation
            global_params = FedAvg(self.local_params)

            # copy weight to net_glob
            self.global_model.load_state_dict(global_params)

            # global model loss
            loss, accu = test(self.global_model, testloader, self.device)
            print('Loss: {:.3f} | Accuracy: {:.3f}'.format(loss, accu))

        print("global_training finished!")
        spend_time = time.time() - start_time

        print("unlearning time = {}".format(spend_time))

    def delete(self, traindata_splits, client_id, delete_ratio):
        # make updated dataset for k
        local_set = traindata_splits[client_id]
        print('target_client:', client_id)
        print('num_samples of target_client:', len(local_set))

        # delete numbers
        delete_num = int(len(local_set) * delete_ratio)
        print("num deleted samples of target_client: {}".format(delete_num))

        permuted_indices = np.random.permutation(len(local_set))

        remain_dateset = torch.utils.data.Subset(
            local_set, permuted_indices[delete_num:]
        )
        delete_dataset = torch.utils.data.Subset(
            local_set, permuted_indices[0:delete_num]
        )

        # prepare for retraining
        traindata_splits[client_id] = remain_dateset

        # new dataloaders
        train_loaders = [
            torch.utils.data.DataLoader(x, batch_size=self.batch_size, shuffle=True)
            for x in traindata_splits
        ]
        # self.global_model = Model_MNIST().to(self.device)

        # reinitialize the model for retraining
        self.global_model = self.global_model.__class__().to(self.device)

        # retraining
        for i in range(len(train_loaders)):
            print('client {}: {} batches'.format(i, len(train_loaders[i])))

        return permuted_indices, traindata_splits, train_loaders
    
    def save_global_model(self, path):
        torch.save(self.global_model.state_dict(), path)
    
    def load_global_model(self, path):
        self.global_model.load_state_dict(torch.load(path))


"""
base quantization
"""


def _Quantize_params(weights, alpla, phase=0.5):
    weights = weights * 1 / alpla
    weights = weights + (phase - 0.5)
    weights = torch.round(weights)
    weights = weights - (phase - 0.5)
    q_weights = weights * alpla
    return q_weights


def qt(model_params, q_alpha):
    # all key in model parameters
    keys = model_params.keys
    for k in keys():
        model_params[k] = _Quantize_params(model_params[k], q_alpha)
    return model_params


class Q_FL(OFL):
    def __init__(self, global_model, num_clients, rounds, batch_size, q_alpha, log_path, device):
        self.num_clients = num_clients
        self.rounds = rounds
        self.batch_size = batch_size
        self.alpha = q_alpha
        self.ofl_params = [i for i in range(rounds + 1)]
        self.q_params = [i for i in range(rounds + 1)]
        self.device = device
        self.log_path = log_path

        if os.path.exists(self.log_path):
            print(f'Found existing results in {self.log_path}. Will delete it')
            os.remove(self.log_path)

        # for save all models on clients
        self.clients_models_all = []

        self.global_model = global_model
        # copy model paras
        self.global_params = self.global_model.state_dict()
        self.local_params = {i: self.global_params for i in range(self.num_clients)}

        self.ofl_params[0] = copy.deepcopy(self.global_params)
        self.q_params[0] = copy.deepcopy(self.global_params)

    def meta_data(self):
        self.ofl_params = []
        self.q_params = []
        return self.ofl_params, self.q_params

    def global_train(self, dataloaders, testloader, client_iters, learning_rate, grad_clip=None, weight_decay=0.0, momentum=0.0):
        start_time = time.time()

        #  global training
        for r in tqdm(range(self.rounds), desc='Train'):

            # save models of round k
            client_models_round = []

            # local training
            # for client_k in range(self.num_clients):
            for client_k in dataloaders:

                local_model = train(
                    model=copy.deepcopy(self.global_model),
                    dataloader=dataloaders[client_k],
                    lr=learning_rate,
                    grad_clip=grad_clip,
                    epochs=client_iters,
                    weight_decay=weight_decay,
                    momentum=momentum,
                    device=self.device,
                )

                # save model of client ks in this round
                client_models_round.append(local_model)

                # print(
                #     "finish round {} of local training at client {}".format(r, client_k)
                # )

                # local parameters list of current round
                self.local_params[client_k] = local_model.state_dict()

            # save i-th round
            self.clients_models_all.append(client_models_round)

            # aggregation
            cur_global_params = FedAvg(list(self.local_params.values()))

            self.ofl_params[r + 1] = copy.deepcopy(cur_global_params)

            # needs Quantize operation
            cur_q_params = qt(cur_global_params, q_alpha=self.alpha)
            # save models
            self.q_params[r + 1] = cur_q_params

            # copy weight to net_glob
            self.global_model.load_state_dict(cur_q_params)

            # global model loss
            loss, accu = test(self.global_model, testloader, self.device)
            print('Loss: {:.3f} | Acc: {:.3f}'.format(loss, accu))
            if os.path.exists(self.log_path):
                with open(self.log_path, 'r') as f:
                    curr_logs = json.load(f)
            else:
                curr_logs = []
            curr_logs.append({'test_acc': accu, 'test_loss': loss})
            with open(self.log_path, 'w') as f:
                json.dump(curr_logs, f, indent=2)
            
            # print(
            #     "=================================",
            #     loss,
            #     accu,
            #     "=================================",
            # )
            # writer.add_scalar("loss", loss, r)
            # writer.add_scalar("accuracy", accu, r)
            # loss_ls.append(loss)
            # accuracy_ls.append(accu)

        print(
            "global_training finished! spend time {}s".format(time.time() - start_time)
        )

    """This is delete code"""

    def delete(self, traindata_splits, client_id, delete_ratio, testloader):

        # make updated dataset for k
        local_set = traindata_splits[client_id]

        print("local_set_size=", len(local_set))

        # delete numbers
        delete_num = int(len(local_set) * delete_ratio)

        delete_indices = np.random.permutation(len(local_set))

        local_set_new = torch.utils.data.Subset(local_set, delete_indices[delete_num:])

        delete_data_subset = torch.utils.data.Subset(
            local_set, delete_indices[0:delete_num]
        )

        print("local_set_new_size=", len(local_set_new))
        print("delete_data_size=", len(delete_data_subset))

        deleted_data = form_data_for_deletion(delete_data_subset)

        traindata_splits[client_id] = local_set_new

        # new dataloaders
        # train_loaders = [
        #     torch.utils.data.DataLoader(x, batch_size=self.batch_size, shuffle=True)
        #     for x in traindata_splits
        # ]
        train_loaders = {c_idx: torch.utils.data.DataLoader(x, batch_size=self.batch_size, shuffle=True)
            for c_idx, x in traindata_splits.items() if len(x) > 0}
        if len(local_set_new) == 0:     # one client is completely removed
            self.num_clients = len(train_loaders)
            self.local_params.pop(client_id)

        print('Number of remaining clients:', self.num_clients)

        # local_loader_new = train_loaders[client_id]

        start_time = time.time()

        for r in range(self.rounds):
            # compute local models of the client, [0,...T]
            # check equal?

            if check_equality(
                r,
                client_id,
                deleted_data,
                self.q_params,
                self.clients_models_all,
                delete_ratio,
                q_alpha=self.alpha
            ):
                print("========= round {} qt models is stable=================".format(r))
                retrain = False
            else:
                # retraining needed
                print(
                    "=============we need retrain from round {}======================".format(
                        r
                    )
                )
                # for i in range(len(train_loaders)):
                #     print(len(train_loaders[i]))
                # for c_idx in train_loaders:
                #     print('Local batches:', len(train_loaders[c_idx]))

                # # retraining Q-FL, call self.global_trainig method
                # self.rounds = self.rounds - r
                # print("need to retrain {} rounds".format(self.rounds))

                # Debugged: load parameters from the round before the model becomes unstable
                self.global_model.load_state_dict(self.q_params[r])
                # self.global_model.load_state_dict(self.q_params[r + 1])

                # self.global_train(train_loaders, testloader, 1)

                retrain = True
                break
            # break

        print("unlearing finished, spend time={}s".format(time.time() - start_time))
        return retrain, delete_data_subset, traindata_splits, train_loaders

    """This is delete code"""

    def _Quantize_params(self, weights, alpla, phase=0.5):
        weights = weights * 1 / alpla
        weights = weights + (phase - 0.5)
        weights = torch.round(weights)
        weights = weights - (phase - 0.5)
        q_weights = weights * alpla
        return q_weights
    
    def save_model(self, path):
        all_state_dict = []
        for client_model_round in self.clients_models_all:
            round_state_dict = []
            for client_model in client_model_round:
                round_state_dict.append(client_model.state_dict())
            all_state_dict.append(round_state_dict)

        data = {
            'global_model': self.global_model.state_dict(),
            'clients_models_all': all_state_dict,
            'q_params': self.q_params,
        }

        torch.save(data, path)

    def load_model(self, path):
        data = torch.load(path)

        self.global_model.load_state_dict(data['global_model'])
        self.q_params = data['q_params']

        print('Number of cached rounds:', len(data['clients_models_all']))
        for round_state_dict in data['clients_models_all']:
            client_model_round = []
            for state_dict in round_state_dict:
                client_model = copy.deepcopy(self.global_model)
                client_model.load_state_dict(state_dict)
                client_model_round.append(client_model)
            self.clients_models_all.append(client_model_round)
    
"""
utils
"""


def check_equality(
    r, client_id, deleted_data, q_global_params_ls, all_clients_models, delete_ratio, q_alpha 
):
    """
    method to update client local models and compare with qt models
    :param r:
    :param client_id:
    :param deleted_data:
    :param global_params:
    :param all_clients_models:
    :return:
    """
    # current round global Qt params at server
    q_params_r = q_global_params_ls[r + 1]

    # client k's original model at local
    old_local_model = all_clients_models[r][client_id]

    # update local model with deleted data
    images, labels = deleted_data
    outputs = old_local_model(images)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        old_local_model.parameters(), lr=0.001 * delete_ratio, momentum=0.9
    )
    loss_remain = -criterion(outputs, labels)

    # gradient ascent
    loss_remain.backward()
    optimizer.step()
    optimizer.zero_grad()

    # update local model parameters list
    all_clients_models[r][client_id] = old_local_model

    # collect r-th round all client models
    local_param_list_r = [model.state_dict() for model in all_clients_models[r]]

    # call FedAvg to aggregation
    new_global_params_at_r = FedAvg(local_param_list_r)

    return compare_params(q_params_r, qt(new_global_params_at_r, q_alpha=q_alpha))


def form_data_for_deletion(subdataset):

    l = len(subdataset)
    c, h, w = subdataset[0][0].shape

    images = torch.ones(l, c, h, w).to("cuda")
    labels = torch.LongTensor(l).to("cuda")

    for i in range(l):
        images[i] = subdataset[i][0]
        labels[i] = subdataset[i][1]

    print(images.shape)

    deleted_data = (images, labels)
    return deleted_data


def compare_params(model_param_1, model_param_2):
    # all key in model parameters
    keys = model_param_1.keys
    for k in keys():
        # import pdb; pdb.set_trace()
        num_matches = torch.eq(model_param_1[k], model_param_2[k]).sum().item()
        num_not_matches = torch.ne(model_param_1[k], model_param_2[k]).sum().item()
        print("Match:", num_matches)
        print("Not match:", num_not_matches)
        if not model_param_1[k].equal(model_param_2[k]):
            return False
    return True


"""
metrics
"""


def SPAE(w_u, w, metric, testloader, deleted_data_ls2set, device):
    if metric == "eff":
        test_data = testloader
        accu_w_u = test(w_u, test_data, device)[1]
        # accu_w_u=0.89
        print("Unlearned model accuracy on test data=", accu_w_u)

        accu_w = test(w, test_data, device)[1]
        print("Retraining model accuracy on test data=", accu_w)

    else:
        test_data = torch.utils.data.DataLoader(
            deleted_data_ls2set, batch_size=32, shuffle=True
        )
        accu_w_u = test(w_u, test_data, device)[1]
        # accu_w_u=0.87
        print("Unlearned model accuracy on U=", accu_w_u)

        accu_w = test(w, test_data, device)[1]
        print("Retraining model accuracy on U=", accu_w)

    return abs(accu_w_u - accu_w) / (accu_w_u + accu_w)
