import torch 
from torch import nn
from torch.nn import functional as F
import numpy as np
import json
from numpy.random import Generator, PCG64
import argparse
import os
from tqdm import tqdm
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, padding=(2,2))
        self.conv2 = nn.Conv2d(32, 64, 5, padding=(2,2))
        
        
        self.fc1 = nn.Linear(7*7*64, 2048)
        self.fc2 = nn.Linear(2048, 62)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2,stride=2)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2,stride=2)
        x = torch.flatten(x,start_dim=1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x 

def get_train_data(path):
    data_dict = {}
    user_names = []
    for i in range(36):
        with open(path+"/all_data_"+str(i)+"_niid_0_train_9.json", 'r') as f:
            data = json.load(f)
            data_dict = {**data_dict, **data["user_data"]}
            user_names = user_names + data["users"]
    return data_dict,user_names

def get_test_data(path):
    data_dict = {}
    user_names = []
    for i in range(36):
        with open(path+"/all_data_"+str(i)+"_niid_0_test_9.json", 'r') as f:
            data = json.load(f)
            data_dict = {**data_dict, **data["user_data"]}
            user_names = user_names + data["users"]
    return data_dict,user_names

def test(model,test_data_dict,device):
    criterion = nn.CrossEntropyLoss()
    acc_dict = {}
    loss_dict = {}
    for key in test_data_dict: 
        inputs = torch.tensor(test_data_dict[key]["x"]).reshape((-1,1,28,28)) 
        labels = torch.tensor(test_data_dict[key]["y"])
        with torch.no_grad():
            outputs = model(inputs.to(device))
        _, predicted = torch.max(outputs.data, 1)
        acc_dict[key]=(predicted == labels.to(device)).float().mean().item()
        loss_dict[key]=criterion(outputs, labels.to(device)).item()
    return acc_dict,loss_dict 


def train(data_dict,user_names,test_data_dict,test_use_names,T=10000,alpha_0=0,alpha_1=0,lr=0.06,save_name="",n_players=3,median=False):
    try:
        os.mkdir("FeMNIST_Results")
    except:
        None
    try:
        os.mkdir(save_name)
    except:
        print("Folder exists, overwriting results")
    alpha_dict = {key:0 for key in user_names}
    criterion = nn.CrossEntropyLoss()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Net().to(device)

    rng = Generator(PCG64(42))
    alpha_typedict = {key: int(rng.uniform()>2/3) for key in range(len(user_names))}

    losses_global = np.zeros((T,n_players))
    mses_global = np.zeros((T,n_players))
    clients_global = {}
    alphas_global = np.zeros((T,n_players))
    alphatypes_global = np.zeros((T,n_players))
    gradsizes_global = np.zeros((T,n_players))

    for step in range(T):
        indexes = np.random.choice(np.arange(len(user_names)),n_players)
        clients = [user_names[index] for index in indexes]
        alphas = [alpha_typedict[index]*(alpha_1-alpha_0) + alpha_0 for index in indexes]
        grads = []
        grads_real = []
        losses = []
        mses = []
        sizes = []
        for i in range(n_players):
            inputs = torch.tensor(data_dict[clients[i]]["x"]).reshape((-1,1,28,28)) 
            labels = torch.tensor(data_dict[clients[i]]["y"])
            sizes.append(len(labels)) 
            outputs = model(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            grad = [x.grad for x in model.parameters()]
            sent = [g+alphas[i]*torch.normal(torch.zeros_like(g))/np.sqrt(len(g.flatten())) for g in grad]
            grads.append(sent)
            grads_real.append(grad)
            losses.append(loss.detach().cpu().numpy().item())
        if not median:
            mean = [torch.sum(torch.stack([grads[j][i]*sizes[j]/sum(sizes) for j in range(n_players)]),0) 
                    for i in range(len(grads[0]))]
        else:
            mean = [torch.median(torch.stack([grads[j][i] for j in range(n_players)]),0).values
                for i in range(len(grads[0]))]
            

        mses = [torch.sum(torch.stack([torch.sum((grads[j][i]-mean[i])**2)
                                       for i in range(len(grads[0]))])).detach().cpu().numpy().item()  
                          for j in range(n_players)]

        gradsizes = [torch.sum(torch.stack([torch.sum((grads_real[j][i])**2)
                                       for i in range(len(grads[0]))])).detach().cpu().numpy().item()  
                          for j in range(n_players)]

        with torch.no_grad():
            for i,param in enumerate(model.parameters()):
                param -= lr * mean[i]
                param.grad = None

        losses_global[step] = losses 
        alphas_global[step] = alphas
        alphatypes_global[step] = [alpha_typedict[index] for index in indexes]
        clients_global[step] = clients
        mses_global[step] = mses
        gradsizes_global[step] = gradsizes

    final_accs,final_losses = test(model,test_data_dict,device)


    redist_global = 1.5 * mses_global - 1.5 * np.mean(mses_global,1,keepdims=True)
    
    quick_results = {"accs":np.mean([final_accs[key] for key in final_accs]),
                     "losses":np.mean([final_losses[key] for key in final_losses]),
                     "mses_a0":np.sum(mses_global*(1-alphatypes_global))/(np.sum(1-alphatypes_global)),
                     "mses_a1":np.sum(mses_global*alphatypes_global)/(np.sum(alphatypes_global)),
                     "redist_a0":np.sum(redist_global*(1-alphatypes_global))/(np.sum(1-alphatypes_global)),
                     "redist_a1":np.sum(redist_global*alphatypes_global)/(np.sum(alphatypes_global)),
                     "redist_checksum":np.mean(redist_global)
                     }

    
    np.save(save_name+"/losses_step",losses_global)
    np.save(save_name+"/mses_step",mses_global)
    np.save(save_name+"/alphas_step",alphas_global)
    np.save(save_name+"/alphatypes_step",alphatypes_global)
    np.save(save_name+"/gradsizes_step",gradsizes_global)

    json.dump(clients_global, open(save_name+"/clients_step", 'w' ) )
    json.dump(final_accs, open(save_name+"/accs_final", 'w' ) )
    json.dump(final_losses, open(save_name+"/losses_final", 'w' ) )
    json.dump(quick_results, open(save_name+"/quick_results", 'w' ) )
    return None 


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('id',type=int)
    parser.add_argument('-T',default=3550,type=int,required=False)
    parser.add_argument('-a0',default=0,type=float,required=False)
    parser.add_argument('-a1',default=0,type=float,required=False)
    parser.add_argument('-lr',default=0.06,type=float,required=False)
    parser.add_argument('-np',default=3,type=float,required=False)
    parser.add_argument('-m', '--median',
                    action='store_true')
    args = parser.parse_args()
    

    save_name = "FeMNIST_Results/p_"+str(args.np)+"_"+str(args.id)+"a0_"+str(args.a0)+"a1_"+str(args.a1)+"T_"+str(args.T) +"lr_"+str(args.lr)+args.m*"median"

    data_dict,user_names = get_train_data("data/train")
    test_data_dict,test_user_names = get_test_data("data/test")
    train(data_dict,user_names,test_data_dict,test_user_names,T=args.T,alpha_0=args.a0,alpha_1=args.a1,lr=args.lr,
            save_name=save_name,n_players=args.np,median=args.m)


           

