'''
prompt utils for training & evaluation
'''

import numpy as np
#import gym
import json, pickle, random, os, torch
from collections import namedtuple
#from .prompt_evaluate_episodes import prompt_evaluate_episode, prompt_evaluate_episode_rtg
import random
from copy import deepcopy
import gymnasium as gym
from datetime import datetime

cardio_attr_states = ['000000000000000', '000010010100000', '001111011011111', '010111111010011', '100000000001100', '100010010101100']

def check_avoid_success(observations, boxes, obs_start_index=0, prompt_dim=3):
    obs = observations[:, obs_start_index:obs_start_index+prompt_dim]
    # print("debug obs size: ", obs.shape)
    successes = []
    for box in boxes:
        # print(box)
        a = torch.tensor(box[:prompt_dim], device=observations.device)
        b = torch.tensor(box[prompt_dim:], device=observations.device)
        # print(np.count_nonzero(np.all(obs >= a, axis = 1) & np.all(obs <= b, axis = 1)))
        successes.append(not torch.any(torch.all(obs >= a, dim = 1) & torch.all(obs <= b, dim = 1)).item())
    successes = torch.tensor(successes, device = observations.device)
    ret = True if torch.all(successes).item() else False
    return ret

''' sample batch from trajectories dataset '''

def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    return discount_cumsum

# given a trajectory, convert it into Transformer sequence
def get_sequence(trajectory, K, info, gamma=1):
    state_dim, act_dim, device, discrete_action = info['state_dim'], info['act_dim'], info['device'], info['discrete_action']
    max_len = K

    s = trajectory['observations'].reshape(1, -1, state_dim)
    if discrete_action:
        a = trajectory['actions'].reshape(1, -1)
    else:
        a = trajectory['actions'].reshape(1, -1, act_dim)
    r = trajectory['rewards'].reshape(1, -1, 1)
    d = trajectory['terminals'].reshape(1, -1)
    timesteps = np.arange(0, trajectory['timesteps']).reshape(1, -1)
    rtg = discount_cumsum(trajectory['rewards'], gamma=gamma)[:s.shape[1] + 1].reshape(1, -1, 1)
    if rtg.shape[1] <= s.shape[1]:
        rtg = np.concatenate([rtg, np.zeros((1, 1, 1))], axis=1)
    #print(s.shape, a, r, d, timesteps)

    # padding to the right
    tlen = s.shape[1]
    s = np.concatenate([s, np.zeros((1, max_len - tlen, state_dim))], axis=1)
    if discrete_action:
        a = np.concatenate([a, np.zeros((1, max_len - tlen))], axis=1)
    else:
        a = np.concatenate([a, np.zeros((1, max_len - tlen, act_dim))], axis=1)
    r = np.concatenate([r, np.zeros((1, max_len - tlen, 1))], axis=1)
    d = np.concatenate([d, np.ones((1, max_len - tlen)) * 2], axis=1)
    rtg = np.concatenate([rtg, np.zeros((1, max_len - tlen, 1))], axis=1)
    timesteps = np.concatenate([timesteps, np.zeros((1, max_len - tlen))], axis=1)
    mask = np.concatenate([np.ones((1, tlen)), np.zeros((1, max_len - tlen))], axis=1)

    s = torch.from_numpy(s).to(dtype=torch.float32, device=device)
    a = torch.from_numpy(a).to(dtype=torch.long if discrete_action else torch.float32, device=device)
    r = torch.from_numpy(r).to(dtype=torch.float32, device=device)
    d = torch.from_numpy(d).to(dtype=torch.long, device=device)
    rtg = torch.from_numpy(rtg).to(dtype=torch.float32, device=device)
    timesteps = torch.from_numpy(timesteps).to(dtype=torch.long, device=device)
    mask = torch.from_numpy(mask).to(device=device)
    #print(s.shape, a, timesteps, mask, tlen, trajectory['timesteps'])

    return s, a, r, d, rtg, timesteps, mask


