from scipy import optimize
import numpy as np
from  tqdm import tqdm
import random as random
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
import pickle
import sys
from agents import User_LB, User
from algorithms import DBGD, NCSMD, Doubler, Sparring

if __name__ == "__main__":
    var1 = sys.argv[1]
    var2 = sys.argv[2]
    var3 = sys.argv[3]
    var4 = sys.argv[4]

    rho = int(var1)/100
    user = int(var2)
    repeat = int(var3)
    beta = int(var4)/100  # Equivalent to alpha in main paper theorem 2
    
    random.seed(repeat)
    np.random.seed(repeat)
    path = "/home/user/ICLR_LIHF"  #TODO: change to your own path 
    
    # Simulation setting
    d = 5
    radius = 10
    n_iter = 100
    corrupted = True
    
    # Randomly simulate preference parameters
    theta = np.random.randn(d, 1)
    theta /= np.linalg.norm(theta, axis=0)
    theta = radius * theta
    x_star = -theta
    
    # Set sparring parameter
    r = 0.95*radius
    nu = radius/radius**2/np.sqrt(n_iter)
    delta = (r*radius**2 * d**2/12/n_iter)**(1/3)
    alpha = (radius * d/2/np.sqrt(n_iter)/r)**(1/3)

    # Run sparring
    user_sparring = User(theta = theta, V0 = 2*np.eye(d), corrupted = corrupted, rho = rho, n_iter = n_iter, method = "SP", corruption_mode="LU")
    sparring = Sparring(feature_dim=d, user=user_sparring, radius = radius,  nu=nu, delta = delta, alpha = alpha, corruption_mode="LU")
    sparring.simulate(T = n_iter)

    # Run doubler
    user_doubler = User(theta = theta, V0 = 2*np.eye(d), corrupted = corrupted, rho = rho, n_iter = n_iter, method = "DB", corruption_mode="LU")
    doubler = Doubler(feature_dim=d, user = user_doubler , radius = radius, nu=nu, delta = delta, alpha = alpha, corruption_mode="LU")
    doubler.simulate_lift(T = n_iter)
     
    # Set DBGD parameter
    delta_1 = (n_iter)**(-beta)*np.sqrt(d)
    gamma_1 = (n_iter)**(-0.5)
    
    # Run DBGD
    user_dbgd = User(theta = theta, V0 = 2*np.eye(d), corrupted = corrupted, rho = rho, n_iter = n_iter, method = "DBGD", corruption_mode="LU")
    dbgd = DBGD(feature_dim = d, radius = radius, delta = delta_1, gamma = gamma_1, user = user_dbgd)
    dbgd.simulate(n_iter, user_dbgd)
    
    # Set NC-SMD parameter
    mu = 0.1
    lmbd = 0.05
    eta = np.log(n_iter)**(0.5) / n_iter**(1 - 2*beta)/2/d
    
    # Run NC-SMD
    user_smd = User(theta = theta, V0 = 2*np.eye(d), corrupted = corrupted, rho = rho, n_iter = n_iter, method = "SMD", corruption_mode="LU")
    ncsmd = NCSMD(feature_dim = d, lmbd = lmbd, radius = radius, mu = mu, eta = eta, user = user_smd)
    ncsmd.simulate(n_iter, user_smd)
    
    # Save experiment data
    with open(f"{path}/tradeoff/lu/doubler_result_{rho}_{user}_{repeat}_{beta}.json", 'wb') as f:
        pickle.dump(doubler.regret, f)
    with open(f"{path}/tradeoff/lu/sparring_result_{rho}_{user}_{repeat}_{beta}.json", 'wb') as f:
        pickle.dump(sparring.regret, f)
    with open(f"{path}/tradeoff/lu/dbgd_result_{rho}_{user}_{repeat}_{beta}.json", 'wb') as f:
        pickle.dump(dbgd.regret, f)
    with open(f"{path}/tradeoff/lu/ncsmd_result_{rho}_{user}_{repeat}_{beta}.json", 'wb') as f:
        pickle.dump(ncsmd.regret, f)  

    #------------------------Greedy Attack---------------------------------------------------------------------
    user_doubler = User(theta = theta, V0 = 2*np.eye(d), corrupted = corrupted, rho = rho, n_iter = n_iter, method = "SP", corruption_mode="G")
    doubler = Doubler(feature_dim=d, user = user_doubler , radius = radius, nu=nu, delta = delta, alpha = alpha, corruption_mode="G")
    doubler.simulate_lift(T = n_iter)

    user_sparring = User(theta = theta, V0 = 2*np.eye(d), corrupted = corrupted, rho = rho, n_iter = n_iter, method = "DB", corruption_mode="G")
    sparring = Sparring(feature_dim=d, user=user_sparring, radius = radius,  nu=nu, delta = delta, alpha = alpha, corruption_mode="G")
    sparring.simulate(T = n_iter)

    user_dbgd = User(theta = theta, V0 = 2*np.eye(d), corrupted = corrupted, rho = rho, n_iter = n_iter, method = "DBGD", corruption_mode="G")
    dbgd = DBGD(feature_dim = d, radius = radius, delta = delta_1, gamma = gamma_1, user = user_dbgd)
    dbgd.simulate(n_iter, user_dbgd)

    user_smd = User(theta = theta, V0 = 2*np.eye(d), corrupted = corrupted, rho = rho, n_iter = n_iter, method = "SMD", corruption_mode="G")
    ncsmd = NCSMD(feature_dim = d, lmbd = lmbd, radius = radius, mu = mu, eta = eta, user = user_smd)
    ncsmd.simulate(n_iter, user_smd)
    
    # Save experiment data
    with open(f"{path}/tradeoff/greedy/doubler_result_{rho}_{user}_{repeat}.json", 'wb') as f:
        pickle.dump(doubler.regret, f)
    with open(f"{path}/tradeoff/greedy/sparring_result_{rho}_{user}_{repeat}.json", 'wb') as f:
        pickle.dump(sparring.regret, f)
    with open(f"{path}/tradeoff/greedy/dbgd_result_{rho}_{user}_{repeat}.json", 'wb') as f:
        pickle.dump(dbgd.regret, f)
    with open(f"{path}/tradeoff/greedy/ncsmd_result_{rho}_{user}_{repeat}.json", 'wb') as f:
        pickle.dump(ncsmd.regret, f)  
