import numpy as np
import math
import gym
import torch
import time
from ddpg import DDPGAgent
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam

tracelim=150
tracebuffersize=300
spike_decay=0.9
tracebuffer_batchsize=32
N_epochs=10000
critic_lr = 1e-5
actor_lr = 1e-5
gamma = 0.95
tau = 1e-2
buffer_maxlen = 100000
thresh=1.0
max_steps_sub=50
def build_tracebuffer(statelog,tracebuffer,success_flag):
    sz=np.shape(statelog)
    if success_flag==1:
        lambd=1
    else:
        lambd=0
    fliplog=np.flipud(statelog)
    for i in range(sz[0]):
        if len(tracebuffer)>tracebuffersize:
            #print(type(tracebuffer))
            tracebuffer=np.delete(tracebuffer,0,axis=0)
            tracebuffer=tracebuffer.tolist()
        ss=np.append(fliplog[i],lambd)
        tracebuffer.append(ss)
        lambd=lambd*0.99
    #if success_flag==1:
    #time.sleep(30)
    return tracebuffer


def learn_trace_model(model,tracebuffer,tracebuffer_neg,runno,learn_prog):

    if len(tracebuffer)>tracebuffer_batchsize and len(tracebuffer_neg)>tracebuffer_batchsize:
        x=[]
        y=[]
        for jj in range(tracebuffer_batchsize):
            if np.random.rand()>0.5:
                curr_buffer=tracebuffer_neg
            else:
                curr_buffer=tracebuffer
            #print(len(curr_buffer))
            ind=np.random.randint(len(curr_buffer))
            #print(curr_buffer[ind])
            if len(x)==0:
                x=np.array([curr_buffer[ind][0:9]])
                y=np.array([curr_buffer[ind][9]])
            else:
                x=np.vstack((x,np.array([curr_buffer[ind][0:9]])))
                y=np.vstack((y,np.array([curr_buffer[ind][9]])))
        #print('model training..')
        #model.fit(x,y+modelcopy.predict(x),verbose=0)
        for i in range(int(N_epochs*learn_prog)):
            model.fit(x,y,verbose=0)
            model.save("trace"+str(runno)+".h5")
            #print("Training done")

    return model#,modelcopy


def execute_subpols(st_state_orig,retall,gp,env,max_steps,maxsteps_total,batch_size,agent,stepcount,target_state,subagent):
    print("Executing subgoal...")
    #subagent = DDPGAgent(env, gamma, tau, buffer_maxlen, critic_lr, actor_lr)
    episodes=1
    

    state=env.reset_xy([st_state_orig[0],st_state_orig[1]])
    next_state=state
    action = subagent.get_action(state)

    dist=np.linalg.norm(gp-np.array([st_state_orig[0],st_state_orig[1]]))
    if dist<thresh:
        done=True
    else:
        done=False
    #print(done)

    dist_sub=np.linalg.norm(target_state-np.array([st_state_orig[0],st_state_orig[1]]))
    #for i in range(episodes):
    if dist_sub<thresh:
        done_sub=True
    else:
        done_sub=False
    #print(done_sub)
    '''
    if done_sub==True:
        print("Subfound prematurely:")
        print(target_state)
        print(np.array([st_state_orig[0],st_state_orig[1]]))
        time.sleep(1)
    '''
    #print("Start and target states:")
    #print(st_state_orig)
    #print(target_state)
    episode_cnt=0
    while dist_sub>thresh and stepcount<=maxsteps_total and episode_cnt<episodes:
        state=env.reset_xy([st_state_orig[0],st_state_orig[1]])
        episode_reward = 0
        #print(i)
        episode_cnt+=1
        print("Stepcount:")
        print(stepcount)

        for step in range(max_steps_sub):
            #if 0.8>np.random.rand():
            
            if (stepcount/maxsteps_total)>np.random.rand():
                action = agent.get_action(state)
            else:
                action = subagent.get_action(state)


            next_state, reward, done, _ = env.step(action)

            
            ###########################
            #statelog.append(np.append(state,action))
            botpos=np.array(env.env.wrapped_env.get_xy())
            dist=np.linalg.norm(botpos-gp)
            dist_sub=np.linalg.norm(botpos-target_state)

            if dist<thresh:
                done=True
                reward=1
                print("Goal executed to completion")
            else:
                done=False
                reward=-0.1

            if dist_sub<thresh:
                done_sub=True
                reward_sub=1
                print(target_state)
                #print("Subgoal found")
            else:
                done_sub=False
                reward_sub=-0.1                

            retall.append(reward)
            stepcount+=1
            subagent.replay_buffer.push(state, action, reward_sub, next_state, done_sub)
            agent.replay_buffer.push(state, action, reward, next_state, done)
            episode_reward += reward
            if len(subagent.replay_buffer) > batch_size:
                subagent.update(batch_size)
            if len(agent.replay_buffer) > batch_size:
                agent.update(batch_size)   
            
            if done_sub==True or step == max_steps_sub-1 or done==True:    
                break   
            state = next_state

        if done_sub==True or done==True:
            break
 
    return next_state,retall,agent,stepcount,done,subagent