# from the dataset, sample a batch of prompt-sequence pairs for training
def get_prompt_batch(trajectories, info, variant, get_prompt, get_avoid_prompt=None, n_prioritized_traj = None, n_prioritized_per_batch = None):
    num_trajectories = len(trajectories)
    max_ep_len, max_prompt_len = variant['max_ep_len'], variant['max_prompt_len']
    if get_avoid_prompt:
        max_avoid_prompt_len = variant['max_avoid_prompt_len']
    state_dim, act_dim, device, discrete_action = info['state_dim'], info['act_dim'], info['device'], info['discrete_action']
    batch_size, K = variant['batch_size'], variant['K']
    #print(discrete_action, batch_size, device, act_dim, K, num_trajectories)
    subsample, subsample_minlen = variant['subsample_trajectory'], variant['subsample_min_len']

    def fn(batch_size=variant['batch_size']):
        # sample batch indices in the trajectories dataset
        if n_prioritized_traj:
            batch_inds_prioritized = np.random.choice(
                                                    np.arange(n_prioritized_traj),
                                                    size=n_prioritized_per_batch,
                                                    replace=False
                                                    )
            batch_inds_additional = np.random.choice(
                                                    np.arange(n_prioritized_traj, num_trajectories),
                                                    size=max(batch_size - n_prioritized_per_batch, 0),
                                                    replace=True
                                                    )
            batch_inds = np.concatenate([batch_inds_prioritized, batch_inds_additional])
            
        else:
            batch_inds = np.random.choice(
                np.arange(num_trajectories),
                size=batch_size,
                replace=True
            )
        prompt_list, p_mask_list = [], []
        if get_avoid_prompt:
            avoid_prompt_list, avoid_p_mask_list, success_list = [], [], []
        s_list, a_list, r_list, d_list, rtg_list, timesteps_list, mask_list = [], [], [], [], [], [], []
        
        for i in (batch_inds):
            if subsample:
                trajectory = subsample_trajectory(trajectories[i], subsample_minlen)
            else:
                trajectory = trajectories[i]
            #print(trajectory['timesteps'], trajectory['observations'].shape)
            p, p_mask = get_prompt(trajectory, max_prompt_length=max_prompt_len, device=device, use_optimal_prompt=False)
            prompt_list.append(p)
            p_mask_list.append(p_mask)

            if get_avoid_prompt:
                avoid_prompt_tuple = get_avoid_prompt(trajectory, max_avoid_prompt_length=max_avoid_prompt_len, device=device, use_optimal_prompt=False)
                avoid_p = avoid_prompt_tuple[0]
                avoid_p_mask = avoid_prompt_tuple[1]
                avoid_prompt_list.append(avoid_p)
                avoid_p_mask_list.append(avoid_p_mask)
                if len(avoid_prompt_tuple) == 3:
                    success_list.append(avoid_prompt_tuple[2])
                    

            s, a, r, d, rtg, timesteps, mask = get_sequence(trajectory, K, info, gamma=variant['gamma'])
            s_list.append(s)
            a_list.append(a)
            r_list.append(r)
            d_list.append(d)
            rtg_list.append(rtg)
            timesteps_list.append(timesteps)
            mask_list.append(mask)

        p, p_mask = torch.cat(prompt_list, dim=0), torch.cat(p_mask_list, dim=0)
        s, a, r, d = torch.cat(s_list, dim=0), torch.cat(a_list, dim=0), torch.cat(r_list, dim=0), torch.cat(d_list, dim=0)
        rtg, timesteps, mask = torch.cat(rtg_list, dim=0), torch.cat(timesteps_list, dim=0), torch.cat(mask_list, dim=0)
        prompt = p, p_mask
        batch = s, a, r, d, rtg, timesteps, mask
        if get_avoid_prompt:
            avoid_p, avoid_p_mask = torch.cat(avoid_prompt_list, dim=0), torch.cat(avoid_p_mask_list, dim=0)
            avoid_prompt = avoid_p, avoid_p_mask
            if success_list is not None:
                return prompt, avoid_prompt, batch, torch.tensor(success_list, device=device)
            return prompt, avoid_prompt, batch
        return prompt, batch
    
    return fn


# subsample into traj[0:len], len in [minlen, traj len]
def subsample_trajectory(trajectory, minlen):
    l = random.randint(minlen, trajectory['timesteps'])
    if l==trajectory['timesteps']:
        return trajectory
    keys = ['observations', 'actions', 'rewards', 'terminals', 'next_observations']
    new_trajectory = {}
    for k in trajectory:
        if k=='timesteps':
            new_trajectory[k] = l 
        elif k in keys:
            new_trajectory[k] = deepcopy(trajectory[k][0:l])
        else:
            new_trajectory[k] = deepcopy(trajectory[k])
    return new_trajectory


""" evaluation """

