'''
Script to run tabular experiments in batch mode.
'''
import os
import numpy as np
import pandas as pd
import argparse
import sys
import warnings
warnings.filterwarnings("ignore")
sys.path.append("..")
import src.bayes_ofu.environment as environment
import src.bayes_ofu.finite_agents as finite_agents
import src.bayes_ofu.bayesian_model as bayesian_model

from src.bayes_ofu.experiment import run_finite_tabular_experiment



if __name__ == '__main__':
    '''
    Run a tabular experiment according to command line arguments
    '''

    # Take in command line flags
    parser = argparse.ArgumentParser(description='Run tabular RL experiment')
    parser.add_argument('--alg', help='Agent constructor', type=str)
    parser.add_argument('--env', help='Environment constructor', type=str)
    parser.add_argument('--nState', help='number of states', type=int)
    parser.add_argument('--nAction', help='number of actions', type=int, default=5)
    parser.add_argument('--epLen', help='length of episodes', type=int, default=5)
    parser.add_argument('--scaling', help='scaling', type=float, default=1.0)
    parser.add_argument('--alpha', help='k^(1-alpha)', type=float, default=0.5)
    parser.add_argument('--delta', help='level of confidence', type=float, default=0.05)
    parser.add_argument('--seed', help='random seed', type=int, default=1)
    parser.add_argument('--nEps', help='number of episodes', type=int, default=100000)
    parser.add_argument('--nTrials', help='number of independent trials', type=int, default=1)
    parser.add_argument('--nSamp', help='number of posterior samples', type=int, default=2)
    parser.add_argument('--pal', help='Parallel or not', action="store_true")
    parser.add_argument('--v1', help='the v_1^* for lambda_k in our article', type=float,default=0.5)
    parser.add_argument('--v2', help='the v_2^* for lambda_k in our article', type=float, default=0)
    parser.add_argument('--l1', help='the v_1^* for value variation in our article', type=float,default=0.5)
    parser.add_argument('--l2', help='the v_2^* for value variation in our article', type=float, default=-1)
    parser.add_argument('--zeta', help='entropy scaling parameter', type=float,default=1)
    parser.add_argument('--gamma', help='distribution weighted parameter', type=float,default=0.998)
    parser.add_argument('--batch_size', help='mini batch size', type=int, default=0)
    args = parser.parse_args()

    #Generate nTrials random generators for envs and agents, respectively
    np.random.seed(args.seed)
    randomseed=np.random.randint(low=1,high=1e9,size=2*args.nTrials)
    #randomseed=[630311760,413652245]

    # 717354022,592322120
    # 946286477,410430190
    # 784077897,729053693
    # 491264,630361479
    # 550290314,796884250
    # 224766668,166716595
    # 630311760,413652245
    # 396591249,726684927
    # 629559426,271757670
    # 799981517,795511699
    #
    #
    #
    #
    #
    #
    #
    #
    #
    #
    print(randomseed)



    # Make the environment
    env_dict = {'Chain': lambda: [environment.make_stochasticChain(args.nState,randomseed[i]) for i in range(args.nTrials)],
                'RiverSwim': lambda: [ environment.make_riverSwim(args.nState, args.epLen,randomseed[i]) for i in range(args.nTrials)],
                'Random': lambda: [environment.make_random_gaussian_reward_tabular_MDP(
                            np.concatenate(
                                (
                                    np.ones((args.nState, args.nAction, args.nState)),
                                    np.zeros((args.nState, args.nAction, 1)),
                                    np.ones((args.nState, args.nAction, 1))
                                ), axis=-1),
                            args.epLen,randomseed[i]
                ) for i in range(args.nTrials)],
    }

    env_sampler = env_dict[args.env]

    if args.env == 'Chain':
        args.nAction = 2
        args.epLen = args.nState
    elif args.env == 'RiverSwim':
        args.nAction = 2

    if args.alg not in ['OptimisticPSRL', 'OptimisticEnvelopePSRL', 'ROPSRL','BOSS','BPS']:
        args.nSamp = 1

    # Make a filename to identify flags
    fileName = (args.env
                + '_alg=' + str(args.alg)
                + '_scal=' + '%03.2f' % args.scaling
                + '_nSamp=' + '%03d' % args.nSamp
                + '_nState=' + '%03d' % args.nState
                + '_nAction=' + '%03d' % args.nAction
                + '_seed=' + str(args.seed)
                + '_nTrials='+str(args.nTrials)
                + '_zeta='+str(args.zeta)
                + '_parallel='+str(args.pal)
                + '_eplen='+str(args.epLen)
                + '_batchsize='+str(args.batch_size)
                + '_gamma='+str(args.gamma)
                + '.csv')
    folderName = '../test/data'

    maxnumber = max(list(map(int,os.listdir(folderName)))) #get foldlist and turn name to int,and find the max number

    folderName = folderName+'/'+str(maxnumber+1)+'/'
    os.makedirs(folderName)


    targetPath =folderName+ fileName
    print('******************************************************************')
    print(fileName)
    print('******************************************************************')

    # Make the agent
    alg_dict = {'PSRL': finite_agents.PSRL,
                'OptimisticPSRL': finite_agents.OptimisticPSRL,
                'OptimisticEnvelopePSRL': finite_agents.OptimisticEnvelopePSRL,
                'ROPSRL': finite_agents.ROPSRL,
                'BOSS': finite_agents.BOSS,
                'FiniteCBOO': finite_agents.FiniteCBOO,
                'FiniteBOO':finite_agents.FiniteBOO,
                'FiniteCBOOACR': finite_agents.FiniteCBOOACR,
                'TabularCBOO': finite_agents.TabularCBOO,
                'TabularCBOOACR': finite_agents.TabularCBOOACR,
                'UCRL2': finite_agents.UCRL2,
                'EpsilonGreedy': finite_agents.EpsilonGreedy,
                'BPS': finite_agents.BOOviaPS}

    agent_constructor = alg_dict[args.alg]

    if args.env in ['Chain', 'RiverSwim', 'Random']:
        agent = lambda : [agent_constructor(
                            bayesian_model.GaussianRewardSoftmaxBayesianTabularModel(
                                np.concatenate(
                                    (
                                        np.ones((args.nState, args.nAction, args.nState)),
                                        np.zeros((args.nState, args.nAction, 1)),
                                        np.ones((args.nState, args.nAction, 1))
                                    ), axis=-1),batch_size=args.batch_size,seed=randomseed[i+args.nTrials]
                            ),
                            args.epLen, scaling=args.scaling, nSamp=args.nSamp, delta=args.delta,v1=args.v1,v2=args.v2,l1=args.l1,l2=args.l2,zeta=args.zeta,gamma=args.gamma,seed=randomseed[i+args.nTrials]
                            ) for i in range(args.nTrials)]

    # Run the experiment
    run_finite_tabular_experiment(agent, env_sampler, args.nEps, args.nTrials, args.seed,args.pal,args.alg,
                        recFreq=100, fileFreq=1000, targetPath=targetPath,folderName=folderName)