def findsubgoal(model,st_state,st_state_orig,retall,gp,env,max_steps,maxsteps_total,batch_size,agent,stepcount,subgoal_targs):
    print("Finding subgoal...")
    subagent = DDPGAgent(env, gamma, tau, buffer_maxlen, critic_lr, actor_lr)
    episodes=100
    #max_steps_sub=50
    subgoal_thresh=1.0
    #state = st_state_orig#env.reset()
    state=env.reset_xy([st_state_orig[0],st_state_orig[1]])
    action = subagent.get_action(state)
    state_sa=[]
    state_sa.append(np.append(state,action))
    dummy_state=np.vstack((state_sa,state_sa))
    rew_trace=model.predict(dummy_state)[0][0]
    target_state_orig=st_state_orig
    #print(rew_trace)
    target_state=gp
    statelog=[]
    sa_log=[]
    success_flag=0
    #stepcount=0
    targ_found=0
    all_found=0

    print(st_state_orig)
    for i in range(episodes):
        #state=st_state_orig
        state=env.reset_xy([st_state_orig[0],st_state_orig[1]])

        episode_reward = 0
        print("Start:")
        print(state)
        print(i)
        for step in range(max_steps_sub):
            action = subagent.get_action(state)
            next_state, reward_sub, done_sub, _ = env.step(action)
            #env.render()            
            ###########################
            
            statelog.append(np.append(state,action))
            botpos=np.array(env.env.wrapped_env.get_xy())
            dist=np.linalg.norm(botpos-gp)
            dist_sub=np.linalg.norm(botpos-target_state)

            if dist<thresh:
                done=True
                reward=1
                success_flag=1
                success_cnt=stepcount
            else:
                done=False
                reward=-0.1


            if dist_sub<thresh:
                done_sub=True
                reward_sub=1
                print(target_state)
                print("Subgoal found")
            else:
                done_sub=False
                reward_sub=-0.1                

            retall.append(reward)
            stepcount+=1
            subagent.replay_buffer.push(state, action, reward_sub, next_state, done_sub)
            agent.replay_buffer.push(state, action, reward, next_state, done)
            episode_reward += reward
            if len(subagent.replay_buffer) > batch_size:
                subagent.update(batch_size)
            if len(agent.replay_buffer) > batch_size:
                agent.update(batch_size)   
            #print(step)
            if (done_sub==True and targ_found==1) or step == max_steps_sub-1:
                print(target_state)
                #episode+=1                
                #print("Step Number: " + str(stepcount) + ": " + str(episode_reward)+" Episode:"+str(episode)+" Run:"+str(runno))
                #tracebuffer_neg=build_tracebuffer(statelog,tracebuffer_neg,success_flag)
                break

            if targ_found==1:
                if done_sub==True or step == max_steps_sub-1:
                    break
            else:
                #print(step)
                #statelog.append(np.append(state,action))
                if len(sa_log)==0:
                    sa_log=np.append(state,action)
                else:
                    sa_log=np.vstack((sa_log,np.append(state,action)))
                if step==max_steps_sub-2 or done_sub==True:
                    s_salog=np.shape(sa_log)
                    #print("salog:")
                    #print(s_salog)
                    #print(len(state))
                    if s_salog[0]==len(state)+len(action):
                        sa_log=np.vstack((sa_log,sa_log))
                    #print(sa_log)
                    traces=model.predict(sa_log)
                    #print(traces)
                    targ_trace=np.max(traces)

                    if np.linalg.norm(botpos-gp)<thresh:
                        targ_found=1
                        target_state_orig=next_state
                        target_state=gp
                        all_found=1
                    elif targ_trace>rew_trace*subgoal_thresh:# or np.linalg.norm(botpos-gp)<thresh:
                        ind=np.argmax(traces)
                        max_sa=sa_log[ind]
                        targ_found=1
                        if (np.linalg.norm(target_state-[0.,0.]))<thresh:
                            target_state=gp
                        else:    
                            target_state_orig=max_sa[0:(len(max_sa)-2)]#np.array([max_sa[0],max_sa[1]])
                            target_state=target_state_orig[0:2]
                        ss=np.shape(subgoal_targs)

                        print("Prev subgoals:")
                        print(subgoal_targs)
                        print(ss)
                        sa_prev=[]
                        for ii in range(ss[0]):                    
                            if (np.linalg.norm(target_state-subgoal_targs[ii-1]))<thresh or (np.linalg.norm(target_state-[0.,0.]))<2*thresh:
                                target_state=gp
                                targ_found=0
                                break
                        break
                        print(target_state)
                        #time.sleep(3)
            state = next_state
        #time.sleep(3)
 
    return subagent,target_state,retall,target_state_orig,agent,stepcount,all_found

