from utility import* 
from algorithm import*
import torch
from torch.utils.tensorboard import SummaryWriter
import time
import argparse
alg_dict = {'batch_linear_ctd':batch_pmf_sgd,'batch_linear_ctd_pre':batch_cdf_sgd}


def opt(writer,alg_name,replay_buffer,K,n,gamma,theta,F,alpha,state_space,pi):
    alg = alg_dict[alg_name]
    if alg != None:
        w, losslist,valuelist, pmflosslist = alg(replay_buffer,K,n,gamma,theta,F,alpha,state_space,pi=pi,writer=writer)
    return w, losslist,valuelist,pmflosslist

if __name__ == '__main__':
    torch.set_default_dtype(torch.float64)
    date = time.strftime('%Y%m%d',time.localtime())
    parser = argparse.ArgumentParser()
    
    ### Experimental Setups ###
    parser.add_argument("--exp", default='exp_0', type=str, help='Experiment number')                       
    parser.add_argument("--dir", default=f"LCTD/runs/{date}", type=str)                  
    parser.add_argument("--seed", default=0, type=int, help='Seeding for torch')       
    parser.add_argument("--device", default=5, type=int, help='Device for torch')
    ### Parameters ###
    parser.add_argument("--K", default='5', type=int, help='Number of anchors')
    parser.add_argument("--n", default='3', type=int, help='Dimension of feature vector')
    parser.add_argument("--gamma", default=0.5, type=float, help='Discount')
    parser.add_argument("--s", default='3', type = int, help='If discrete environment, the number of states')
    parser.add_argument("--r_number", default='3', type = int, help='Sample discrete reward')
    parser.add_argument("--sample_number", default='500000', type = int, help='Number of replay buffer')
    parser.add_argument("--alg_name", default = 'linear_ctd',type = str, help='Name of the algorithm')
    parser.add_argument("--alpha", default=0.0005,type = float, help='Learning rate')
    args = parser.parse_args()
    exp = args.exp
    dir = args.dir
    seed = args.seed
    torch.manual_seed(seed)
    K = args.K
    n = args.n
    gamma = args.gamma
    device = args.device
    if device is not None:
        device = torch.device(f'cuda:{device}' if torch.cuda.is_available() else 'cpu')
    else:
        device ='cpu'
    torch.set_default_device(device)
    s = args.s
    r_number = args.r_number
    sample_number = args.sample_number
    alpha = args.alpha

    F = torch.tensor([[0.4629, 0.4825, 0.7436],
    [0.5888, 0.6526, 0.4769],
    [0.4394, 0.7526, 0.4905]])
    print('Feature matrix',F)
    S = torch.arange(0,s).long()
    p_a = torch.tensor([[0.4242, 0.5493, 0.0265],
    [0.3504, 0.3526, 0.2970],
    [0.2834, 0.5602, 0.1564]])
    print('Transition',p_a)
    r = torch.tensor([[[0.0117, 0.4665, 0.5218],
        [0.4397, 0.3620, 0.1983],
        [0.3042, 0.4752, 0.2206]],
    [[0.4760, 0.0134, 0.5106],
        [0.1220, 0.7912, 0.0868],
        [0.0147, 0.5061, 0.4792]],
    [[0.1271, 0.5374, 0.3355],
        [0.5684, 0.2574, 0.1742],
        [0.3880, 0.1065, 0.5055]]])
    print('Reward distribution of each transition',r)
    reward_list = torch.tensor([0.0000, 0.5000, 1.0000])
    alg_name = args.alg_name
    theta = torch.linspace(0,1/(1-gamma),K)
    pi = Policy(S,p_a,r,reward_list)
    replay_buffer = torch.load('LCTD/replay_buffer').to(device=device)
    # replay_buffer = replay_buffer[:1000000]
    writer = SummaryWriter(log_dir=dir+f'/seed_{seed}_K_{K}/'+exp)
    w, losslist,valuelist,pmflosslist = opt(writer,alg_name,replay_buffer,K,n,gamma,theta,F,alpha,S,pi=pi)

