import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.nn.functional as F    

import numpy as np
import random
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm

from scipy import io as spio

from scipy.stats import norm
import wandb
from rdp_accountant import *
import argparse
import os

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument("--gpu")
parser.add_argument("--run_name")
parser.add_argument("--rerun", type = bool)
parser.add_argument("--rep", type = int)


parser.add_argument('--local_update_step_number', default = 1, type=int)
parser.add_argument('--num_rounds', default = 150, type=int)
parser.add_argument('--lr_server', default = 0.1, type=float)
parser.add_argument('--local_bach_size', default = 16, type=int)
parser.add_argument('--noise_multiplier_sigma', default = 0, type=float)
parser.add_argument('--alg',type=str, default = "SGD", choices=["SGD","SignSGD","Sto-SignSGD","EF-SignSGD","SignSGD-Uni","QSGD"])
parser.add_argument('--wandb', action="store_true") 
parser.add_argument('--client_momentum', default = 0, type=float)
parser.add_argument('--mark', default = "",type=str)
parser.add_argument('--qlevel', default = 1,type=int)

args = parser.parse_args()
# define paths
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# FL
num_rounds = args.num_rounds 

# local model updating
lr_server = args.lr_server # learning rate of server
local_batch_size = args.local_bach_size
client_momentum = args.client_momentum

# DP noise
noise_multiplier_sigma = args.noise_multiplier_sigma


lmbds = np.arange(1.1,1000,0.5)
total_clients = 3579
    
run_name = args.run_name + args.mark + "_MNIST_alg:{}_rounds:{}_lrs:{}_bz:{}_noise:{}_momentum:{}".format( args.alg,
                                                                                          num_rounds,
                                                                                           lr_server,
                                                                                          local_batch_size,
                                                                                          noise_multiplier_sigma,
                                                                                          client_momentum)

if args.alg == "QSGD":
    run_name += "_qlevel:{}".format(args.qlevel)

print(run_name)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
if args.wandb:
    wandb.init(project="DP-FL", entity="t1773420638",name=run_name)

class NeuralNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, 4, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        # x of shape [B, 1, 28, 28]
        x = F.relu(self.conv1(x))  # -> [B, 16, 14, 14]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 16, 13, 13]
        x = F.relu(self.conv2(x))  # -> [B, 32, 5, 5]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 32, 4, 4]
        x = x.view(-1, 32 * 4 * 4)  # -> [B, 512]
        x = F.relu(self.fc1(x))  # -> [B, 32]
        x = self.fc2(x)  # -> [B, 10]
        return x

    def name(self):
        return "NeuralNet"

def mnist_noniid(dataset):
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    # 60,000 training imgs
    dict_users = {}
    labels = dataset.train_labels.numpy()
    # divide and assign
    for i in range(10):
        dict_users[i] = np.where(labels==i)[0]
        
    return dict_users


def index_generator(batch_size,class_index):
    data_len = len(class_index)
    perm_indexes = torch.randperm(data_len)
    num_round = data_len // batch_size
    for ii in range(num_round): 
        yield class_index[perm_indexes[(ii*batch_size):((ii+1)*batch_size)]]
    if (num_round * batch_size) < data_len:
        yield class_index[perm_indexes[(num_round * batch_size):]]
        
def Stochastic_Quantization(scale, vec, qlevel):
    vec_normalized = torch.abs(vec / scale)
    vec_sign = torch.sign(vec)
    interval = 1/qlevel
    
    vec_level = torch.div(vec_normalized, interval, rounding_mode='trunc')
    res = vec_normalized % interval
    
    vec_quantized = scale * vec_sign * (vec_level + ((torch.rand_like(res) * interval) < res).float()) * interval
    
    return vec_quantized
    
    
        

data_dir = './data'
apply_transform = transforms.Compose([
         transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform)
train_all_loader = DataLoader(train_dataset, batch_size= 1000,
                            shuffle=False)
test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                      transform=apply_transform)
test_all_loader = DataLoader(test_dataset, batch_size=1000,
                            shuffle=False)

client_indexes = mnist_noniid(train_dataset)

train_data_X = torch.cat([dt[0] for dt in train_dataset]).reshape(60000,1, 28, 28)
train_data_y = torch.tensor(train_dataset.targets,dtype=torch.long)

if os.path.exists("results/"+run_name+".txt") and (not args.rerun):
    print("Run already exists:")