def mini_batch_train(env, agent, max_steps, batch_size, maxsteps_total,gp,stepcount,runno):
    episode_rewards = []
    tracebuffer=[]
    tracebuffer_neg=[]
    statelog=[]
    success_flag=0
    stepcount=0
    success_attempts=0
    retall=[]
    subgoal_targs=[]
    subgoal_targs_orig=[]
    sub_pol=[]

    model = Sequential()
    model.add(Dense(128, input_shape=(9,), activation="relu"))
    model.add(Dense(128, activation="relu"))
    model.add(Dense(1, activation="linear"))
    model.compile(loss="mse", optimizer=Adam(lr=0.001))
    episode=0
    all_found=0

    #runno=0
    #for episode in range(max_episodes):
    while stepcount<=maxsteps_total:
        state = env.reset()
        init_state=state
        episode_reward = 0
        done=False
        success_flag=0
        for step in range(max_steps):
            #env.render()
            action = agent.get_action(state)
            #print(action)
            next_state, reward, done, _ = env.step(action)
            
            
            s_sg=np.shape(subgoal_targs)
            if s_sg[0]==0:
                st_state_orig=init_state#env.reset()
                st_state=st_state_orig[0:2]
            else:
                s_subgoaltargs=np.shape(subgoal_targs)
                st_state_orig=subgoal_targs_orig[s_subgoaltargs[0]-1]
                st_state=subgoal_targs[s_subgoaltargs[0]-1]
            if all_found==1:
                st_state_fin_orig = env.reset()
                sz=np.shape(subgoal_targs)
                print("All subgoals:")
                print(subgoal_targs)
                for k in range(sz[0]):
                    if k==0:
                        st_state_fin_orig,retall,agent,stepcount,doneflag,subpol_ex=execute_subpols(st_state_fin_orig,retall,gp,env,max_steps,maxsteps_total,batch_size,agent,stepcount,subgoal_targs[k],sub_pol[k])
                        sub_pol[k]=subpol_ex
                    else:
                        st_state_fin_orig[0]=subgoal_targs[k-1][0]
                        st_state_fin_orig[1]=subgoal_targs[k-1][1]
                        st_state_fin_orig,retall,agent,stepcount,doneflag,subpol_ex=execute_subpols(st_state_fin_orig,retall,gp,env,max_steps,maxsteps_total,batch_size,agent,stepcount,subgoal_targs[k],sub_pol[k])
                        sub_pol[k]=subpol_ex

                    if doneflag==True:# or stepcount==maxsteps_total:
                        break
                stepcount+=1
                retall.append(1)
                break
                    #st_state_fin_pos=st_state_fin_orig[0:2]
            #Find subgoal
            if success_attempts>0 and done==False:
                Q_sub,sub_goal,retall,subgoal_orig,agent,stepcount,all_found=findsubgoal(model,st_state,st_state_orig,retall,gp,env,max_steps,maxsteps_total,batch_size,agent,stepcount,subgoal_targs)
                s_sgtarg=np.shape(subgoal_targs)
                closecount=0
                if np.linalg.norm(st_state-sub_goal)>=thresh:
                    for ik in range(s_sgtarg[0]):
                        if np.linalg.norm(subgoal_targs[ik]-sub_goal)<=thresh:
                            closecount+=1
                else: 
                    closecount+=1
                if closecount==0:
                    subgoal_targs.append(sub_goal)
                    subgoal_targs_orig.append(subgoal_orig)
                    sub_pol.append(Q_sub)

                
                print(type(sub_pol))
                print(np.shape(sub_pol))

            ###########################
            if len(statelog)>=(tracelim):
                statelog=statelog[1:(tracelim)]
            statelog.append(np.append(state,action))
            botpos=np.array(env.env.wrapped_env.get_xy())
            dist=np.linalg.norm(botpos-gp)

            if stepcount<=maxsteps_total:
                done=False
                reward=-0.1
            if dist<thresh:
                done=True
                print(dist)
                #print(next_state)
                print(botpos)
                print("Goal")
                success_attempts+=1
                reward=1
                success_flag=1
                success_cnt=stepcount
                tracebuffer=build_tracebuffer(statelog,tracebuffer,success_flag)
            retall.append(reward)
            
            stepcount+=1

            if success_flag>0:
                train_prob=spike_decay**(stepcount-success_cnt)
                if np.random.sample()<train_prob:
                    model=learn_trace_model(model,tracebuffer,tracebuffer_neg,runno,stepcount/maxsteps_total)
                    model_trained_flag=1
                    print("Model trained")
            agent.replay_buffer.push(state, action, reward, next_state, done)
            episode_reward += reward

            if len(agent.replay_buffer) > batch_size:
                agent.update(batch_size)   

            if done or step == max_steps-1:
                episode+=1
                episode_rewards.append(episode_reward)
                print("Step Number: " + str(stepcount) + ": " + str(episode_reward)+" Episode:"+str(episode)+" Run:"+str(runno))
                tracebuffer_neg=build_tracebuffer(statelog,tracebuffer_neg,success_flag)
                break

            state = next_state
    torch.save(agent.critic.state_dict(), 'criticsaved_fk.pt')
    torch.save(agent.actor.state_dict(), 'actorsaved_fk.pt')
    #retallcopy=retall
    #retallcopy[retallcopy==-0.1]=0
    retallcopy=np.array(retall.copy())
    retallcopy[retallcopy==-0.1]=0
    print("Total Successes:"+str(np.sum(retallcopy)))

    return episode_rewards,stepcount,retall



