import argparse
import csv
import json

import matplotlib.pyplot as plt
from datetime import datetime

from algorithms import *
from RobustRL_utils import *
from utils import *





if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Example script to demonstrate argparse')
    parser.add_argument('--run-mode', default='algo1_new', choices=['algo1', 'algo1_new', 'algo5'], help='list servers, storage, or both (default: %(default)s)') 
    args = parser.parse_args()
    
    run_mode = args.run_mode
    data = {}
    data['run_mode'] = run_mode
    
    if run_mode == 'algo1': # Algo1    
        algo1(num_states=10, num_actions=5, K=50, K_prime=50, T=20, T_prime=20, L=10, alpha=0.001, beta=0.02, a=0.01, b=0.02)
    elif run_mode == 'algo5':
        # Algo 5: Nonconvex-Nonconcave Setting
        K = 400
        data['K'] = K
        alpha, beta = 0.008, 0.2
        num_states = 10
        num_actions = 5
        data['num_states'], data['num_actions'] = num_states, num_actions
        
        norms1, complexs = test_bc_pseg_plus(num_states, num_actions, K, alpha, beta, showfigs=['steps', 'complexity'], save_fig=True)
        data['norms'] = norms1.tolist()
        data['complexs'] = complexs.tolist()
        current_time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
        np.savetxt(f'outputs/algo5_{current_time_str}.csv', norms1, delimiter=' ', fmt='%f')
        
        with open(f'json_outputs/{run_mode}_{current_time_str}.json', 'w') as json_file:
            json.dump(data, json_file, indent=4)
            
    elif run_mode == 'algo1_new': 
        # Algo 2 (using Algo 1, Algo 3): Convex Setting
        # K, K_prime = 20, 50
        # K, K_prime = 50, 100
        K, K_prime = 200, 300
        data['K'], data['K_prime'] = K, K_prime
        # T, T_prime = 10, 15
        T, T_prime = 25, 35
        data['T'], data['T_prime'] = T, T_prime
        num_states = 10
        num_actions = 5
        data['num_states'], data['num_actions'] = num_states, num_actions
        
        L = 20
        data['L_xi_xi'] = L
        alpha, beta, a, b = 0.002, 0.001, 0.002, 0.002
        data['alpha'], data['beta'], data['a'], data['b'] = alpha, beta, a, b
        
        # pi, xi, norms = algo1_old(num_states=10, num_actions=5, K=K, K_prime=K_prime, T=25, L=100, alpha=0.001, beta=0.02)
        pi, xi, norms1, complexs = algo1_new(num_states, num_actions, K, K_prime, T, T_prime, L, alpha, beta, a, b, step_size=-1)
        
        norms = norms1
        data['norms'] = norms.tolist()
        data['complexs'] = complexs.tolist()
    
        current_time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
        # np.savetxt(f'outputs/algo1_{current_time_str}.csv', norms1, delimiter=' ', fmt='%f')
        
        # with open(f'json_outputs/{run_mode}_{current_time_str}.json', 'w') as json_file:
        #     json.dump(data, json_file, indent=4)
