import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.nn.functional as F    

import numpy as np
import random
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm

from scipy import io as spio

from scipy.stats import norm
import wandb
import argparse
import os

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument("--gpu")
parser.add_argument("--run_name")
parser.add_argument("--rerun", type = bool)
parser.add_argument("--rep", type = int)


parser.add_argument('--clients_number_per_round', default = 100, type=int)
parser.add_argument('--local_update_step_number', default = 20, type=int)
parser.add_argument('--num_rounds', default = 500, type=int)
parser.add_argument('--lr_client', default = 0.05, type=float)
parser.add_argument('--lr_server', default = 1, type=float)
parser.add_argument('--local_bach_size', default = 16, type=int)
parser.add_argument('--noise_multiplier_sigma', default = 0, type=float)
parser.add_argument('--use_sign', action="store_true")
parser.add_argument('--uniform', action="store_true")
parser.add_argument('--OneEpoch', action="store_true")
parser.add_argument('--wandb', action="store_true")
parser.add_argument('--momentum', default = 0, type=float)
parser.add_argument('--mark', default = "",type=str)
parser.add_argument('--qlevel', default = 1,type=int)




args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
# define paths
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# FL
clients_number_per_round = args.clients_number_per_round # total: 1000
local_update_step_number = args.local_update_step_number # FedSGD
num_rounds = args.num_rounds 

# local model updating
lr_client = args.lr_client # learning rate of client
lr_server = args.lr_server # learning rate of server
local_bach_size = args.local_bach_size
momentum = args.momentum

# DP noise
noise_multiplier_sigma = args.noise_multiplier_sigma
use_sign = args.use_sign

if args.OneEpoch:
    local_update_step_number = 1

lmbds = np.arange(1.1,1000,0.5)
total_clients = 3579
    
run_name = args.run_name + args.mark+"_clients:{}_lsteps:{}_rounds:{}_lrc:{}_lrs:{}_bz:{}_noise:{}_sign:{}_uni:{}_OneEpoch:{}_momentum:{}_qlevel:{}".format( clients_number_per_round,
                                                                                          local_update_step_number,
                                                                                          num_rounds,
                                                                                          lr_client, lr_server,
                                                                                          local_bach_size,
                                                                                          noise_multiplier_sigma,
                                                                                          use_sign,args.uniform, args.OneEpoch,momentum,args.qlevel)

print(run_name)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
if args.wandb:
    wandb.init(project="DP-FL", entity="t1773420638",name=run_name)

