import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.nn.functional as F    

import numpy as np
import random
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm

from scipy import io as spio

from scipy.stats import norm
import wandb
from rdp_accountant import *
import argparse
import os

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--clients_number_per_round', default = 100, type=int)
parser.add_argument('--local_update_step_number', default = 1, type=int)
parser.add_argument('--num_rounds', default = 500, type=int)
parser.add_argument('--lr_client', default = 0.05, type=float)
parser.add_argument('--lr_server', default = 1, type=float)
parser.add_argument('--local_bach_size', default = 32, type=int)
parser.add_argument('--noise_multiplier_sigma', default = 0, type=float)
parser.add_argument('--max_grad_norm_C', default = 0.01, type=float)
parser.add_argument('--use_sign', action="store_true")
parser.add_argument('--OneEpoch', action="store_true")
parser.add_argument('--wandb', action="store_true")
parser.add_argument('--momentum', default = 0, type=float)
parser.add_argument('--mark', default = "",type=str)

args = parser.parse_args()
# define paths
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# FL
clients_number_per_round = args.clients_number_per_round # total: 1000
local_update_step_number = args.local_update_step_number # FedSGD
num_rounds = args.num_rounds 

# local model updating
lr_client = args.lr_client # learning rate of client
lr_server = args.lr_server # learning rate of server
local_bach_size = args.local_bach_size
momentum = args.momentum

# DP noise
noise_multiplier_sigma = args.noise_multiplier_sigma
max_grad_norm_C = args.max_grad_norm_C
use_sign = args.use_sign

if args.OneEpoch:
    local_update_step_number = 1

lmbds = np.arange(1.1,1000,0.5)
total_clients = 3579
if noise_multiplier_sigma > 0:
    rdp = compute_rdp(clients_number_per_round/total_clients, noise_multiplier_sigma, num_rounds, lmbds)
    if use_sign:
        eps_spent = get_privacy_spent(lmbds, rdp * 2/np.pi, target_delta=1/total_clients)
    else:
        eps_spent = get_privacy_spent(lmbds, rdp, target_delta=1/total_clients)
else:
    eps_spent = np.inf
    
run_name = args.mark+"eps:{}_clients:{}_lsteps:{}_rounds:{}_lrc:{}_lrs:{}_bz:{}_noise:{}_clip:{}_sign:{}_OneEpoch:{}_momentum:{}".format(eps_spent, clients_number_per_round,
                                                                                          local_update_step_number,
                                                                                          num_rounds,
                                                                                          lr_client, lr_server,
                                                                                          local_bach_size,
                                                                                          noise_multiplier_sigma,
                                                                                          max_grad_norm_C,
                                                                                          use_sign,args.OneEpoch,momentum)

print(run_name)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
if args.wandb:
    wandb.init(project="DP-FL", entity="t1773420638",name=run_name)

class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class NeuralNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, 4, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        # x of shape [B, 1, 28, 28]
        x = F.relu(self.conv1(x))  # -> [B, 16, 14, 14]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 16, 13, 13]
        x = F.relu(self.conv2(x))  # -> [B, 32, 5, 5]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 32, 4, 4]
        x = x.view(-1, 32 * 4 * 4)  # -> [B, 512]
        x = F.relu(self.fc1(x))  # -> [B, 32]
        x = self.fc2(x)  # -> [B, 10]
        return x

    def name(self):
        return "NeuralNet"

# address
emnist = spio.loadmat("data/EMNIST/emnist-digits.mat")

# ------ training images ------ #
train_images = emnist["dataset"][0][0][0][0][0][0]
train_images = train_images.astype(np.float32)
train_images /= 255

# ------ training labels ------ #
train_labels = emnist["dataset"][0][0][0][0][0][1].reshape(240000).tolist()

# ------ test images ------ #
test_images = emnist["dataset"][0][0][1][0][0][0]
test_images = test_images.astype(np.float32)
test_images /= 255

# ------ test labels ------ #
test_labels = emnist["dataset"][0][0][1][0][0][1].reshape(40000).tolist()

# ------ reshape using matlab order ------ #
train_images = train_images.reshape(train_images.shape[0], 1, 28, 28, order="A")
test_images = test_images.reshape(test_images.shape[0], 1, 28, 28, order="A")

# calculate mean and standard deviation ------ #
mean_px = train_images.mean().astype(np.float32)
std_px = train_images.std().astype(np.float32)

# normalize
train_images = (train_images-mean_px)/std_px
test_images = (test_images-mean_px)/std_px

train_dataset = list(map(list, zip(torch.tensor(train_images), train_labels)))
test_dataset = list(map(list, zip(torch.tensor(test_images), test_labels)))

# DataLoader
train_all_loader = DataLoader(train_dataset, batch_size= 1280, shuffle=False)
test_all_loader = DataLoader(test_dataset, batch_size=1280, shuffle=False)



train_writers = emnist["dataset"][0][0][0][0][0][2].reshape(240000)
train_writers_set = np.unique(train_writers)

# ------ random seed ------ #
# seed = int(123123)
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
# np.random.seed(seed)  # Numpy module.
# random.seed(seed)  # Python random module.
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

