#Save trajectories
from gym.wrappers.time_limit import TimeLimit
import numpy as np
import gym
import torch
import re
from mdtoyenv import MultiDiscreteToyEnv
from util import cari_concat, get_hypersphere_locations, flatten_buffer
from reward_fns import within_radius, neg_distance

import os
from datetime import datetime
# import pybullet_envs
from PPO import PPO
from torch.utils.tensorboard import SummaryWriter
import argparse
import random


def add_trajectories_to_replay(N_episodes, weights, env, model, buffer, verbose = False, cat_state_weights=True):
    """
    Generate N_episodes of trajectories with weights. Notice that the number of trajectories will be n_arenas * N_episodes
    """
    n_arenas = 1
    weights = weights.reshape((1,-1)).repeat(n_arenas, axis=0)
    #rew_fn = within_radius
    rew_fn = neg_distance
    policy_rewards = []
    policy_is_terminals = []
    policy_states = []
    policy_next_states = []
    policy_actions = []
    policy_log_probs = []
    for ii in range(N_episodes):
        state = env.reset()
        if cat_state_weights:
            state = cari_concat(state, weights)
        done = False
        while not done:

            # select action with policy
            action = model.select_action(state, greedy = False)
            action_to_save = action.clone()
            next_state, _, done, _ = env.step(action)
            if cat_state_weights:
                next_state = cari_concat(next_state, weights)
            #compute the saveables
            state_to_save = torch.FloatTensor(state.copy())
            next_state_to_save = torch.FloatTensor(next_state.copy())
            device = action_to_save.device
            dists = model.policy_old.actor(state_to_save.to(device))
            logprobs_to_save = model.policy_old.actor.get_logprobs(action_to_save, dists)
            #calculate rewards        
            rewards = [rew_fn(next_state_to_save[:,:env.n_dim].numpy(), goal.reshape((1,env.n_dim)), rad.reshape((1,1))) for ii, (goal, rad) in enumerate(zip(env.goal_locs, env.goal_rads))]
            reward = np.column_stack(rewards)
            reward = np.sum(weights * reward, axis=1)

            if isinstance(done, list):
                done = done[0]
            #if done either by hp or by time
            done_mask = np.repeat(done, n_arenas)
            # saving reward and is_terminals
            buffer.states.append(state_to_save)
            buffer.actions.append(action_to_save)
            buffer.logprobs.append(logprobs_to_save)
            buffer.rewards.append(reward)
            buffer.is_terminals.append(done_mask)
            buffer.next_states.append(next_state_to_save)

            policy_states.append(state_to_save)
            policy_actions.append(action_to_save)
            policy_log_probs.append(logprobs_to_save)
            policy_rewards.append(reward)
            policy_is_terminals.append(done_mask)
            policy_next_states.append(next_state_to_save)

            state = next_state

            # break; if the episode is over
            if done:
                break

    return policy_states, policy_actions, policy_log_probs, policy_rewards, policy_is_terminals, policy_next_states


if __name__ == "__main__":

    env_name = "mdenv"

    parser = argparse.ArgumentParser(description='MDENV generate trajectories')
    parser.add_argument("--N_goals", default=3, type=int)
    parser.add_argument("--N_weights", default=7, type=int) #number of randomly sampled weights from dirichlet
    parser.add_argument("--policy_path", default="Anonymous/ScIRL/PPO_preTrained/mdtoyenv/3.pth", type=str)
    parser.add_argument("--N_trajs", type=int, default=128)
    parser.add_argument("--random_seed", type=int, default=123)
    
    args = parser.parse_args()

    seed = args.random_seed
    print(f"Setting seed to {seed}")
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
 
    # random seed set to derk_appinstance.py create_session() -function "pWkn91perNetBJQ3ymNc9"
    n_arenas = 1
    #get the input dimension from the saved policy
    n_env_dims = int( (list(torch.load(args.policy_path).items())[0][1].shape[1] - args.N_goals) / (1 + args.N_goals) )
    goal_locs, goal_rads = get_hypersphere_locations(n_env_dims, args.N_goals)
    print(f"GOALS: {goal_locs}, RADS {goal_rads}")
    env = MultiDiscreteToyEnv(goal_locs, goal_rads)
    # state space dimension
    state_dim = env.observation_space.shape[0]
    # action space dimension
    action_space = "multidiscrete"
    action_dim = env.action_space.nvec

    # initialize a PPO agent
    alphas = np.ones(args.N_goals)

    ppo_agent = PPO(state_dim + args.N_goals, action_dim, 0.0, 0.0, 0.0, 0.0, 0.0, action_space, 0.0) 
    ppo_agent.load(args.policy_path)
    max_ep_len = 150
    env = TimeLimit(env, max_ep_len)
    
    #define weights that are used to generate expert trajs
    weights = np.random.dirichlet(alphas, (args.N_weights,))

    def get_valid_filename(weight):
        s = np.array2string(weight, separator="_")
        s = s.replace(" ", "")
        return re.sub(r'(?u)[^-\w.]', '', s)

    print(f"Generating {str(args.N_trajs * n_arenas)} trajectories in {env_name}")
    traj_dir = "trajectories/mdtoyenv/scalability/"
    if not os.path.exists(f"{traj_dir}/{args.N_goals}"):
        os.makedirs(f"{traj_dir}/{args.N_goals}") 
    weights = weights.round(5)

    for weight in weights:
        print(f"\t with weight {weight}")
        replay = ppo_agent.buffer
        add_trajectories_to_replay(args.N_trajs, weight, env, ppo_agent, replay, verbose=True)
        flatten_buffer(replay)
        weight_str = get_valid_filename(weight)
        replay.save(f"{traj_dir}/{args.N_goals}/weights{weight_str}.pkl")
        replay.clear()