def eval_episodes(info, variant, envs, model, prompt_len, get_prompt, trajectories, get_avoid_prompt = None, buffer_size = 0.03, name_prefix = None, prompt_dim = 3):
    max_ep_len, discrete_action = info['max_ep_len'], info['discrete_action']
    state_dim, act_dim, device = info['state_dim'], info['act_dim'], info['device']
    num_eval_episodes = variant['num_eval_episodes']

    returns = []
    cost_returns = []
    normalized_costs = []
    ep_lens = []
    successes = []
    avoid_successes = []
    final_distances = []
    #norm_scores = []
    for env_id in range(len(envs)):
        # prompt = get_prompt(trajectories[env_id], variant['max_prompt_len'], 
        #                     prompt_length=prompt_len, device=device, use_optimal_prompt=variant['test_optimal_prompt'])
        for _ in range(num_eval_episodes):
            avoid_success=None
            final_distance=None
            with torch.no_grad():
                if variant["avoid_prompt"]:
                    num_avoid = 1
                    if variant["num_avoid"]:
                        num_avoid = variant["num_avoid"]
                    prompt_eval_episode_results = prompt_evaluate_episode(
                        envs[env_id],
                        state_dim,
                        act_dim,
                        discrete_action,
                        model,
                        max_ep_len=max_ep_len,
                        device=device,
                        avoid_prompt=True,
                        buffer_size=buffer_size,
                        name_prefix=name_prefix,
                        state_truncated_dim=variant["state_truncated_dim"],
                        obs_start_index=variant["obs_start_index"],
                        prompt_dim=prompt_dim,
                        num_avoid=num_avoid
                        )
                    
                    ret, cost_ret, ep_len, success = prompt_eval_episode_results[:4]
                    if len(prompt_eval_episode_results) == 5:
                        avoid_success = prompt_eval_episode_results[4]
                    if len(prompt_eval_episode_results) == 6:
                        final_distance = prompt_eval_episode_results[5]
                else:    
                    ret, cost_ret, ep_len, success = prompt_evaluate_episode(
                        envs[env_id],
                        state_dim,
                        act_dim,
                        discrete_action,
                        model,
                        max_ep_len=max_ep_len,
                        device=device,
                        # prompt=prompt,
                        )
                #print(prompt, ret)
            returns.append(ret)
            cost_returns.append(cost_ret)
            normalized_costs.append(cost_ret / ep_len)
            ep_lens.append(ep_len)
            successes.append(success)
            if avoid_success is not None:
                avoid_successes.append(avoid_success)
            if final_distance is not None:
                final_distances.append(final_distance)
            #if hasattr(envs[env_id], 'max_return'):
            #    norm_scores.append(ret/envs[env_id].max_return)
    ret = {
        f'prompt_len_{prompt_len}_return_mean': np.mean(returns),
        f'prompt_len_{prompt_len}_cost_return_mean': np.mean(cost_returns),
        f'prompt_len_{prompt_len}_normalized_costs_mean': np.mean(normalized_costs),
        f'prompt_len_{prompt_len}_ep_len_mean': np.mean(ep_lens),
        f'prompt_len_{prompt_len}_success_rate': np.mean(successes),
        }
    if avoid_successes:
        ret[f'prompt_len_{prompt_len}_avoid_success_rate'] = np.mean(avoid_successes)
    if final_distances:
        ret[f'prompt_len_{prompt_len}_final_distances'] = np.mean(final_distances)
    #if len(norm_scores)>0:
    #    ret[f'prompt_len_{prompt_len}_normalized_score_mean'] = np.mean(norm_scores)
    return ret