if os.path.exists("results/"+run_name+".txt"):
    print("Run already exists:")
else:
    print("start training:")
    global_model = NeuralNet().to(device)
    model_diff_average_momentum_dict = {n:torch.zeros_like(p) for n,p in global_model.named_parameters()}
    with open("results/"+run_name+".txt","w") as f:
        for round in range(num_rounds):
            choosed_clients = np.random.choice(train_writers_set,replace=False,size=clients_number_per_round)
            model_diff_average_dict = {n:torch.zeros_like(p) for n,p in global_model.named_parameters()}
            global_state_dict = copy.deepcopy(global_model.state_dict())
            for client_ind, client in enumerate(choosed_clients):
                client_writer_indexes = np.where(train_writers==client)[0]
                train_dataset_each_client = DatasetSplit(train_dataset, client_writer_indexes)
                    
                # mini-batch samples
                client_train_loader = torch.utils.data.DataLoader(train_dataset_each_client,
                                            batch_size = min([local_bach_size, len(train_dataset_each_client)]), shuffle=True) 
                
                local_model = NeuralNet().to(device)
                local_model.load_state_dict(global_state_dict)   
                local_optimizer = torch.optim.SGD(local_model.parameters(),lr_client,momentum = 0,weight_decay = 0)
                # local update
                local_step = 0
                while local_step < local_update_step_number:
                    local_model.train()
                    for i, (images, target) in enumerate(client_train_loader): 
                        local_step += 1
                        
                        input_var = images.to(device)
                        target_var = target.to(device)
                        
                        # compute output
                        output = local_model(input_var)
                        loss = criterion(output, target_var)

                        # compute gradient and do SGD step
                        local_optimizer.zero_grad()
                        loss.backward()
                        local_optimizer.step()

                        if (local_step >= local_update_step_number) and (not args.OneEpoch):
                            break
                        
                local_model_state_dict = local_model.state_dict() 
                local_model_clip = torch.sqrt(sum(torch.norm(local_model_state_dict[n] - p.data)**2 for n,p in global_model.named_parameters()))
                with torch.no_grad():
                    for n,p in global_model.named_parameters():    
                        temp_diff = (local_model_state_dict[n] - p.data) / (max([1,local_model_clip/max_grad_norm_C]))
                        temp_diff += torch.Tensor(temp_diff.size()).normal_(0, noise_multiplier_sigma * max_grad_norm_C).cuda()
                        if use_sign:
                            model_diff_average_dict[n] += torch.sign(temp_diff)
                        else:
                            model_diff_average_dict[n] += temp_diff
                
                # print("round:", round,"client", client_ind)
                
            # average
            for n,p in global_model.named_parameters():    
                model_diff_average_dict[n] /= clients_number_per_round
            
            # update momentum
            for n,p in global_model.named_parameters():
                model_diff_average_momentum_dict[n] = momentum * model_diff_average_momentum_dict[n] + model_diff_average_dict[n]
            
            # global model step
            with torch.no_grad():
                for n,p in global_model.named_parameters():
                    p.data += lr_server * model_diff_average_momentum_dict[n]  

            
            #check
            with torch.no_grad():
                # train acc/loss
                train_loss, train_total, train_correct = 0.0, 0.0, 0.0
                for batch_idx, (images1, labels1) in enumerate(train_all_loader):
                    images, labels = images1.to(device), labels1.to(device)

                    # Inference
                    outputs = global_model(images)
                    batch_loss = criterion(outputs, labels)
                    train_loss += batch_loss.item() * len(labels)
                    train_total += len(labels)
                    
                    _, pred_labels = torch.max(outputs, 1)
                    pred_labels = pred_labels.view(-1)
                    train_correct += torch.sum(torch.eq(pred_labels, labels)).item()
                train_loss /= train_total
                train_correct /= train_total
                
                # test acc/loss
                test_loss, test_total, test_correct = 0.0, 0.0, 0.0
                for batch_idx, (images1, labels1) in enumerate(test_all_loader):
                    images, labels = images1.to(device), labels1.to(device)

                    # Inference
                    outputs = global_model(images)
                    batch_loss = criterion(outputs, labels)
                    test_loss += batch_loss.item() * len(labels)
                    test_total += len(labels)
                    
                    _, pred_labels = torch.max(outputs, 1)
                    pred_labels = pred_labels.view(-1)
                    test_correct += torch.sum(torch.eq(pred_labels, labels)).item()
                test_loss /= test_total
                test_correct /= test_total
                
            print("round:", round, "train loss:", train_loss, "train acc", train_correct, "test loss:", test_loss, "test acc", test_correct)
            if args.wandb:
                wandb.log({"train loss": train_loss,
                        "train acc": train_correct, 
                        "test loss": test_loss,
                        "test acc": test_correct})
            stats = "{},{},{},{},{}\n".format(round,train_loss,train_correct,test_loss,test_correct)
            f.write(stats)
            
    with open("total_results.txt","a") as f:
        f.write(run_name+"->"+str(test_correct)+"\n")
        