import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.optim
import argparse
import tracemalloc
import time
from mpi4py import MPI
from itertools import repeat
from torch.nn import functional as F
from torch.autograd.functional import vhp
from torch.autograd.functional import hvp
from torch.autograd.functional import hessian
from torchvision import datasets
from grad_y import *
import copy
import yaml
import matplotlib.pyplot as plt

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

# keep track of the communicated blocks
com_blocks = 0.0

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', default=10, type=int)
    parser.add_argument('--batch_size', type=int, default=1000)
    parser.add_argument('--test_size', type=int, default=1000)
    parser.add_argument('--o_steps', type=int, default=50, help='K')
    parser.add_argument('--iterations', type=int, default=10, help='T')
    parser.add_argument('--N', type=int, default=20, help='total steps of HIGP oracle')
    parser.add_argument('--outer_lr', type=float, default=0.3, help='beta')
    parser.add_argument('--inner_lr', type=float, default=0.3, help='alpha')
    parser.add_argument('--higp_lr', type=float, default=0.01, help='lr for HIGP oracle')
    parser.add_argument('--mv_lr', type=float, default=0, help='lr for moving average')
    parser.add_argument('--data_path', default='data/', help='The temporary data storage path')
    parser.add_argument('--training_size', type=int, default=20000)
    parser.add_argument('--validation_size', type=int, default=5000)
    parser.add_argument('--save_folder', type=str, default='', help='path to save result')
    parser.add_argument('--model_name', type=str, default='', help='Experiment name')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--alg', type=str, default='SUN-GT', choices=['SUN-HR','SUN-GT', 'SUN-SE', 'DSGDA-GT'], help='Algorithm to use')
    parser.add_argument('--num_of_nodes', default=8, type=int, help='The number of total nodes')
    parser.add_argument('--network_weight', default=0.4, type=float)
    parser.add_argument('--alpha_x', type=float, default=0.3, help='alpha_x')
    parser.add_argument('--alpha_y', type=float, default=0.3, help='alpha_y')
    parser.add_argument('--alpha_theta', type=float, default=3, help='alpha_theta')
    parser.add_argument('--gamma', type=float, default=1, help='gamma')
    parser.add_argument('--ck_bar', type=float, default=2, help='ck_bar')
    parser.add_argument('--exp', type=float, default=0.3, help='exp')  
    parser.add_argument('--beta_x', type=float, default=0.01, help='alpha_x')
    parser.add_argument('--beta_y', type=float, default=0.2, help='alpha_y')
    parser.add_argument('--beta_theta', type=float, default=0.1, help='alpha_theta')
    parser.add_argument('--output', type=str, default='./output', help='Path to save the output results')
    args = parser.parse_args()
    args.validation_size = 60000 // args.num_of_nodes
    args.training_size = 60000 // args.num_of_nodes
    
    return args
    
    

def get_data_loaders(args):
    kwargs = {'num_workers': 0, 'pin_memory': True}
    dataset = datasets.MNIST(root=args.data_path, train=True, download=True,
                        transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                        ]))
    train_sampler = torch.utils.data.sampler.SequentialSampler(dataset)


    train_loader = torch.utils.data.DataLoader(dataset, sampler=train_sampler,
        batch_size=args.training_size, **kwargs)
    val_loader = torch.utils.data.DataLoader(dataset, sampler=train_sampler,
        batch_size=args.validation_size, **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(root=args.data_path, train=False,
                        download=True, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                        ])), batch_size=args.test_size)
    train_loader = [train_loader, val_loader]
    return train_loader, test_loader