def prompt_evaluate_episode(
        env,
        state_dim,
        act_dim,
        discrete_action,
        model,
        max_ep_len=1000,
        device='cuda',
        prompt=None,
        avoid_prompt=False,
        buffer_size=0.06,
        name_prefix=None,
        state_truncated_dim=10,
        obs_start_index=0,
        prompt_dim=3,
        num_avoid=1,
    ):

    if name_prefix:
        # print("debug 1")
        env.unwrapped.render_mode = "rgb_array"
        env.unwrapped.camera_name = "free"
        env = gym.wrappers.RecordVideo(env=env, video_folder="videos", name_prefix=name_prefix + datetime.now().strftime("%Y%m%d%H%M%S"), episode_trigger=lambda x: x % 2 == 0)
        env.env._max_episode_steps = 50
    # print("debug 2")

    model.eval()
    model.to(device=device)

    obs, info = env.reset()
    state = obs["observation"]

    # customized: create a prompt to be whatever the environment goal is
    goal = obs["desired_goal"]
    prompt_goal = torch.from_numpy(np.array([goal]).reshape(1, -1, prompt_dim)).to(dtype=torch.float32, device=device)
    prompt_mask = torch.ones(prompt_goal.shape[:2]).to(device=device)
    prompt = (prompt_goal, prompt_mask)

    if avoid_prompt:
        # avoid_state = state[10:13]
        # avoid_prompt_state = torch.from_numpy(np.array([avoid_state]).reshape(1, -1, 3)).to(dtype=torch.float32, device=device)
        # avoid_prompt_mask = torch.ones(avoid_prompt_state.shape[:2]).to(device=device)
        # avoid_prompt = (avoid_prompt_state, avoid_prompt_mask)
        if state_truncated_dim == len(state):
            avoid_state_readable = np.random.choice(cardio_attr_states)
            avoid_state = np.array([int(x) for x in avoid_state_readable])
            bts_avoid_states = [np.concatenate([avoid_state + np.array([-buffer_size] * prompt_dim), avoid_state + np.array([buffer_size] * prompt_dim)])]
            avoid_states = [np.concatenate([avoid_state + np.array([-buffer_size] * prompt_dim), avoid_state + np.array([buffer_size] * prompt_dim)])]
        else:
            bts_avoid_states = [np.concatenate([state[state_truncated_dim : state_truncated_dim + prompt_dim] + np.array([-buffer_size] * prompt_dim), state[state_truncated_dim : state_truncated_dim + prompt_dim] + np.array([buffer_size] * prompt_dim)])]
            avoid_states = [np.concatenate([state[state_truncated_dim : state_truncated_dim + prompt_dim] + np.array([-buffer_size] * prompt_dim), state[state_truncated_dim : state_truncated_dim + prompt_dim] + np.array([buffer_size] * prompt_dim)])]
            # avoid_states = [np.concatenate([state[10:13] + np.array([-0.02, -0.02, -0.03]), state[10:13] + np.array([0.02, 0.02, 0.03])])]
        avoid_prompt_state = torch.from_numpy(np.array(avoid_states).reshape(1, -1, prompt_dim*2)).to(dtype=torch.float32, device=device)
        avoid_prompt_mask = torch.ones(avoid_prompt_state.shape[:2]).to(device=device)
        avoid_prompt = (avoid_prompt_state, avoid_prompt_mask)
    else:
        avoid_prompt = None    

    state = state[:state_truncated_dim] # remove obstacle information


    # we keep all the histories on the device
    # note that the latest action and reward will be "padding"
    states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
    if discrete_action:
        actions = torch.zeros((0,), device=device, dtype=torch.long)
    else:
        actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    #rewards = torch.zeros(0, device=device, dtype=torch.float32)
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    sim_states = []

    episode_return, episode_cost_return, episode_length = 0, 0, 0
    for t in range(max_ep_len):
        # print('evaluate/t', t)
        # add padding
        if discrete_action:
            actions = torch.cat([actions, torch.zeros((1,), device=device, dtype=torch.long)], dim=0)
        else:
            actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
        #rewards = torch.cat([rewards, torch.zeros(1, device=device)])
        #print(states.shape, actions, timesteps, prompt)
        action = model.get_action(
            states.to(dtype=torch.float32),
            actions,
            timesteps.to(dtype=torch.long),
            prompt=prompt,
            avoid_prompt=avoid_prompt,
            # success_list=torch.tensor([False], device=device),
            success_list=torch.tensor([True], device=device)
        )
            
        actions[-1] = action
        action = action.detach().cpu().numpy()

        #env.render()
        # state, reward, done, infos = env.step(action)
        obs, reward, cost, done, infos = env.step(action)
        state = obs["observation"]
        state = state[:state_truncated_dim] # remove obstacle information

        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
        states = torch.cat([states, cur_state], dim=0)
        #rewards[-1] = reward
        
        timesteps = torch.cat(
            [timesteps,
             torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1)

        episode_return += reward
        episode_cost_return += cost
        episode_length += 1
        infos['episode_length'] = episode_length

        env.render()

        if done or infos["is_success"]: # stop when success is reached
            break
        
    ag = obs["achieved_goal"]
    g = obs["desired_goal"]
    infos['final_distance'] = np.linalg.norm(ag - g)

    infos['avoid_success'] = check_avoid_success(states, bts_avoid_states, obs_start_index, prompt_dim)
    # env.close()

    return episode_return, episode_cost_return, episode_length, infos['is_success'], infos['avoid_success'], infos['final_distance']
