import os
import sys
import time
import torch
import argparse
import tracemalloc
import copy
import numpy as np
import torch.nn.functional as F
from mpi4py import MPI
import matplotlib.pyplot as plt
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

# keep track of the communicated blocks
com_blocks = 0.0

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--test_size', type=int, default=10000)
    parser.add_argument('--training_size', type=int, default=50000)
    parser.add_argument('--dimension', type=int, default=100)
    parser.add_argument('--o_steps', type=int, default=200, help='K')
    parser.add_argument('--iterations', type=int, default=10, help='total steps of inner loop iteration')
    parser.add_argument('--total_steps', type=int, default=20, help='total steps of iteration of GBDSBO')
    parser.add_argument('--N', type=int, default=20, help='total steps of JHIP or HIGP oracle')
    parser.add_argument('--b', type=int, default=20, help='total steps of HI oracle in GBDSBO')
    parser.add_argument('--outer_lr', type=float, default=0.03, help='alpha')
    parser.add_argument('--inner_lr', type=float, default=0.03, help='beta')
    parser.add_argument('--jhip_lr', type=float, default=0.01, help='lr for JHIP oracle')
    parser.add_argument('--higp_lr', type=float, default=0.01, help='lr for HIGP oracle')
    parser.add_argument('--Lg', type=float, default=10.0, help='smoothness constant in GBDSBO')
    parser.add_argument('--mu', type=float, default=1.0, help='mu_g in GBDSBO')
    parser.add_argument('--heter_rate', type=float, default=1.0)
    parser.add_argument('--noise_rate', type=float, default=0.1)
    parser.add_argument('--seed', type=int, default=6)
    parser.add_argument('--alg', type=str, default='MADSBO', choices=['MADSBO', 'DSBO_JHIP', 'GBDSBO',
                                                                       'SUN-SE', 'SUN-GT', 'SUN-HR','DSGDA-GT'])
    parser.add_argument('--num_of_nodes', type=int, default=8, help='The number of total nodes')
    parser.add_argument('--network_weight', type=float, default=0.4)
    parser.add_argument('--network_topology', default='complete', choices=['ring', 'ER', 'complete'])
    parser.add_argument('--eval_interval', type=int, default=1)
    parser.add_argument('--output', type=str, default='results', help='Output directory for results')
    parser.add_argument('--alpha_x', type=float, default=0.01, help='alpha_x')
    parser.add_argument('--alpha_y', type=float, default=0.03, help='alpha_y')
    parser.add_argument('--alpha_theta', type=float, default=0.05, help='alpha_theta')
    parser.add_argument('--gamma', type=float, default=0.9, help='gamma')
    parser.add_argument('--ck_bar', type=float, default=1, help='ck_bar')
    parser.add_argument('--exp', type=float, default=0.0001, help='exp') 
    parser.add_argument('--beta_x', type=float, default=0.00, help='beta_x')
    parser.add_argument('--beta_y', type=float, default=0.3, help='beta_y')
    parser.add_argument('--beta_theta', type=float, default=0.01, help='beta_theta') 


    args = parser.parse_args()
    return args

def save_results(args, all_results):
    if args.alg in {'DSBO_JHIP', 'MADSBO'}:
        filename = f'{args.alg}_{args.network_topology}_inner_{args.iterations}_inner_lr_{args.inner_lr}_outer_lr_{args.outer_lr}_hstep_{args.N}_seed_{args.seed}'
    elif args.alg == 'GBDSBO':
        filename = f'{args.alg}_{args.network_topology}_inner_{args.iterations}_inner_lr_{args.inner_lr}_outer_lr_{args.outer_lr}_b_{args.b}_seed_{args.seed}'
    elif args.alg in {'SUN-SE','SUN-GT', 'SUN-HR', 'DSGDA-GT'}:
        filename = f'{args.alg}_{args.network_topology}_N_{args.N}_alpha_x_{args.alpha_x}_alpha_y_{args.alpha_y}_alpha_theta_{args.alpha_theta}_gamma_{args.gamma}_ck_bar_{args.ck_bar}_exp_{args.exp}_seed_{args.seed}'
    filepath = os.path.join(args.output, filename)
    np.save(filepath, all_results)