class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class NeuralNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 8, 2, padding=3)
        self.conv2 = nn.Conv2d(16, 32, 4, 2)
        self.fc1 = nn.Linear(32 * 4 * 4, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        # x of shape [B, 1, 28, 28]
        x = F.relu(self.conv1(x))  # -> [B, 16, 14, 14]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 16, 13, 13]
        x = F.relu(self.conv2(x))  # -> [B, 32, 5, 5]
        x = F.max_pool2d(x, 2, 1)  # -> [B, 32, 4, 4]
        x = x.view(-1, 32 * 4 * 4)  # -> [B, 512]
        x = F.relu(self.fc1(x))  # -> [B, 32]
        x = self.fc2(x)  # -> [B, 10]
        return x

    def name(self):
        return "NeuralNet"
    
def Stochastic_Quantization(scale, vec, qlevel):
    vec_normalized = torch.abs(vec / scale)
    vec_sign = torch.sign(vec)
    interval = 1/qlevel
    
    vec_level = torch.div(vec_normalized, interval, rounding_mode='trunc')
    res = vec_normalized % interval
    
    vec_quantized = scale * vec_sign * (vec_level + ((torch.rand_like(res) * interval) < res).float()) * interval
    
    return vec_quantized

# address
emnist = spio.loadmat("data/EMNIST/emnist-digits.mat")

# ------ training images ------ #
train_images = emnist["dataset"][0][0][0][0][0][0]
train_images = train_images.astype(np.float32)
train_images /= 255

# ------ training labels ------ #
train_labels = emnist["dataset"][0][0][0][0][0][1].reshape(240000).tolist()


# ------ test images ------ #
test_images = emnist["dataset"][0][0][1][0][0][0]
test_images = test_images.astype(np.float32)
test_images /= 255

# ------ test labels ------ #
test_labels = emnist["dataset"][0][0][1][0][0][1].reshape(40000).tolist()

# ------ reshape using matlab order ------ #
train_images = train_images.reshape(train_images.shape[0], 1, 28, 28, order="A")
test_images = test_images.reshape(test_images.shape[0], 1, 28, 28, order="A")

# calculate mean and standard deviation ------ #
mean_px = train_images.mean().astype(np.float32)
std_px = train_images.std().astype(np.float32)

# normalize
train_images = (train_images-mean_px)/std_px
test_images = (test_images-mean_px)/std_px

train_dataset = list(map(list, zip(torch.tensor(train_images), train_labels)))
test_dataset = list(map(list, zip(torch.tensor(test_images), test_labels)))

# DataLoader
train_all_loader = DataLoader(train_dataset, batch_size= 1280, shuffle=False)
test_all_loader = DataLoader(test_dataset, batch_size=1280, shuffle=False)



train_writers = emnist["dataset"][0][0][0][0][0][2].reshape(240000)
train_writers_set = np.unique(train_writers)

# ------ random seed ------ #
# seed = int(123123)
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
# np.random.seed(seed)  # Numpy module.
# random.seed(seed)  # Python random module.
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

    
if os.path.exists("results/"+run_name+".txt") and (not args.rerun):
    print("Run already exists:")
else:
    print("start training:")
    global_model = NeuralNet().to(device)
    model_diff_average_momentum_dict = {n:torch.zeros_like(p) for n,p in global_model.named_parameters()}
    with open("results/"+run_name+".txt","w") as f:
        for round in range(num_rounds):
            choosed_clients = np.random.choice(train_writers_set,replace=False,size=clients_number_per_round)
            model_diff_average_dict = {n:torch.zeros_like(p) for n,p in global_model.named_parameters()}
            global_state_dict = copy.deepcopy(global_model.state_dict())
            for client_ind, client in enumerate(choosed_clients):
                client_writer_indexes = np.where(train_writers==client)[0]
                train_dataset_each_client = DatasetSplit(train_dataset, client_writer_indexes)
                    
                # mini-batch samples
                client_train_loader = torch.utils.data.DataLoader(train_dataset_each_client,
                                            batch_size = min([local_bach_size, len(train_dataset_each_client)]), shuffle=True) 
                
                local_model = NeuralNet().to(device)
                local_model.load_state_dict(global_state_dict)   
                local_optimizer = torch.optim.SGD(local_model.parameters(),lr_client,momentum = 0,weight_decay = 0)
                # local update
                local_step = 0
                while local_step < local_update_step_number:
                    local_model.train()
                    for i, (images, target) in enumerate(client_train_loader): 
                        local_step += 1
                        
                        input_var = images.to(device)
                        target_var = target.to(device)
                        
                        # compute output
                        output = local_model(input_var)
                        loss = criterion(output, target_var)

                        # compute gradient and do SGD step
                        local_optimizer.zero_grad()
                        loss.backward()
                        local_optimizer.step()

                        if (local_step >= local_update_step_number) and (not args.OneEpoch):
                            break
                        
                local_model_state_dict = local_model.state_dict() 
                
                temp_diff = {n:torch.zeros_like(p) for n,p in global_model.named_parameters()}
                two_norm = 0
                with torch.no_grad():
                    for n,p in global_model.named_parameters():    
                        temp_diff[n] = (local_model_state_dict[n] - p.data) 
                        two_norm += torch.pow(torch.norm(temp_diff[n],p=2),2)
                        
                    two_norm = torch.sqrt(two_norm)
                    for n,p in global_model.named_parameters():    
                        model_diff_average_dict[n] += Stochastic_Quantization(two_norm, temp_diff[n], args.qlevel)
                    
                
                # print("round:", round,"client", client_ind)
                
            # average
            for n,p in global_model.named_parameters():    
                model_diff_average_dict[n] /= clients_number_per_round
            
            # update momentum
            for n,p in global_model.named_parameters():
                model_diff_average_momentum_dict[n] = momentum * model_diff_average_momentum_dict[n] + model_diff_average_dict[n]
            
            # global model step
            with torch.no_grad():
                for n,p in global_model.named_parameters():
                    p.data += lr_server * model_diff_average_momentum_dict[n]  

            
            #check
            with torch.no_grad():
                # train acc/loss
                train_loss, train_total, train_correct = 0.0, 0.0, 0.0
                for batch_idx, (images1, labels1) in enumerate(train_all_loader):
                    images, labels = images1.to(device), labels1.to(device)

                    # Inference
                    outputs = global_model(images)
                    batch_loss = criterion(outputs, labels)
                    train_loss += batch_loss.item() * len(labels)
                    train_total += len(labels)
                    
                    _, pred_labels = torch.max(outputs, 1)
                    pred_labels = pred_labels.view(-1)
                    train_correct += torch.sum(torch.eq(pred_labels, labels)).item()
                train_loss /= train_total
                train_correct /= train_total
                
                # test acc/loss
                test_loss, test_total, test_correct = 0.0, 0.0, 0.0
                for batch_idx, (images1, labels1) in enumerate(test_all_loader):
                    images, labels = images1.to(device), labels1.to(device)

                    # Inference
                    outputs = global_model(images)
                    batch_loss = criterion(outputs, labels)
                    test_loss += batch_loss.item() * len(labels)
                    test_total += len(labels)
                    
                    _, pred_labels = torch.max(outputs, 1)
                    pred_labels = pred_labels.view(-1)
                    test_correct += torch.sum(torch.eq(pred_labels, labels)).item()
                test_loss /= test_total
                test_correct /= test_total
                
            print("round:", round, "train loss:", train_loss, "train acc", train_correct, "test loss:", test_loss, "test acc", test_correct)
            if args.wandb:
                wandb.log({"train loss": train_loss,
                        "train acc": train_correct, 
                        "test loss": test_loss,
                        "test acc": test_correct})
            stats = "{},{},{},{},{}\n".format(round,train_loss,train_correct,test_loss,test_correct)
            f.write(stats)
            
    with open("total_results.txt","a") as f:
        f.write(run_name+"->"+str(test_correct)+"\n")
        