#%%

from utils import *
from envs import *
from model import *
import numpy as np
from copy import deepcopy
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--episodes', type=int, required=False, help='episodes', default=50000)
parser.add_argument('--tmax', type=int, required=False, help='tmax', default=300)
parser.add_argument('--obs', type=bool, required=False, help='obs', default=True)
parser.add_argument('--startcoords', type=float,nargs='+', required=False, help='startcoods', default=[[-0.75,-0.75],[-0.75,0.75]])
parser.add_argument('--goalcoords', type=float,nargs='+', required=False, help='goalcoords', default=[[0.75,-0.75]])
parser.add_argument('--obscoords', type=float,nargs='+', required=False, help='obscoords', default=[[-0.2,0.2,-1,0.5]])
parser.add_argument('--rsz', type=float, required=False, help='rsz', default=0.1)
parser.add_argument('--rmax', type=int, required=False, help='rmax', default=5)

parser.add_argument('--seed', type=int, required=False, help='seed', default=0)
parser.add_argument('--pcinit', type=str, required=False, help='pcinit', default='uni')
parser.add_argument('--npc', type=int, required=False, help='npc', default=21)
parser.add_argument('--alpha', type=float, required=False, help='alpha', default=1)
parser.add_argument('--sigma', type=float, required=False, help='sigma', default=0.01)

parser.add_argument('--plr', type=float, required=False, help='plr', default=0.01)
parser.add_argument('--clr', type=float, required=False, help='clr', default=0.01)
parser.add_argument('--llr', type=float, required=False, help='llr', default=0.0001) 
parser.add_argument('--alr', type=float, required=False, help='alr', default=0.0001) 
parser.add_argument('--slr', type=float, required=False, help='slr', default=0.0001)
parser.add_argument('--gamma', type=float, required=False, help='gamma', default=0.95)
parser.add_argument('--nact', type=int, required=False, help='nact', default=4)
parser.add_argument('--beta', type=float, required=False, help='beta', default=1)

parser.add_argument('--balpha', type=float, required=False, help='balpha', default=0.0)
parser.add_argument('--paramsindex', type=int,nargs='+', required=False, help='paramsindex', default=[0,1,2])
parser.add_argument('--noise', type=float, required=False, help='noise', default=0.000)

parser.add_argument('--analysis', type=str, required=False, help='analysis', default='drift')
parser.add_argument('--datadir', type=str, required=False, help='datadir', default='./data/')
parser.add_argument('--figdir', type=str, required=False, help='figdir', default='./fig/')
parser.add_argument('--csvname', type=str, required=False, help='csvname', default='results')
args, unknown = parser.parse_known_args()


# training params
train_episodes = args.episodes
tmax = args.tmax
obs = args.obs

# env pararms
envsize = 1
maxspeed = 0.1
goalsize = args.rsz
startcoord = args.startcoords
goalcoords = args.goalcoords
obscoords = args.obscoords
seed = args.seed
max_reward = args.rmax

#agent params
npc = args.npc**2
sigma = args.sigma
alpha = args.alpha
nact = args.nact

# noise params
noise = args.noise
paramsindex = args.paramsindex
piname = ''.join(map(str, paramsindex))
pcinit = args.pcinit

actor_eta = args.plr
critic_eta = args.clr
pc_eta = args.llr
sigma_eta = args.slr
constant_eta = args.alr
etas = [pc_eta, sigma_eta,constant_eta, actor_eta,critic_eta]
gamma = args.gamma
balpha = args.balpha

save_figs= True
savevar = True

exptname = f'2D_obs_td_multi_{noise}ns_{piname}p_{npc}n_{actor_eta}plr_{critic_eta}clr_{pc_eta}llr_{constant_eta}alr_{sigma_eta}slr_{pcinit}_{nact}a_{seed}s_{train_episodes}e_{max_reward}rmax_{goalsize}rsz'
figdir = './fig/'
datadir = './data/'

print(exptname)

if pcinit=='uni':
    params = uniform_2D_pc_weights(npc, nact, seed, sigma=sigma, alpha=alpha, envsize=envsize)
elif pcinit == 'rand_all':
        params = random_all_pc_weights(npc, nact, seed, sigma=sigma, alpha=alpha, envsize=envsize)

initparams = deepcopy(params)
plot_all_pc([initparams],0)

# inner loop training loop
def run_trial(params, env, trial):
    coords = []
    actions = []
    rewards = []
    tds = []

    state, goal, eucdist, done = env.reset()
    totR = 0
    
    for t in range(tmax):

        pcact = predict_placecell(params, state)

        aprob = predict_action_prob(params, pcact)

        onehotg = get_onehot_action(aprob, nact=nact)

        newstate, reward, done = env.step(onehotg) 

        params, grads, td = learn(params, reward, newstate, state, onehotg,aprob, gamma, etas,balpha, noise, paramsindex)

        coords.append(state)
        actions.append(onehotg)
        rewards.append(reward)
        tds.append(td**2)

        state = newstate.copy()

        totR += reward

        if done:
            break

    return np.array(coords), np.array(rewards), np.array(actions),np.sum(tds), t, params


#%%
losses = []
latencys = []
allcoords = []
logparams = []
logparams.append(initparams)
cum_rewards = []
all_rewards = []

for goalcoord in goalcoords:

    for obscoord in obscoords:
        env = NDimNav(startcoord=startcoord, goalcoord=goalcoord, goalsize=goalsize, tmax=tmax, 
                        maxspeed=maxspeed,envsize=envsize, nact=nact, max_reward=max_reward, obstacles=obs, obscoord=obscoord)

        for episode in range(train_episodes):

            coords, rewards, actions,tds, latency, params = run_trial(params, env, episode)

            discount_rewards = get_discounted_rewards(rewards, gamma)

            allcoords.append(coords)
            logparams.append(deepcopy(params))
            latencys.append(latency)
            losses.append(tds)
            all_rewards.append(env.total_reward)

            print(f'Goal {goalcoords}, Trial {episode+1}, G {env.total_reward:.3f}, t {latency}, L {tds:.3f}')


if savevar:
    saveload(datadir+exptname, [logparams, all_rewards, allcoords], 'save')


env.plot_trajectory()
plot_all_pc(logparams,-1)
f,score, drift = plot_analysis(logparams, latencys,all_rewards, allcoords, train_episodes//2, exptname=exptname, rsz=goalsize)

if save_figs:
    f.savefig(figdir+exptname+'.svg')


thresholds = [0.1, 0.2, 0.25, 0.3]
trials = [0,train_episodes//4, train_episodes]
grs = np.zeros([len(thresholds), len(trials)])
for t, trial in enumerate(trials):
    for e, threshold in enumerate(thresholds):
        grs[e,t] = np.sum(np.linalg.norm(goalcoords[0] - logparams[trial][0],axis=1) < threshold)
    

f,ax = plt.subplots(1,len(trials),figsize=(3*len(trials),2*1))

for t,trial in enumerate(trials):
    xy = logparams[trial][0]
    ax[t].scatter(xy[:,0], xy[:,1],s=2,color='k')
    ax[t].set_aspect('equal')
f.tight_layout()