from kan import *
import time
import os
import sys

def run():
    
    # Grab the arguments that are passed in
    my_task_id = int(sys.argv[1])
    num_tasks = int(sys.argv[2])
    
    problems = [0,1]
    opt_names = ["LBFGS", "Adam"]
    widths = [1,3,10,30,100]
    depths = [2,3,4]
    lr_ids = [0,1,2,3,4]

    xx, yy, zz, uu, vv = np.meshgrid(problems, opt_names, widths, depths, lr_ids)
    params_ = np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,), zz.reshape(-1,), uu.reshape(-1,), vv.reshape(-1,)]))
    
    indices = np.arange(params_.shape[0])
    
    my_indices = indices[my_task_id:indices.shape[0]:num_tasks]

    for i in my_indices:
        
        
        problem = params_[i][0].astype('int') 
        opt_name = params_[i][1]
        width = params_[i][2].astype('int') 
        depth = params_[i][3].astype('int') 
        lr_id = params_[i][4].astype('int')
        
        lrs_lbfgs = [1e-2,3e-2,1e-1,3e-1,1]
        lrs_adam = [1e-4,3e-4,1e-3,3e-3,1e-2]
        
        seed = 0
        torch.manual_seed(seed)
        np.random.seed(seed)
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if opt_name == "Adam":
            lr = lrs_adam[lr_id]
            steps = 10000
        else:
            lr = lrs_lbfgs[lr_id]
            steps = 200

        # create dataset
        if problem == 0:
            f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
            dataset = create_dataset(f, n_var=2, device=device, train_num=1000)
            
        else:
            def f(x):
                return torch.exp(torch.mean(torch.sin(torch.pi/2*x)**2, dim=1))
            dataset = create_dataset(f, n_var=100, device=device, train_num=1000)
            

        grids = np.array([3,5,10,20,50,100])
        #grids = np.array([3,10])


        train_losses = []
        test_losses = []
        k = 3

        start_time = time.time()

        for i in range(grids.shape[0]):
            if i == 0:
                if problem == 0:
                    model = KAN(width=[2]+(depth-1)*[width]+[1], grid=grids[i], k=k, seed=0, device=device)
                else:
                    model = KAN(width=[100]+(depth-1)*[width]+[1], grid=grids[i], k=k, seed=0, device=device)
            if i != 0:
                model = model.refine(grids[i])
            results = model.fit(dataset, opt=opt_name, steps=steps, lr=lr)
            train_losses += results['train_loss']
            test_losses += results['test_loss']

        end_time = time.time()
        wall_time = end_time - start_time

        np.savetxt(f'./results/kan_w_ge/train_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', train_losses)
        np.savetxt(f'./results/kan_w_ge/test_loss_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', test_losses)
        np.savetxt(f'./results/kan_w_ge/walltime_problem_{problem}_opt_{opt_name}_width_{width}_depth_{depth}_lrid_{lr_id}', [wall_time])


run()