def mini_batch_train_plain(env, agent, max_steps, batch_size, maxsteps_total,gp,stepcount,runno):
    episode_rewards = []
    tracebuffer=[]
    tracebuffer_neg=[]
    statelog=[]
    success_flag=0
    stepcount=0
    success_attempts=0
    retall=[]
    subgoal_targs=[]
    subgoal_targs_orig=[]
    sub_pol=[]

    episode=0
    all_found=0

    #runno=0
    #for episode in range(max_episodes):
    while stepcount<=maxsteps_total:
        print(stepcount)
        state = env.reset()
        init_state=state
        episode_reward = 0
        done=False
        success_flag=0
        for step in range(max_steps):
            #env.render()
            action = agent.get_action(state)
            #print(action)
            next_state, reward, done, _ = env.step(action)
            
            ###########################
            botpos=np.array(env.env.wrapped_env.get_xy())
            dist=np.linalg.norm(botpos-gp)

            if stepcount<=maxsteps_total:
                done=False
                reward=-0.1
            if dist<thresh:
                done=True
                #print(dist)
                #print(next_state)
                #print(botpos)
                print("Episode length:")
                print(step)
                print("Goal")
                success_attempts+=1
                reward=1
                success_flag=1
                success_cnt=stepcount
                #tracebuffer=build_tracebuffer(statelog,tracebuffer,success_flag)
            retall.append(reward)
            stepcount+=1

            agent.replay_buffer.push(state, action, reward, next_state, done)
            episode_reward += reward

            if len(agent.replay_buffer) > batch_size:
                agent.update(batch_size)   

            if done or step == max_steps-1:
                episode+=1
                episode_rewards.append(episode_reward)
                print("Step Number: " + str(stepcount) + ": " + str(episode_reward)+" Episode:"+str(episode)+" Run:"+str(runno))
                #tracebuffer_neg=build_tracebuffer(statelog,tracebuffer_neg,success_flag)
                break

            state = next_state
    torch.save(agent.critic.state_dict(), 'criticplainsaved.pt')
    torch.save(agent.actor.state_dict(), 'actorplainsaved.pt')
    retallcopy=np.array(retall.copy())
    retallcopy[retallcopy==-0.1]=0
    print("Total Successes:"+str(np.sum(retallcopy)))
    return episode_rewards,stepcount,retall

def mini_batch_train_frames(env, agent, max_frames, batch_size):
    episode_rewards = []
    state = env.reset()
    episode_reward = 0

    for frame in range(max_frames):
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        agent.replay_buffer.push(state, action, reward, next_state, done)
        episode_reward += reward

        if len(agent.replay_buffer) > batch_size:
            agent.update(batch_size)   

        if done:
            episode_rewards.append(episode_reward)
            print("Frame " + str(frame) + ": " + str(episode_reward))
            state = env.reset()
            episode_reward = 0
        
        state = next_state
            
    return episode_rewards

# process episode rewards for multiple trials
def process_episode_rewards(many_episode_rewards):
    minimum = [np.min(episode_reward) for episode_reward in episode_rewards]
    maximum = [np.max(episode_reward) for episode_reward in episode_rewards]
    mean = [np.mean(episode_reward) for episode_reward in episode_rewards]

    return minimum, maximum, mean
