#Save trajectories
from gym.wrappers.time_limit import TimeLimit
import numpy as np
import gym
import torch
import re

from gym_derk.envs import DerkEnv
from util import SingleAgentWrapper, DiscreteActionEnv, BotPolicy
from util import cari_concat, set_reward_function, flatten_buffer, beta_params
from reward_fns import RewardCalculator

import os
from PPO import PPO
from torch.utils.tensorboard import SummaryWriter

def add_trajectories_to_replay(N_episodes, weights, env, model, buffer, verbose = False):
    """
    Generate N_episodes of trajectories with weights. Notice that the number of trajectories will be n_arenas * N_episodes
    """
    n_arenas = env.n_arenas
    weights = weights.reshape((1,-1)).repeat(n_arenas, axis=0)

    for ii in range(N_episodes):
        set_reward_function(weights, env) 
        state = env.reset()
        rc = RewardCalculator(n_arenas)
        state = cari_concat(state, weights)
        done = False
        while not done:

            # select action with policy
            action = model.select_action(state, greedy = True)
            action_to_save = action.clone()
            action = [tuple(a.tolist()) for a in action]
            next_state, _, done, _ = env.step(action)
            next_state = cari_concat(next_state, weights)
            #compute the saveables
            state_to_save = torch.FloatTensor(state.copy())
            next_state_to_save = torch.FloatTensor(next_state.copy())

            device = action_to_save.device
            dists = model.policy_old.actor(state_to_save.to(device))
            logprobs_to_save = model.policy_old.actor.get_logprobs(action_to_save, dists)
            rc.update_focus(state, action, next_state)
            #calculate the rewards with new weights
            reward_calc_actions = [tuple(a.tolist()) for a in action_to_save]  
            reward_dmg = rc.damage_enemy(state_to_save.numpy(), reward_calc_actions, next_state_to_save.numpy()) / 10.0 #normalize to be more equivalent with healing
            reward_heal = rc.heal_teammate(state_to_save.numpy(), reward_calc_actions, next_state_to_save.numpy())
            reward = np.column_stack([reward_dmg, reward_heal])
            reward = np.sum(weights * reward, axis=1)

            done_mask = state_to_save[:, 0] <= 0.0    
            if isinstance(done, list):
                done = done[0]
            #if done either by hp or by time
            done_mask = np.logical_or(done_mask, np.repeat(done, n_arenas)).numpy()
            # saving reward and is_terminals
            buffer.rewards.append(reward)
            buffer.is_terminals.append(done_mask)
            buffer.states.append(state_to_save)
            buffer.actions.append(action_to_save)
            buffer.logprobs.append(logprobs_to_save)
            buffer.next_states.append(next_state_to_save)

            state = next_state

            # break; if the episode is over
            if done:
                break

if __name__ == "__main__":

    env_name = "derk"
    
    #init reward functions
    base_reward_fn = {
        "killEnemyStatue": 0,
        "killEnemyUnit": 0,
        "damageEnemyUnit": 0,
        "healTeammate1": 0,
        "healTeammate2": 0,
        "friendlyFire": 0,
        "statueDamageTaken": 0,
        "fallDamageTaken": 0,
        "healEnemy": 0,
        "timeScaling": 0
    }
    # random seed set to derk_appinstance.py create_session() -function "pWkn91perNetBJQ3ymNc9"
    N_trajs = 128
    env = DerkEnv(n_arenas = N_trajs, turbo_mode=True, reward_function=base_reward_fn, session_args=dict(interleaved=False))
    discretization = [[env.action_space[0].low.item(), env.action_space[0].high.item()], [env.action_space[1].low.item(), env.action_space[1].high.item()], [env.action_space[2].low.item(), env.action_space[2].high.item()], None, None]
    d_steps = 5
    for di, interval in enumerate(discretization):
        if interval is not None:
            disc = torch.linspace(interval[0], interval[1], d_steps+1)
            disc = disc.repeat_interleave(2)
            disc = disc[1:-1].reshape(-1, 2)
            discretization[di] = disc

    env = DiscreteActionEnv(env, discretization)

    # state space dimension
    state_dim = env.observation_space.shape[0]

    # action space dimension
    action_space = "multidiscrete"
    action_dim = env.action_space.nvec

    # initialize a PPO agent
    n_weights = 2
    alphas = np.ones(n_weights)
    bot_weights = np.random.dirichlet(alphas, 1)[0]
    ppo_agent = PPO(state_dim + n_weights, action_dim, 0.0, 0.0, 0.0, 0.0, 0.0, action_space, 0.0) 
    ppo_agent.load("PPO_preTrained/derk/nn_bots_policy2.pth")

    #copy if one wants static bot policy
    #bot_net = copy.deepcopy(ppo_agent.policy.actor)
    #if we want to load existing model
    #bot_net = PPO(state_dim + n_weights, action_dim, lr_actor, lr_critic, gamma, K_epochs, eps_clip, action_space, action_std)
    #bot_net.load("PPO_preTrained/derk/PPO_derk_0_0.pth")
    #bot_net = bot_net.policy.actor
    #if we wish to have exactly the same policy
    #bot_net = ppo_agent.policy.actor
    bot_net = None
    bp = BotPolicy(env=env, bot_net=bot_net, weights=bot_weights)
    env = SingleAgentWrapper(env, bp)
    max_ep_len = 150
    env = TimeLimit(env, max_ep_len)
    
    #define weights that are used to generate expert trajs
    N1 = 20
    mean1 = 0.8
    var1 = 0.001
    N2 = 25
    mean2 = 0.35
    var2 = 0.008
    def sample_beta_weights(mean, var, size):
        a, b = beta_params(mean, var)
        s = np.random.beta(a, b, size)
        return np.stack([s, 1-s], axis=1)
    
    weights1 = sample_beta_weights(mean1, var1, N1)
    weights2 = sample_beta_weights(mean2, var2, N2)
    weights = np.concatenate([weights1, weights2], axis=0)


    def get_valid_filename(s):
        s = str(s).strip().replace(' ', '_')
        return re.sub(r'(?u)[^-\w.]', '', s)

    n_ep = 1
    print(f"Generating {str(N_trajs * n_ep)} trajectories in {env_name}")
    for weight in weights:
        print(f"\t with weight {weight}")
        replay = ppo_agent.buffer
        add_trajectories_to_replay(n_ep, weight, env, ppo_agent, replay, verbose=True)
        flatten_buffer(replay)
        weight_str = get_valid_filename(str(weight))
        replay.save(f"trajectories/derk/weights{weight_str}.pkl")
        replay.clear()