def train_model(args):
    num_of_nodes = args.num_of_nodes
    batch_num = args.training_size//args.batch_size
    w = args.network_weight
    w_neighbor = (1 - w) / 2
    n_features, n_classes = 785, args.num_classes
    parameters = torch.randn((n_classes, n_features), requires_grad=True)
    parameters_theta = copy.deepcopy(parameters)
    lambda_x = torch.zeros(n_features, requires_grad=True)
    r = torch.zeros(n_features)

    parameters_mean = comm.reduce(parameters, op=MPI.SUM, root=0)
    parameters_mean_theta = comm.reduce(parameters_theta, op=MPI.SUM, root=0)
    
    images_list_val, labels_list_val = [], []
    images_list_train, labels_list_train = [], []
    all_results = np.zeros((args.o_steps+1, 6))
    images, labels = None, None
    
    def out_f(data, parameters):
        output = torch.matmul(data, torch.t(parameters[0][:, 0:784])) + parameters[0][:, 784]
        return output
    
    def reg_f(params, hparams, loss):
        loss_regu = torch.mean(loss) + ((params[0] ** 2) * torch.exp(hparams.unsqueeze(1))).mean()
        return loss_regu
    
    def hessian_vector_product(params, val_data_list, v):
        #lower level unregularized loss
        data_list, labels_list = val_data_list
        def g_loss(x):
            return F.cross_entropy(data_list[0] @ (x[:, 0:784]).t(), labels_list[0])
        return vhp(g_loss, params[0], v)[1]

    def eval_acc(params, x, y):
        out = out_f(x, params)
        pred = out.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(y.view_as(pred)).sum().item() / len(y)
        return acc

    def data_batch(batch_size, x, y):
        batch_index = np.random.permutation(np.arange(args.training_size))[0:batch_size]
        return x[batch_index], y[batch_index]
    
    def gossip(info):
        # communication and do the average step

        # starting the monitoring
        tracemalloc.start()
        comm.send(info, dest=(rank+1)%size, tag=rank)
        comm.send(info, dest=(rank-1)%size, tag=rank+size)
        info_for = comm.recv(source=(rank+1)%size, tag=(rank+1)%size + size)
        info_back = comm.recv(source=(rank-1)%size, tag=(rank-1)%size)
        output = w_neighbor * (info_for + info_back) + w * info

        # the memory
        memory_usage = tracemalloc.get_traced_memory()
        memory_usage = memory_usage[1] - memory_usage[0]

        # stopping the library
        tracemalloc.stop()
        global com_blocks
        com_blocks += memory_usage
        
        return output

    #def HIGP(params, hparams, val_data_list, args, out_f):
        data_list, labels_list = val_data_list
        # Fy_gradient
        output = out_f(data_list[0], params)
        b = gradient_fy(args, labels_list[0], params, data_list[0], output).detach()

        #initialization
        z, s_new, s_old, d = torch.zeros_like(b), -b, -b, -b

        #Hessian of the reg part
        H_temp = 2 * torch.exp(hparams) / (n_features * n_classes)

        for _ in range(args.N):
            #communication
            z = gossip(z) - args.higp_lr * d

            #local update for s
            s_new = hessian_vector_product(params, val_data_list, z) + H_temp * z - b

            #gradient tracking
            d = gossip(d) + s_new - s_old

            s_old = s_new.clone().detach()
        # z.shape = [10, 785]
        # The Jacobian-vector product has a closed form expression
        z = torch.sum(z * params[0]) * H_temp
        return -z
    
    #def inner_dsgt(lambda_x, parameters, gamma, u, h, T, alpha=0):
        # dsgt inner loop
        # alpha > 0 when the inner loop is used to update y, otherwise z
        # hyperparam: theta; param: x
        for t in range(T):
            images, labels = data_batch(args.batch_size, train_images, train_labels)
            images = torch.reshape(images, (images.size()[0],-1))
            output = out_f(images, [parameters])
            inner_grad = gradient_gy(args, labels, parameters, images, lambda_x, output, reg_f)

            # compute update direction
            if alpha > 0:
                upper_grad = gradient_fy(args, labels, parameters, images, out_f(images, [parameters]))
                h_new = upper_grad + alpha * inner_grad
            else:    
                h_new = inner_grad

            # communication and update 
            u = gossip(u) + h_new - h
            parameters = gossip(parameters) - gamma * u

            # update h
            h = h_new.clone().detach()

        return parameters, u, h_new

    #def inner_dsgd(lambda_x, parameters, gamma, T):
        # dsgd inner loop
        # hyperparam: theta; param: x
        for t in range(T):
            images, labels = data_batch(args.batch_size, train_images, train_labels)
            images = torch.reshape(images, (images.size()[0],-1))
            output = out_f(images, [parameters])
            inner_grad = gradient_gy(args, labels, parameters, images, lambda_x, output, reg_f)
            
            # communication and update 
            parameters = gossip(parameters) - gamma * inner_grad
            
        return parameters
    
    if rank == 0:
        params_history, hparams_history = [], []
        parameters_mean = parameters_mean / num_of_nodes
        train_loss_avg = loss_train_avg(train_loader[0], parameters_mean, batch_num)
        test_acc, test_loss_avg = loss_test_avg(test_loader, parameters_mean)
        all_results[0, :] = [train_loss_avg, test_loss_avg, 0.0, 0.0, 0.0, 0.0] 
        
        images_list_train, labels_list_train = [], []
        for index, (images, labels) in enumerate(train_loader[0]):
            images_list_train.append(images)
            labels_list_train.append(labels)
        for index, (images, labels) in enumerate(train_loader[1]):
            images_list_val.append(images)
            labels_list_val.append(labels)
        whole_images_val = torch.stack(images_list_val)
        whole_images_val = torch.reshape(whole_images_val, (-1,1,28,28))
        whole_images_val = torch.reshape(whole_images_val, (whole_images_val.size()[0],-1))
        whole_labels_val = torch.stack(labels_list_val)
        whole_labels_val = torch.reshape(whole_labels_val, (-1,))
        
    #distribute the data
    train_images = comm.scatter(images_list_train, root=0)
    train_labels = comm.scatter(labels_list_train, root=0)
    train_images = torch.reshape(train_images, (train_images.size()[0],-1))
    train_labels = train_labels
    
    val_images = comm.scatter(images_list_val, root=0)
    val_labels = comm.scatter(labels_list_val, root=0)
    val_images = torch.reshape(val_images, (val_images.size()[0],-1))
    val_labels = val_labels
    
    val_indices = torch.randperm(args.validation_size)[0:args.batch_size]
    val_data_list = [[val_images[val_indices]], [val_labels[val_indices]]]
    #outer_grad = HIGP([parameters], lambda_x, val_data_list, args, out_f)
    
    #if rank == 0:
    #    final_grad = outer_grad
    #    all_results[0, 3] = torch.linalg.norm(final_grad)
    #    print('o_step={}(0s) Hypergrad Norm: {:4f} Train Loss: {:.4f} Test Loss: {:.4f}'\
    #        .format(0, all_results[0, 3], train_loss_avg, test_loss_avg))
    gamma = args.gamma
    alpha_x = args.alpha_x 
    alpha_y = args.alpha_y
    alpha_theta = args.alpha_theta
    ck_bar = args.ck_bar
    exp = args.exp
    start_time = time.time()
    #if not args.save_folder:
    #    args.save_folder = args.output
    #args.model_name = '{}_{}_bs_{}_gamma_{}_alphax_{}_alphay_{}_alphatheta_{}_ck_bar_{}_exp_{}_ite_{}_time{}'.format(
    #    args.alg, args.training_size, args.batch_size, gamma, alpha_x, alpha_y, alpha_theta, ck_bar, exp, args.iterations, start_time)
    #args.save_folder = os.path.join(args.save_folder, args.model_name)
    #if (not os.path.isdir(args.save_folder)) and (rank == 0):
    #    os.makedirs(args.save_folder)
        
    #start_time = time.time()
    u_x = u_y = u_theta = 0
    h_x = h_y = h_theta = 0
    v_x = v_y = v_theta = 0
    if args.alg == 'SUN-GT':
         u_x = u_y = u_theta = 0
         h_x = h_y = h_theta = 0
    if args.alg == 'DSGDA-GT':
         u_x = u_y = u_theta = 0
         h_x = h_y = h_theta = 0
    elif args.alg == 'SUN-HR':
         beta_x = args.beta_x 
         beta_y = args.beta_y
         beta_theta = args.beta_theta
         u_x = u_y = u_theta = 0
         h_x = h_y = h_theta = 0
         v_x = v_y = v_theta = 0
         parameters_old = torch.zeros((n_classes, n_features), requires_grad=True)
         parameters_theta_old = copy.deepcopy(parameters)
         lambda_x_old = torch.zeros(n_features, requires_grad=True)
         
    for epoch in range(args.o_steps):
        if args.alg == 'SUN-SE':
            start_time = time.time()
            
            ck = ck_bar * 1 / ((1+epoch)**exp)
            v_images, v_labels = data_batch(args.batch_size, val_images, val_labels)
            #v_images = torch.reshape(val_images, (val_images.size()[0],-1))
            
            images, labels = data_batch(args.batch_size, train_images, train_labels)
            images = torch.reshape(images, (images.size()[0],-1))
            output = out_f(images, [parameters])
            #output_old = out_f(images, [parameters_old])
            
            upper_grad_y = gradient_fy(args, v_labels, parameters, v_images, out_f(v_images, [parameters]))
            inner_grad_y = gradient_gy(args, labels, parameters, images, lambda_x, output, reg_f)
            inner_grad_theta = gradient_gy(args, labels, parameters_theta, images, lambda_x, output, reg_f)
            inner_grad_x = gradient_gx(args, labels, parameters, images , lambda_x , output , reg_f )
            inner_grad_theta_x = gradient_gx(args, labels, parameters_theta , images , lambda_x , output , reg_f )
            
            h_y_new = ck*upper_grad_y + inner_grad_y +  gamma*(parameters_theta-parameters)
            h_theta_new =  inner_grad_theta + gamma*(parameters_theta-parameters)
            h_x_new = inner_grad_x - inner_grad_theta_x
            
            # communication and update 
            #u_y = gossip(u_y) + h_y_new - h_y
            #u_theta = gossip(u_theta) + h_theta_new - h_theta 
            #u_x = gossip(u_x) + h_x_new - h_x 
            parameters = gossip(parameters - alpha_y *  h_y_new )
            parameters_theta =   gossip(parameters_theta - alpha_theta * h_theta_new)
            lambda_x = gossip(lambda_x - alpha_x * h_x_new)
            end_time = time.time()
              
        elif args.alg == 'SUN-GT':
            start_time = time.time()
            #inner loop
            #parameters = inner_dsgd(lambda_x, parameters, args.inner_lr, args.iterations)
            
            ck = ck_bar * 1 / ((1+epoch)**exp)
            v_images, v_labels = data_batch(args.batch_size, val_images, val_labels)
            #v_images = torch.reshape(val_images, (val_images.size()[0],-1))
            
            images, labels = data_batch(args.batch_size, train_images, train_labels)
            images = torch.reshape(images, (images.size()[0],-1))
            output = out_f(images, [parameters])
            
            upper_grad_y = gradient_fy(args, v_labels, parameters, v_images, out_f(v_images, [parameters]))
            inner_grad_y = gradient_gy(args, labels, parameters, images, lambda_x, output, reg_f)
            inner_grad_theta = gradient_gy(args, labels, parameters_theta, images, lambda_x, output, reg_f)
            inner_grad_x = gradient_gx(args, labels, parameters, images, lambda_x, output, reg_f)
            inner_grad_theta_x = gradient_gx(args, labels, parameters_theta, images, lambda_x, output, reg_f)

            h_y_new = ck*upper_grad_y + inner_grad_y +  gamma*(parameters_theta-parameters)
            h_theta_new =  inner_grad_theta + gamma*(parameters_theta-parameters)
            h_x_new = inner_grad_x - inner_grad_theta_x
            # communication and update 
            u_y = gossip(u_y + h_y_new - h_y)
            u_theta = gossip(u_theta + h_theta_new - h_theta) 
            u_x = gossip(u_x + h_x_new - h_x) 
            
        elif args.alg == 'DSGDA-GT':
            start_time = time.time()
            #inner loop
            #parameters = inner_dsgd(lambda_x, parameters, args.inner_lr, args.iterations)
            
            ck = ck_bar 
            v_images, v_labels = data_batch(args.batch_size, val_images, val_labels)
            #v_images = torch.reshape(val_images, (val_images.size()[0],-1))
            
            images, labels = data_batch(args.batch_size, train_images, train_labels)
            images = torch.reshape(images, (images.size()[0],-1))
            output = out_f(images, [parameters])
            
            upper_grad_y = gradient_fy(args, v_labels, parameters, v_images, out_f(v_images, [parameters]))
            inner_grad_y = gradient_gy(args, labels, parameters, images, lambda_x, output, reg_f)
            inner_grad_theta = gradient_gy(args, labels, parameters_theta, images, lambda_x, output, reg_f)
            inner_grad_x = gradient_gx(args, labels, parameters, images, lambda_x, output, reg_f)
            inner_grad_theta_x = gradient_gx(args, labels, parameters_theta, images, lambda_x, output, reg_f)

            h_y_new = upper_grad_y + ck*inner_grad_y
            h_theta_new =  ck*(-inner_grad_theta)
            h_x_new =  ck*( inner_grad_x - inner_grad_theta_x)
            # communication and update 
            u_y = gossip(u_y) + h_y_new - h_y
            u_theta = gossip(u_theta) + h_theta_new - h_theta
            u_x = gossip(u_x) + h_x_new - h_x
            
            #communication
            parameters = gossip(parameters) - alpha_y * u_y
            parameters_theta =   gossip(parameters_theta) - alpha_theta * u_theta
            lambda_x = gossip(lambda_x) - alpha_x * u_x
            h_y = h_y_new.clone().detach()
            h_theta = h_theta_new.clone().detach()
            h_x = h_x_new.clone().detach()
            end_time = time.time()

        elif args.alg == 'SUN-HR':
            start_time = time.time()
            #inner loop
            #parameters = inner_dsgd(lambda_x, parameters, args.inner_lr, args.iterations)
            
            ck = ck_bar * 1 / ((1+epoch)**exp)
            v_images, v_labels = data_batch(args.batch_size, val_images, val_labels)
            #v_images = torch.reshape(val_images, (val_images.size()[0],-1))
            
            images, labels = data_batch(args.batch_size, train_images, train_labels)
            images = torch.reshape(images, (images.size()[0],-1))
            output = out_f(images, [parameters])
            upper_grad_y = gradient_fy(args, v_labels, parameters, v_images, out_f(v_images, [parameters]))
            inner_grad_y = gradient_gy(args, labels, parameters, images, lambda_x, output, reg_f)
            inner_grad_theta = gradient_gy(args, labels, parameters_theta, images, lambda_x, output, reg_f)
            inner_grad_x = gradient_gx(args, labels, parameters, images , lambda_x , output , reg_f )
            inner_grad_theta_x = gradient_gx(args, labels, parameters_theta , images , lambda_x , output , reg_f )
            
            h_y_new = ck*upper_grad_y + inner_grad_y +  gamma*(parameters_theta-parameters)
            h_theta_new =  inner_grad_theta + gamma*(parameters_theta-parameters)
            h_x_new = inner_grad_x - inner_grad_theta_x

            if epoch == 0:
                h_y = 0
                h_theta = 0
                h_x = 0
            else:
                output_old = out_f(images, [parameters_old])
                upper_grad_y_old = gradient_fy(args, v_labels, parameters_old, v_images, out_f(v_images, [parameters_old]))
                inner_grad_y_old = gradient_gy(args, labels, parameters_old, images, lambda_x_old, output_old, reg_f)
                inner_grad_theta_old = gradient_gy(args, labels, parameters_theta_old, images, lambda_x_old, output_old, reg_f)
                inner_grad_x_old = gradient_gx(args, labels, parameters_old, images , lambda_x_old , output_old , reg_f )
                inner_grad_theta_x_old = gradient_gx(args, labels, parameters_theta_old , images , lambda_x_old , output_old , reg_f )
                ck_1 = ck_bar * 1 / ((1+epoch-1)**exp)
                h_y = ck_1*upper_grad_y_old + inner_grad_y_old +  gamma*(parameters_theta_old-parameters_old)
                h_theta =  inner_grad_theta_old + gamma*(parameters_theta_old-parameters_old)
                h_x = inner_grad_x_old - inner_grad_theta_x_old


            
            v_y_new = h_y_new - (1-beta_y) * (v_y-h_y)
            u_y = gossip(u_y + v_y_new - v_y)
            v_theta_new = h_theta_new - (1-beta_theta) * (v_theta-h_theta)
            u_theta = gossip(u_theta + v_theta_new - v_theta)
            v_x_new = h_x_new - (1-beta_x) * (v_x-h_x)
            u_x = gossip(u_x + v_x_new - v_x)
            
            parameters_old = parameters.clone().detach().requires_grad_(True)
            lambda_x_old = lambda_x.clone().detach().requires_grad_(True)
            parameters_theta_old = parameters_theta.clone().detach().requires_grad_(True)

            parameters = gossip(parameters - alpha_y * u_y)
            parameters_theta = gossip(parameters_theta - alpha_theta * u_theta)
            lambda_x = gossip(lambda_x - alpha_x * u_x)
            #communication
            #parameters = gossip(parameters) - alpha_y * u_y
            #parameters_theta =   gossip(parameters_theta) - alpha_theta * u_theta

            #h_y = h_y_new.clone().detach()
            #h_theta = h_theta_new.clone().detach()
            #h_x = h_x_new.clone().detach()

            v_x = v_x_new.clone().detach()
            v_y = v_y_new.clone().detach()
            v_theta = v_theta_new.clone().detach()
            
            end_time = time.time()
            
            # communication and update 
            #u_y = gossip(u_y) + h_y_new - h_y
            #u_theta = gossip(u_theta) + h_theta_new - h_theta 
            #u_x = gossip(u_x) + h_x_new - h_x 
              
          
            
        else:
            raise NotImplementedError
        
        parameters_mean = comm.reduce(parameters, op=MPI.SUM, root=0)
        hparams_mean = comm.reduce(lambda_x, op=MPI.SUM, root=0)
        com_blocks_mean = comm.reduce(com_blocks, op=MPI.SUM, root=0)
        if rank == 0:
            parameters_mean /= num_of_nodes
            hparams_mean /= num_of_nodes
            com_blocks_mean /= num_of_nodes
            params_history.append(parameters_mean.detach().numpy()) 
            hparams_history.append(hparams_mean.detach().numpy())
            #final_grad = outer_grad
            #all_results[epoch+1, 3] = torch.linalg.norm(final_grad)
            train_loss_avg = loss_train_avg(train_loader[0], parameters_mean, batch_num)
            test_acc, test_loss_avg = loss_test_avg(test_loader, parameters_mean)
        
            #end_time = time.time()
            print('o_step={}({:.2f}s) Hypergrad Norm: {:4f} Train Loss: {:.4f} Test Loss: {:.4f} Test Accuracy: {:.4f} Com_blocks={:.2e}'\
                  .format(epoch+1, end_time-start_time, all_results[epoch+1, 3], train_loss_avg, test_loss_avg, test_acc, com_blocks))


            sys.stdout.flush()
            all_results[epoch+1, 0] = train_loss_avg
            all_results[epoch+1, 1] = test_loss_avg
            all_results[epoch+1, 2] = test_acc
            all_results[epoch+1, 4] = (end_time-start_time)
            all_results[epoch+1, 5] = com_blocks_mean
            

   
    if rank == 0:
        print(all_results)
        file_name = '{}_{}_bs_{}_gamma_{}_alphax_{}_alphay_{}_alphatheta_{}_ck_bar_{}_exp_{}_ite_{}_seed{}.npy'.format(
            args.alg, args.num_of_nodes, args.batch_size, gamma, alpha_x, alpha_y, alpha_theta, ck_bar, exp, args.iterations, args.seed)
        if args.output:
            file_addr = os.path.join(args.output, file_name)
        else:
            file_addr = os.path.join('./output', file_name)
        with open(file_addr, 'wb') as f:
            np.save(f, all_results)
    return all_results  # 确保所有进程都有返回值