else:
    print("start training:")
    global_model = NeuralNet().to(device)
    optimizer = torch.optim.SGD(global_model.parameters(), lr=lr_server, momentum=0)
    clients_momentum_dict_list = [{n:torch.zeros_like(p) for n,p in global_model.named_parameters()} for _ in range(10)]
    if args.alg == "EF-SignSGD":
        clients_EF_dict_list = [{n:torch.zeros_like(p) for n,p in global_model.named_parameters()} for _ in range(10)]
    
    local_index_generator = {cl:iter(index_generator(local_batch_size, client_indexes[cl])) for cl in range(10)}
    with open("results/"+run_name+".txt","w") as f:
        for round in range(num_rounds):
            choosed_clients = np.arange(10)
            average_grad_dict = {n:torch.zeros_like(p) for n,p in global_model.named_parameters()}
            for client_ind, client in enumerate(choosed_clients):
                try:
                    batch_idx = next(local_index_generator[client])
                except:
                    local_index_generator[client] = iter(index_generator(local_batch_size, client_indexes[client]))
                    batch_idx = next(local_index_generator[client])
                
                
                images, target =  train_data_X[batch_idx], train_data_y[batch_idx]
                
                input_var = images.to(device)
                target_var = target.to(device)
                
                # compute output
                output = global_model(input_var)
                loss = criterion(output, target_var)

                # compute gradient and do SGD step
                
                loss.backward()

             
                with torch.no_grad():
                    if args.alg == "EF-SignSGD":
                        one_norm = 0
                        var_num = 0
                        temp_grad = {n:torch.zeros_like(p) for n,p in global_model.named_parameters()}
                    
                    if (args.alg == "Sto-SignSGD") or (args.alg == "QSGD"):
                        two_norm = 0
                    
                    for n,p in global_model.named_parameters():    
                        clients_momentum_dict_list[client][n] = client_momentum * clients_momentum_dict_list[client][n] +(1-client_momentum) * p.grad.data
                        
                        if args.alg == "SGD":
                            average_grad_dict[n] += clients_momentum_dict_list[client][n]
                        
                        elif args.alg == "SignSGD":
                            average_grad_dict[n] += torch.sign(clients_momentum_dict_list[client][n] + torch.Tensor(p.size()).normal_(0, noise_multiplier_sigma).cuda())
                            
                        elif args.alg == "SignSGD-Uni":
                            average_grad_dict[n] += torch.sign(clients_momentum_dict_list[client][n] + noise_multiplier_sigma *  (torch.rand_like(clients_momentum_dict_list[client][n])* 2 - 1).cuda())
                        
                        elif (args.alg == "Sto-SignSGD") or (args.alg == "QSGD"):
                            two_norm += torch.pow(torch.norm(clients_momentum_dict_list[client][n],p=2),2) 
                            
                        elif args.alg == "EF-SignSGD":
                            temp_grad[n] = clients_momentum_dict_list[client][n] + clients_EF_dict_list[client][n]
                            one_norm  += torch.norm(temp_grad[n], p=1)
                            var_num += p.numel()
                
                        else:
                            pass
                            
                            
                    if args.alg == "EF-SignSGD":
                        for n,p in global_model.named_parameters():
                            temp_sign = one_norm * torch.sign(temp_grad[n]) / var_num
                            clients_EF_dict_list[client][n] = temp_grad[n] - temp_sign
                            average_grad_dict[n] += temp_sign
                            
                    if (args.alg == "Sto-SignSGD") or (args.alg == "QSGD"):
                        two_norm = torch.sqrt(two_norm)
                        
                        if args.alg == "Sto-SignSGD":
                            for n,p in global_model.named_parameters():
                                temp_prob = 1/2 +  torch.abs(clients_momentum_dict_list[client][n]) / two_norm/2
                                temp_grad = torch.sign(clients_momentum_dict_list[client][n])
                                flip = torch.rand_like(temp_grad)
                                temp_grad[flip>temp_prob] = -temp_grad[flip>temp_prob]
                                average_grad_dict[n] += temp_grad
                                
                        else:
                            for n,p in global_model.named_parameters():
                                average_grad_dict[n] += Stochastic_Quantization(two_norm,clients_momentum_dict_list[client][n],args.qlevel)
                            
                optimizer.zero_grad()

            
            # global model step
            with torch.no_grad():
                for n,p in global_model.named_parameters():
                    p.grad.data = average_grad_dict[n] / 10 

            optimizer.step()
            
            #check
            with torch.no_grad():
                # train acc/loss
                train_loss, train_total, train_correct = 0.0, 0.0, 0.0
                for batch_idx, (images1, labels1) in enumerate(train_all_loader):
                    images, labels = images1.to(device), labels1.to(device)

                    # Inference
                    outputs = global_model(images)
                    batch_loss = criterion(outputs, labels)
                    train_loss += batch_loss.item() * len(labels)
                    train_total += len(labels)
                    
                    _, pred_labels = torch.max(outputs, 1)
                    pred_labels = pred_labels.view(-1)
                    train_correct += torch.sum(torch.eq(pred_labels, labels)).item()
                train_loss /= train_total
                train_correct /= train_total
                
                # test acc/loss
                test_loss, test_total, test_correct = 0.0, 0.0, 0.0
                for batch_idx, (images1, labels1) in enumerate(test_all_loader):
                    images, labels = images1.to(device), labels1.to(device)

                    # Inference
                    outputs = global_model(images)
                    batch_loss = criterion(outputs, labels)
                    test_loss += batch_loss.item() * len(labels)
                    test_total += len(labels)
                    
                    _, pred_labels = torch.max(outputs, 1)
                    pred_labels = pred_labels.view(-1)
                    test_correct += torch.sum(torch.eq(pred_labels, labels)).item()
                test_loss /= test_total
                test_correct /= test_total
                
            print("round:", round, "train loss:", train_loss, "train acc", train_correct, "test loss:", test_loss, "test acc", test_correct)
            if args.wandb:
                wandb.log({"train loss": train_loss,
                        "train acc": train_correct, 
                        "test loss": test_loss,
                        "test acc": test_correct})
            stats = "{},{},{},{},{}\n".format(round,train_loss,train_correct,test_loss,test_correct)
            f.write(stats)
            
    with open("total_results.txt","a") as f:
        f.write(run_name+"->"+str(test_correct)+"\n")
        pass