from solvers import *
import matplotlib.pyplot as plt
# Run K times and plot
def exp(K):
    dic = {"x_dim": 10, "constraint_dim": 2, "A_range": [0, 2], "a_range": [0, 1]}
    RECOO_loss = []
    RECOO_constraint = []
    Xinlei_loss = []
    Xinlei_constraint = []
    Xinlei2_loss = []
    Xinlei2_constraint = []
    Yuan_loss = []
    Yuan_constraint = []

    def listmean(a, b):
        c = []
        for i in range(len(a)):
            c.append(a[i] / b)
        return c

    def list_add(a, b):
        c = []
        for i in range(len(a)):
            c.append(a[i] + b[i])
        return c

    for _ in range(K):
        prob = COCO(dictionary = dic)

        # RECOO exp
        a = RECOO(problem = prob)
        a.run()
        RECOO_loss.append(a.loss_list)
        RECOO_constraint.append(a.constraint_violation_list)
        print(a.constraint_violation_list)

        # Algorithm 1 in Yi exp
        b = Alg1_Yi(problem=prob,alpha = 2/math.sqrt(5000),gamma = math.pow(5000, 1 / 4))
        b.run()
        Xinlei_loss.append(b.loss_list)
        Xinlei_constraint.append(b.constraint_violation_list)

        # Algorithm 2 in Yi exp
        c = Alg2_Yi(alpha =[2 / math.pow(5000, 1 / 2),2 / math.pow(5000, 1 / 2),2 / math.pow(5000, 1 / 2),2 / math.pow(5000, 1 / 2),2 / math.pow(5000, 1 / 2)],gamma = math.pow(5000, 1 / 4),    beta = 3 / math.pow(5000, 1 / 2),problem = problem)
        c.run()
        Xinlei2_loss.append(c.loss_list)
        Xinlei2_constraint.append(c.constraint_violation_list)

        # Algorithm in Yuan exp
        d = Alg_Yuan(problem = prob, eta=1 / (8*np.sqrt(2*5000)), sigma=125)
        d.run()
        Yuan_loss.append(d.loss_list)
        Yuan_constraint.append(d.constraint_violation_list)

    RECOOloss_mean = RECOO_loss[0]
    for i in range(1,K):
        RECOOloss_mean = list_add(RECOOloss_mean,RECOO_loss[i])
    RECOOloss_mean = listmean(RECOOloss_mean,K)

    RECOOcon_mean = RECOO_constraint[0]
    for i in range(1,K):
        RECOOcon_mean = list_add(RECOOcon_mean,RECOO_constraint[i])
    RECOOcon_mean = listmean(RECOOcon_mean,K)

    Xloss_mean = Xinlei_loss[0]
    for i in range(1,K):
        Xloss_mean = list_add(Xloss_mean,Xinlei_loss[i])
    Xloss_mean = listmean(Xloss_mean,K)

    Xconstraint_mean = Xinlei_constraint[0]
    for i in range(1,K):
        Xconstraint_mean = list_add(Xconstraint_mean,Xinlei_constraint[i])
    Xconstraint_mean = listmean(Xconstraint_mean,K)

    X2loss_mean = Xinlei2_loss[0]
    for i in range(1,K):
        X2loss_mean = list_add(X2loss_mean,Xinlei2_loss[i])
    X2loss_mean = listmean(X2loss_mean,K)

    X2constraint_mean = Xinlei2_constraint[0]
    for i in range(1,K):
        X2constraint_mean = list_add(X2constraint_mean,Xinlei2_constraint[i])
    X2constraint_mean = listmean(X2constraint_mean,K)

    Yloss_mean = Yuan_loss[0]
    for i in range(1,K):
        Yloss_mean = list_add(Yloss_mean,Yuan_loss[i])
    Yloss_mean = listmean(Yloss_mean,K)

    Yconstraint_mean = Yuan_constraint[0]
    for i in range(1,K):
        Yconstraint_mean = list_add(Yconstraint_mean,Yuan_constraint[i])
    Yconstraint_mean = listmean(Yconstraint_mean,K)


    plt.plot(RECOOloss_mean, label="RPCOO", linewidth=2.0)
    plt.plot(Xloss_mean, label="Algorithm 1 in [26]", linewidth=2.0)
    plt.plot(X2loss_mean, label="Algorithm 2 in [26]", linewidth=2.0)
    plt.plot(Yloss_mean, label="Algorithm 1 in [31]", linewidth=2.0)
    plt.xlabel("Rounds", fontsize=18)
    plt.ylabel("Loss", fontsize=18)
    plt.legend(fontsize=13)
    plt.tick_params(labelsize=13)
    plt.show()

    plt.plot(RECOOcon_mean, label="RPCOO", linewidth=2.0)
    plt.plot(Xconstraint_mean, label="Algorithm 1 in [27]", linewidth=2.0)
    plt.plot(X2constraint_mean, label="Algorithm 2 in [27]", linewidth=2.0)
    plt.plot(Yconstraint_mean, label="Algorithm 1 in [31]", linewidth=2.0)
    plt.xlabel("Rounds", fontsize=18)
    plt.ylabel("Constraint Violation", fontsize=18)
    plt.legend(fontsize=13)
    plt.tick_params(labelsize=13)
    plt.show()

if __name__ == '__main__':
    trial = 1
    exp(trial)