def loss_train_avg(data_loader, parameters, batch_num):
    loss_avg, num = 0.0, 0
    for index, (images, labels) in enumerate(data_loader):
        if index>= batch_num:
            break
        else:
            images = torch.reshape(images, (images.size()[0],-1))
            loss = loss_f_funciton(labels, parameters, images)
            loss_avg += loss 
            num += 1
    loss_avg = loss_avg/num
    return loss_avg.detach()

def loss_test_avg(data_loader, parameters):
    loss_avg, num, acc, total_num = 0.0, 0, 0, 0
    for _, (images, labels) in enumerate(data_loader):
        images = torch.reshape(images, (images.size()[0],-1))
        out = torch.matmul(images, torch.t(parameters[:, 0:784]))+parameters[:, 784]
        pred = out.argmax(dim=1, keepdim=True)
        acc += pred.eq(labels.view_as(pred)).sum().item()
        loss = loss_f_funciton(labels, parameters, images)
        loss_avg += loss
        num += 1
        total_num += images.size()[0]
    acc = acc / total_num
    loss_avg = loss_avg / num
    return acc, loss_avg.detach()

def loss_f_funciton(labels, parameters, data):
    output = torch.matmul(data, torch.t(parameters[:, 0:784]))+parameters[:, 784]
    loss = F.cross_entropy(output, labels)
    return loss

if __name__ == '__main__':
    args = parse_args()
    torch.manual_seed(args.seed)
    if rank == 0:
        print(args)
        train_loader, test_loader = get_data_loaders(args)
    np.random.seed(0)
    torch.manual_seed(0)
    seeds = np.random.randint(0, 100, size=10)
    print("Random seeds:", seeds)
    for num in range(len(seeds)):
        args.seed = seeds[num]
        torch.manual_seed(args.seed)
        
        # Initialize the model and optimizer
        # model = Model()
        # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
        result = train_model(args)
        #filename = f'{args.alg}_{args.num_of_nodes}_agents_{args.o_steps}_steps_bs_{args.batch_size}_'
        #if args.alg in {'SUN-SE','SUN-GT', 'SUN-HR'}:
        #    filename += f'N_{args.N}_T_{args.iterations}_a_{args.outer_lr}_b_{args.inner_lr}_seed_{args.seed}'
        #np.save(filename, result)