def train_model(args):
    torch.set_default_dtype(torch.float32)
    torch.set_default_device('cpu')
    
    #def set_seed(seed):
    #    np.random.seed(seed)
    #    torch.manual_seed(seed)

    def frnp(x):
        t = torch.from_numpy(x)
        return t.float()
    def tonp(x):
        return x.detach().numpy()

    def eval_acc(params, x, y):
        out = x @ params[0]
        pred = 1 * (out > 0)
        acc = pred.eq(y).sum() / len(y) # number of correct predictions
        return acc
    
    def data_batch(batch_size, x, y):
        batch_index = np.random.permutation(np.arange(local_n))[0:batch_size]
        return x[batch_index], y[batch_index]

    #setup and initializaton
    n, d = args.training_size, args.dimension
    num_of_nodes = args.num_of_nodes
    topology = args.network_topology
    if topology == 'ring':
        weight = args.network_weight
        weight_neighbor = (1 - weight) / 2
    elif topology == 'ER':
        W = np.array([[0.25714286, 0.2, 0. , 0. , 0. , 0.14285714, 0.2, 0.2],
       [0.2, 0.65714286, 0. , 0. , 0. , 0.14285714, 0. , 0. ],
       [0. , 0. , 0.29047619, 0.2 , 0.16666667, 0.14285714, 0.2 , 0.],
       [0. , 0. , 0.2 , 0.29047619, 0.16666667, 0.14285714, 0. , 0.2],
       [0. , 0. , 0.16666667, 0.16666667, 0.19047619, 0.14285714, 0.16666667, 0.16666667],
       [0.14285714, 0.14285714, 0.14285714, 0.14285714, 0.14285714, 0.14285714, 0. , 0.14285714],
       [0.2 , 0. , 0.2 , 0. , 0.16666667, 0. , 0.43333333, 0.],
       [0.2 , 0. , 0. , 0.2 , 0.16666667, 0.14285714, 0. , 0.29047619]])
        degree = np.sum(W[rank] > 0) - 1
        peer_list = np.setdiff1d(np.where(W[rank] > 0)[0], rank)
        local_weight = W[rank, rank]
    else:
        W = 1/8 * np.ones((8, 8))
        degree = 7
        peer_list = np.setdiff1d(np.arange(8), rank)
        local_weight = W[rank, rank]
    N = args.N
    heter_rate = args.heter_rate
    noise_rate = args.noise_rate
    r = torch.zeros(d)
    
    outer_steps = args.epochs
    eval_interval = args.eval_interval
    T = args.iterations
    
    # gamma (beta) and alpha respectively in GBDSBO
    inner_lr = args.inner_lr
    outer_lr = args.outer_lr
    
    higp_lr = args.higp_lr
    jhip_lr = args.jhip_lr

    # synthetic data generation
    #set_seed(args.seed)
    local_n = n // num_of_nodes
    
    w_oracle, noise = None, None
    if rank == 0:
        #generate the oracle and noise on node 0
        w_oracle = np.random.randn(d)
        noise = np.random.randn(local_n)
    
    w_oracle = comm.bcast(w_oracle, root=0)
    noise = comm.bcast(noise, root=0)

    #heterogeneous data generation
    x_train = heter_rate * (rank + 1) * np.random.randn(local_n, d)
    x_val = heter_rate * (rank + 1) * np.random.randn(local_n, d)
    y_train = x_train @ w_oracle + noise_rate * noise
    y_val = x_val @ w_oracle + noise_rate * noise
    y_train = (y_train > 0.).astype(float)
    y_val = (y_val > 0.).astype(float)

    x_train, x_val, y_train, y_val = frnp(x_train), frnp(x_val), frnp(y_train), frnp(y_val)
    
    params = [torch.randn(d).requires_grad_(True)]  #omega
    params_theta = [torch.randn(d).requires_grad_(True)]  #theta
    hparams = [torch.randn(d).requires_grad_(True)]  #lambda

    x_train_full = torch.clone(x_train)
    y_train_full = torch.clone(y_train)
    x_val_full = torch.clone(x_val)
    y_val_full = torch.clone(y_val)
    
    x_train_full = comm.gather(x_train_full, root=0)
    y_train_full = comm.gather(y_train_full, root=0)
    x_val_full = comm.gather(x_val_full, root=0)
    y_val_full = comm.gather(y_val_full, root=0)
 
    hparams_mean = torch.clone(hparams[0])
    params_mean = torch.clone(params[0])
    params_theta_mean = torch.clone(params_theta[0])
    hparams_mean = comm.reduce(hparams_mean, op=MPI.SUM, root=0)
    params_mean = comm.reduce(params_mean, op=MPI.SUM, root=0)
    params_theta_mean = comm.reduce(params_theta_mean, op=MPI.SUM, root=0)
    

    
    if args.alg == 'GBDSBO':
        # maintain MA updates in GBDSBO
        # we do not need s (defined in Algorithm 1 of Yang's paper) since grad_x f = 0
        # the Jacobian is a diagonal matrix so we only maintain a vector u
        mu, Lg, b = args.mu, args.Lg, args.b
        h, u, V, Q = torch.zeros(d), torch.zeros(d), (mu/Lg) * torch.eye(d), mu * torch.eye(d)
        
        # in GBDSBO T = 1
        T, outer_steps = 1, args.total_steps

    if args.alg == 'SUN-SE':
        # SUN-SE
        alpha_x = args.alpha_x
        alpha_y = args.alpha_y  
        alpha_theta = args.alpha_theta
        ck_bar = args.ck_bar
        exp = args.exp
        gamma = args.gamma    

       
    if args.alg == 'SUN-GT':
        # SUN-GT
        # hyperparam: theta; param: x
        alpha_x = args.alpha_x
        alpha_y = args.alpha_y  
        alpha_theta = args.alpha_theta
        ck_bar = args.ck_bar
        exp = args.exp
        gamma = args.gamma    

        u_x, u_y, u_theta = 0, 0, 0
        h_x, h_y, h_theta = 0, 0, 0

    if args.alg == 'DSGDA-GT':
        # SUN-GT
        # hyperparam: theta; param: x
        alpha_x = args.alpha_x
        alpha_y = args.alpha_y  
        alpha_theta = args.alpha_theta
        ck_bar = args.ck_bar
        exp = args.exp
        gamma = args.gamma    

        u_x, u_y, u_theta = 0, 0, 0
        h_x, h_y, h_theta = 0, 0, 0

    if args.alg == 'SUN-HR':
        # SUN-HR
        alpha_x = args.alpha_x
        alpha_y = args.alpha_y  
        alpha_theta = args.alpha_theta
        ck_bar = args.ck_bar
        exp = args.exp
        gamma = args.gamma    
        beta_x = args.beta_x
        beta_y = args.beta_y
        beta_theta = args.beta_theta
        u_x, u_y, u_theta = 0, 0, 0
        v_x, v_y, v_theta = 0, 0, 0
        h_x, h_y, h_theta = 0, 0, 0    
        params_old = [torch.randn(d).requires_grad_(True)]
        hparams_old = [torch.randn(d).requires_grad_(True)]
        params_theta_old = [torch.randn(d).requires_grad_(True)]

    if rank == 0:
        all_results = np.zeros((outer_steps+1, 6)) #train loss, test loss, test acc, grad norm, time, communicated blocks
        
        x_test = np.random.randn(args.test_size, d)
        y_test = x_test @ w_oracle
        y_test = (y_test > 0.).astype(float)
        x_test, y_test = frnp(x_test), frnp(y_test)
        all_results[0, 2] = eval_acc([params_mean], x_test, y_test)
        
        x_train_full = torch.stack(x_train_full)
        y_train_full = torch.stack(y_train_full)
        x_val_full = torch.stack(x_val_full)
        y_val_full = torch.stack(y_val_full)
        
        train_loss = F.binary_cross_entropy_with_logits(x_train_full @ params_mean, y_train_full)
        all_results[0, 0] = tonp(train_loss)
        test_loss = F.binary_cross_entropy_with_logits(x_test @ params_mean, y_test)
        all_results[0, 1] = tonp(test_loss)
        
        hparams_mean = hparams_mean.clone()
        params_mean = params_mean.clone()
        params_theta_mean = params_theta_mean.clone()

        hparams_mean = hparams_mean / size
        params_mean = params_mean / size
        params_theta_mean = params_theta_mean / size

        hparams_history = [hparams_mean]
        params_history = [params_mean]
        params_theta_history = [params_theta_mean]

        total_time = 0
        
    # losses
    def test_loss(params, hparams, x, y):
        test_loss = F.binary_cross_entropy_with_logits(x @ params[0], y)
        return test_loss

    def local_logistic_loss():
        def f(w):
            x_train_batch, y_train_batch = data_batch(args.batch_size, x_train, y_train)
            return F.binary_cross_entropy_with_logits(x_train_batch @ w, y_train_batch)
        return f

    def global_logistic_loss():
        def f(w):
            return F.binary_cross_entropy_with_logits(x_train_full @ w, y_train_full)
        return f


    def grad_y(params, hparams, x, y, reg=True):
        # grad wrt model parameters 
        # lower level: reg=True 
        # upper level: reg=False
        x_train_batch, y_train_batch = data_batch(args.batch_size, x, y)
        loss = F.binary_cross_entropy_with_logits(x_train_batch @ params[0], y_train_batch)
        
        if reg:
            gy = torch.autograd.grad(loss, params)[0] + torch.exp(hparams[0]) * params[0]
        else:
            gy = torch.autograd.grad(loss, params)[0]

        return gy
    
    def grad_x(params, hparams, x, y, reg=True):
        # grad wrt model parameters 
        # lower level: reg=True 
        # upper level: reg=False
        x_train_batch, y_train_batch = data_batch(args.batch_size, x, y)
        loss = F.binary_cross_entropy_with_logits(x_train_batch @ params[0], y_train_batch)
        
        if reg:
            gx =  0.5 * torch.exp(hparams[0]) * (params[0] ** 2)

        else:
            gx = 0

        return gx 
    
    def grad_y_hr(params, hparams, params_old,hparams_old, x, y, reg=True):
        # grad wrt model parameters 
        # lower level: reg=True 
        # upper level: reg=False
        x_train_batch, y_train_batch = data_batch(args.batch_size, x, y)
        loss = F.binary_cross_entropy_with_logits(x_train_batch @ params[0], y_train_batch)
        loss1 = F.binary_cross_entropy_with_logits(x_train_batch @ params_old[0], y_train_batch)
        if reg:
            gy = torch.autograd.grad(loss, params)[0] + torch.exp(hparams[0]) * params[0]
            gy_old = torch.autograd.grad(loss1, params_old)[0] + torch.exp(hparams_old[0]) * params_old[0]
        else:
            gy = torch.autograd.grad(loss, params)[0]
            gy_old = torch.autograd.grad(loss1, params_old)[0]
        return gy, gy_old
    
    def grad_x_hr(params, hparams, params_old, hparams_old, x, y, reg=True):
        # grad wrt model parameters 
        # lower level: reg=True 
        # upper level: reg=False
        x_train_batch, y_train_batch = data_batch(args.batch_size, x, y)
        loss = F.binary_cross_entropy_with_logits(x_train_batch @ params[0], y_train_batch)
        
        if reg:
            gx =  0.5 * torch.exp(hparams[0]) * (params[0] ** 2)
            gx_old = 0.5 * torch.exp(hparams_old[0]) * (params_old[0] ** 2)
        else:
            gx = 0
            gx_old = 0
        return gx , gx_old



    def gossip(info):
        # communication and do the average step

        # starting the monitoring
        tracemalloc.start()
        if topology == 'ring':
            comm.send(info, dest=(rank+1)%size, tag=rank)
            comm.send(info, dest=(rank-1)%size, tag=rank+size)
            info_for = comm.recv(source=(rank+1)%size, tag=(rank+1)%size + size)
            info_back = comm.recv(source=(rank-1)%size, tag=(rank-1)%size)
            output = weight_neighbor * (info_for + info_back) + weight * info
        else:
            info_recv = [MPI.REQUEST_NULL for _ in range(degree)]
            for ind, peer_id in enumerate(peer_list):
                comm.send(info, dest=peer_id)
            for ind, peer_id in enumerate(peer_list):
                info_recv[ind] = comm.recv(source=peer_id)
            output = local_weight * info
            for ind, peer_id in enumerate(peer_list):
                output += W[rank, peer_id] * info_recv[ind]

        # the memory
        memory_usage = tracemalloc.get_traced_memory()
        memory_usage = memory_usage[1] - memory_usage[0]

        # stopping the library
        tracemalloc.stop()
        global com_blocks
        com_blocks += memory_usage
        
        return output
    
    def HIGP(params, hparams, N, higp_lr=0.01, get_memory=False):
        z = torch.zeros(d)
        g_old = torch.zeros(d)
        
        H = torch.autograd.functional.hessian(local_logistic_loss(), params[0])
        y = -grad_y(params, hparams, x_val, y_val, reg=False)
        
        for i in range(N):
            z = gossip(z) - higp_lr * y
            g_new = H @ z + torch.exp(hparams[0]) * z
            y = gossip(y) + g_new - g_old
            g_old = g_new
        return z
    
    def JHIP(params, hparams, N, jhip_lr=0.01):
        Z = torch.zeros((d,d))
        G_old = torch.zeros((d,d))
        
        exp_temp = torch.exp(hparams[0])
        
        H = torch.autograd.functional.hessian(local_logistic_loss(), params[0]) + torch.diag(exp_temp)
        Y = -torch.diag(exp_temp * params[0]) 
        
        for i in range(N):
            Z = gossip(Z) - jhip_lr * Y
            G_new = H @ Z
            Y = gossip(Y) + G_new - G_old
            G_old = G_new
        return Z

    def inner_dsgt(hparams, params, gamma, u, h, T, alpha=0):
        # dsgt inner loop
        # hyperparam: theta; param: x
        for t in range(T):
            h_new = grad_y(params, hparams, x_train, y_train, reg=False)

            # alpha > 0 when the inner loop is used to update y, otherwise z
            if alpha > 0:
                h_new = (1 + alpha) * h_new + alpha * torch.exp(hparams[0]) * params[0]
            else:    
                h_new = h_new + torch.exp(hparams[0]) * params[0]

            # communication and update 
            u = gossip(u) + h_new - h
            params[0] = gossip(params[0]) - gamma * u

            # update h
            h = h_new.clone().detach()

        return params, u, h_new

    def inner_dsgd(hparams, params, gamma, T):
        # dsgd inner loop
        # hyperparam: theta; param: x
        for t in range(T):
            inner_grad = grad_y(params, hparams, x_train, y_train, reg=True)

            # communication and update
            params[0] = gossip(params[0]) - gamma * inner_grad

        return params

    for o_step in range(outer_steps):
        if rank == 0:
            step_start_time = time.time()

        # inner loop or lower level
        if args.alg in {'GBDSBO', 'DSBO_JHIP', 'MADSBO'}:
            # GBDSBO, DSBO_JHIP, MADSBO only have one inner loop without gradient tracking
            params = inner_dsgd(hparams, params, inner_lr, T)
            
        
        # comptue average to compute losses
        params_mean = comm.reduce(params[0], op=MPI.SUM, root=0)
        
        if rank == 0:
            params_mean /= size
            params_history.append(params_mean)
            # all_results: train loss, test loss, test acc, grad norm, time
            train_loss = F.binary_cross_entropy_with_logits(x_train_full @ params_mean, y_train_full)
            test_loss = F.binary_cross_entropy_with_logits(x_test @ params_mean, y_test)
            all_results[o_step+1, 0] = tonp(train_loss)
            all_results[o_step+1, 1] = tonp(test_loss)
            all_results[o_step+1, 2] = eval_acc([params_mean], x_test, y_test)
        
        # outer loop / upper level
        if args.alg == 'MADSBO':
            z = HIGP(params, hparams, N, higp_lr)
            outer_grad_new = -torch.exp(hparams[0]) * params[0] * z
        
        elif args.alg == 'DSBO_JHIP':
            Z = JHIP(params, hparams, N, jhip_lr)
            outer_grad_y = grad_y(params, hparams, x_val, y_val, reg=False)
            outer_grad_new = -Z.t() @ outer_grad_y
        
        elif args.alg == 'GBDSBO':
            r = - u * (Q @ h)

        # single loop for SUN-SE
        elif args.alg == 'SUN-SE':  
            # SUN-SE
            ck = ck_bar * 1 / ((1+o_step)**exp)
            outer_grad_y = grad_y(params, hparams, x_val, y_val, reg=False)
            inner_grad_y = grad_y(params, hparams, x_train, y_train, reg=True)
            inner_grad_x = grad_x(params, hparams, x_train, y_train, reg=True)
            inner_grad_tehtax = grad_x(params_theta,hparams, x_train, y_train, reg=True)
            inner_grad_theta = grad_y(params_theta, hparams, x_train, y_train, reg=True)   
            h_y_new = ck*outer_grad_y + inner_grad_y +  gamma*(params_theta[0]-params[0])
            h_theta_new =  inner_grad_theta + gamma*(params_theta[0]-params[0])  
            h_x_new = inner_grad_x - inner_grad_tehtax
            params[0] = gossip(params[0] - alpha_y * h_y_new)
            params_theta[0] = gossip(params_theta[0] - alpha_theta * h_theta_new)
            hparams[0] = gossip(hparams[0] - alpha_x * h_x_new )        

        # single loop for SUN-GT
        elif args.alg == 'SUN-GT':  
            # SUN-GT
            ck = ck_bar * 1 / ((1+o_step)**exp)
            outer_grad_y = grad_y(params, hparams, x_val, y_val, reg=False)
            inner_grad_y = grad_y(params, hparams, x_train, y_train, reg=True)
            inner_grad_x = grad_x(params, hparams, x_train, y_train, reg=True)
            inner_grad_tehtax = grad_x(params_theta, hparams, x_train, y_train, reg=True)
            inner_grad_theta = grad_y(params_theta, hparams, x_train, y_train, reg=True)   
            h_y_new = ck*outer_grad_y + inner_grad_y +  gamma*(params_theta[0]-params[0])
            h_theta_new =  inner_grad_theta + gamma*(params_theta[0]-params[0])
            h_x_new = inner_grad_x - inner_grad_tehtax

            u_y = gossip(u_y + h_y_new - h_y)
            u_theta = gossip(u_theta + h_theta_new - h_theta) 
            u_x = gossip(u_x + h_x_new - h_x) 
            
            params[0] = gossip(params[0] - alpha_y * u_y)
            params_theta[0] = gossip(params_theta[0] - alpha_theta * u_theta)
            hparams[0] = gossip(hparams[0] - alpha_x * u_x)
            #communication
            #parameters = gossip(parameters) - alpha_y * u_y
            #parameters_theta =   gossip(parameters_theta) - alpha_theta * u_theta
            
            h_y = h_y_new.clone().detach()
            h_theta = h_theta_new.clone().detach()
            h_x = h_x_new.clone().detach()

        elif args.alg == 'DSGDA-GT':  
            # DSGDA-GT
            ck = ck_bar
            outer_grad_y = grad_y(params, hparams, x_val, y_val, reg=False)
            inner_grad_y = grad_y(params, hparams, x_train, y_train, reg=True)
            inner_grad_x = grad_x(params, hparams, x_train, y_train, reg=True)
            inner_grad_tehtax = grad_x(params_theta,hparams, x_train, y_train, reg=True)
            inner_grad_theta = grad_y(params_theta, hparams, x_train, y_train, reg=True)   
            h_y_new = outer_grad_y + ck* inner_grad_y 
            h_theta_new =  ck *(inner_grad_theta)  
            h_x_new = ck* (inner_grad_x - inner_grad_tehtax)
            
            u_y = gossip(u_y) + h_y_new - h_y
            u_theta = gossip(u_theta) + h_theta_new - h_theta
            u_x = gossip(u_x) + h_x_new - h_x
            
            params[0] = gossip(params[0]) - alpha_y * u_y
            params_theta[0] = gossip(params_theta[0]) - alpha_theta * u_theta
            hparams[0] = gossip(hparams[0]) - alpha_x * u_x
            #communication
            #parameters = gossip(parameters) - alpha_y * u_y
            #parameters_theta =   gossip(parameters_theta) - alpha_theta * u_theta
            
            h_y = h_y_new.clone().detach()
            h_theta = h_theta_new.clone().detach()
            h_x = h_x_new.clone().detach()

        # single loop for SUN-HR
        elif args.alg == 'SUN-HR':  
            # SUN-HR
            ck = ck_bar * 1 / ((1+o_step)**exp)
            outer_grad_y, outer_grad_y_old = grad_y_hr(params, hparams, params_old, hparams_old, x_val, y_val, reg=False)
            inner_grad_y, inner_grad_y_old = grad_y_hr(params, hparams, params_old, hparams_old, x_train, y_train, reg=True)
            inner_grad_x, inner_grad_x_old = grad_x_hr(params, hparams, params_old, hparams_old, x_train, y_train, reg=True)
            inner_grad_tehtax,inner_grad_thetax_old = grad_x_hr(params_theta, hparams, params_theta_old, hparams_old,  x_train, y_train, reg=True)
            inner_grad_theta, inner_grad_theta_old  = grad_y_hr(params_theta, hparams, params_theta_old, hparams_old,  x_train, y_train, reg=True)   
            h_y_new = ck*outer_grad_y + inner_grad_y +  gamma*(params_theta[0]-params[0])
            h_theta_new =  inner_grad_theta + gamma*(params_theta[0]-params[0])
            h_x_new = inner_grad_x - inner_grad_tehtax
            if o_step == 0:
                h_y = 0
                h_theta = 0
                h_x = 0
            else:
                ck_1 = ck_bar * 1 / ((1+o_step-1)**exp)
                h_y = ck_1*outer_grad_y_old + inner_grad_y_old +  gamma*(params_theta_old[0]-params_old[0])
                h_theta =  inner_grad_theta_old + gamma*(params_theta_old[0]-params_old[0])
                h_x = inner_grad_x_old - inner_grad_thetax_old


            
            v_y_new = h_y_new - (1-beta_y) * (v_y-h_y)
            u_y = gossip(u_y + v_y_new - v_y)
            v_theta_new = h_theta_new - (1-beta_theta) * (v_theta-h_theta)
            u_theta = gossip(u_theta + v_theta_new - v_theta)
            v_x_new = h_x_new - (1-beta_x) * (v_x-h_x)
            u_x = gossip(u_x + v_x_new - v_x)
            
            params_old = [params[0].clone().detach().requires_grad_(True)]
            hparams_old = [hparams[0].clone().detach().requires_grad_(True)]
            params_theta_old = [params_theta[0].clone().detach().requires_grad_(True)]

            params[0] = gossip(params[0] - alpha_y * u_y)
            params_theta[0] = gossip(params_theta[0] - alpha_theta * u_theta)
            hparams[0] = gossip(hparams[0] - alpha_x * u_x)
            #communication
            #parameters = gossip(parameters) - alpha_y * u_y
            #parameters_theta =   gossip(parameters_theta) - alpha_theta * u_theta

            #h_y = h_y_new.clone().detach()
            #h_theta = h_theta_new.clone().detach()
            #h_x = h_x_new.clone().detach()

            v_x = v_x_new.clone().detach()
            v_y = v_y_new.clone().detach()
            v_theta = v_theta_new.clone().detach()

        # communication and update
        if args.alg in {'MADSBO', 'DSBO_JHIP', 'GBDSBO'}:
           hparams[0] = gossip(hparams[0]) - outer_lr * r
        

        # moving average for MADSBO
        if args.alg == 'MADSBO':
            r = (1 - outer_lr) * r + outer_lr * outer_grad_new
        elif args.alg == 'DSBO_JHIP':
            r = outer_grad_new
            
        elif args.alg == 'GBDSBO':
            outer_grad_y = grad_y(params, hparams, x_val, y_val, reg=False)
            
            # MA update
            h = (1 - inner_lr) * gossip(h) + inner_lr * outer_grad_y
            u = (1 - inner_lr) * gossip(u) + inner_lr * torch.exp(hparams[0]) * params[0]

            Q = torch.eye(d)
            for i in range(b):
                w = params[0]
                loss = local_logistic_loss()
                H = torch.autograd.functional.hessian(loss, w)
                
                # note that the division step (Lg) is here to reduce complexity
                # so we maintain V divided by Lg here
                V = (1 - inner_lr) * gossip(V) + (inner_lr / Lg) * H
                Q = torch.eye(d) + Q - V @ Q
            Q = Q / Lg
                
        hparams_mean = comm.reduce(hparams[0], op=MPI.SUM, root=0)
        com_blocks_mean = comm.reduce(com_blocks, op=MPI.SUM, root=0)
        
        if rank == 0:
            hparams_mean /= size
            com_blocks_mean /= size
            hparams_history.append(hparams_mean)
            step_time = time.time()-step_start_time
            total_time +=step_time

            w = params_history[-1]
            loss_average = local_logistic_loss()
            H_average = torch.autograd.functional.hessian(loss_average, w)
            loss = F.binary_cross_entropy_with_logits(x_val_full @ w, y_val_full)
            
            outer_grad_y = torch.autograd.grad(loss, w)[0]
            outer_grad_average = -torch.exp(hparams_history[-1]) * w * torch.linalg.solve(H_average, outer_grad_y)

            # all_results: train loss, test loss, test acc, grad norm, time
            all_results[o_step+1, 3] = torch.linalg.norm(outer_grad_average)
            all_results[o_step+1, 4] = step_time
            all_results[o_step+1, 5] = com_blocks_mean

            if o_step % eval_interval == 0 or o_step == outer_steps:
                print('o_step={}({:.2e}s) Hypergrad norm={:.4f} Training Loss={:.4f} Test Loss={:.4f} Accuracy={:.4f} Com_blocks={:.2e}'\
                      .format(o_step+1, step_time, all_results[o_step+1, 3], all_results[o_step+1, 0], all_results[o_step+1, 1],\
                               all_results[o_step+1, 2], all_results[o_step+1, 5]))
                sys.stdout.flush()
    if rank == 0:
        print('total time = {}'.format(total_time))
        return all_results


if __name__ == '__main__':
    args = parse_args()
    all_results = train_model(args)
    args.seed = 0
    if args.alg in {'DSBO_JHIP', 'MADSBO'}:
        filename = f'{args.alg}_{args.network_topology}_inner_{args.iterations}_inner_lr_{args.inner_lr}_outer_lr_{args.outer_lr}_hstep_{args.N}_seed_{args.seed}'
    elif args.alg == 'GBDSBO':
        filename = f'{args.alg}_{args.network_topology}_inner_{args.iterations}_inner_lr_{args.inner_lr}_outer_lr_{args.outer_lr}_b_{args.b}_seed_{args.seed}'
    elif args.alg in {'SUN-SE', 'SUN-GT', 'SUN-HR', 'DSGDA-GT'}:
        filename = f'{args.alg}_{args.network_topology}_N_{args.N}_alpha_x_{args.alpha_x}_alpha_y_{args.alpha_y}_alpha_theta_{args.alpha_theta}_gamma_{args.gamma}_ck_bar_{args.ck_bar}_exp_{args.exp}_seed_{args.seed}'
    if rank == 0:
        save_results(args, all_results)

    MPI.Finalize()
    
