import numpy as np
import torch
from decision_transformer.models.s4_muj import *
import logging
import time
import os
import sys
from memory_profiler import memory_usage
logger = logging.getLogger(__name__)

def evaluate_episode(
        env,
        state_dim,
        act_dim,
        model,
        max_ep_len=1000,
        device='cuda',
        target_return=None,
        mode='normal',
        state_mean=0.,
        state_std=1.,
):

    model.eval()
    model.to(device=device)

    state_mean = torch.from_numpy(state_mean).to(device=device)
    state_std = torch.from_numpy(state_std).to(device=device)

    state = env.reset()

    # we keep all the histories on the device
    # note that the latest action and reward will be "padding"
    states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)
    target_return = torch.tensor(target_return, device=device, dtype=torch.float32)
    sim_states = []

    episode_return, episode_length = 0, 0
    for t in range(max_ep_len):

        # add padding
        actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=device)])

        action = model.get_action(
            (states.to(dtype=torch.float32) - state_mean) / state_std,
            actions.to(dtype=torch.float32),
            rewards.to(dtype=torch.float32),
            target_return=target_return,
        )
        actions[-1] = action
        action = action.detach().cpu().numpy()

        state, reward, done, _ = env.step(action)

        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward

        episode_return += reward
        episode_length += 1

        if done:
            break
    return episode_return, episode_length


def evaluate_episode_rtg(
        env,
        state_dim,
        act_dim,
        model,
        max_ep_len=1000,
        scale=1000.,
        state_mean=0.,
        state_std=1.,
        device='cuda',
        target_return=None,
        mode='normal',
    ):
    model.eval()
    model.to(device=device)
    s4_rec = False
    if isinstance(model, S4_mujoco_wrapper):
        if model.config.single_step_val:
            s4_rec = True
            s4_states = [r.detach() for r in model.get_initial_state((1), device)]

    state_mean = torch.from_numpy(state_mean).to(device=device)
    state_std = torch.from_numpy(state_std).to(device=device)

    state = env.reset()
    if mode == 'noise':
        state = state + np.random.normal(0, 0.1, size=state.shape)

    # we keep all the histories on the device
    # note that the latest action and reward will be "padding"
    states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)
    actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)
    rewards = torch.zeros(0, device=device, dtype=torch.float32)

    ep_return = target_return
    target_return = torch.tensor(ep_return, device=device, dtype=torch.float32).reshape(1, 1)
    timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)

    sim_states = []

    episode_return, episode_length = 0, 0
    actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
    rewards = torch.cat([rewards, torch.zeros(1, device=device)])
    logging.info(f"LOG EVAL TIME: STARTING EVAL :: {device}")
    for t in range(max_ep_len):

        # add padding
        if t > 500 and t <= 503:
            eval_start = time.time()
        if s4_rec:
            action, s4_states = model.get_action(
                (states.to(dtype=torch.float32) - state_mean) / state_std,
                actions.to(dtype=torch.float32),
                rewards.to(dtype=torch.float32),
                target_return.to(dtype=torch.float32),
                timesteps.to(dtype=torch.long),
                s4_states = s4_states,
            )
            if model.config.base_model == "ant_reward_target":
                action = action[0]
            if model.config.discrete > 0:
                action, pred_rtg = action
                maxim = torch.argmax(action, dim=-1).to(dtype=action.dtype)
                action = (maxim * 2 / model.config.discrete - 1 )[0, -1, :model.action_dim]
                pred_state = (maxim / (model.config.discrete / (model.config.state_bound[1] -model.config.state_bound[0])) - model.config.state_bound[0])[0, -1, model.action_dim:]
            if t > 0:
                actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
                rewards = torch.cat([rewards, torch.zeros(1, device=device)])
        else:
            if t > 0:
                actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)
                rewards = torch.cat([rewards, torch.zeros(1, device=device)])
            action = model.get_action(
                (states.to(dtype=torch.float32) - state_mean) / state_std,
                actions.to(dtype=torch.float32),
                rewards.to(dtype=torch.float32),
                target_return.to(dtype=torch.float32),
                timesteps.to(dtype=torch.long),
            )
        if t > 500 and t <= 503:
            curr_time = time.time()
            logging.info(f"LOG EVAL TIME: step time: {curr_time-eval_start}")
            tot_m, used_m, free_m = map(int, os.popen('free -t -m').readlines()[-1].split()[1:])
            logging.info(f"LOG EVAL MEM1: tot_m: {tot_m} , used_m: {used_m} , free_m: {free_m}")
            if not s4_rec:
                #memtime = memory_usage((model.transformer,torch.rand(1, , model.hidden_size), s4_states[0]), interval=0.001)
                memtime =  memory_usage((model.get_action, ((states.to(dtype=torch.float32) - state_mean) / state_std,
                                                                     actions.to(dtype=torch.float32)[:-1,...],
                                                                     rewards.to(dtype=torch.float32)[:-1,...],
                                                                     target_return.to(dtype=torch.float32),
                                                                     timesteps.to(dtype=torch.long),
                                                                     )), interval=0.0001)
                logging.info(f"MAXLEN: {model.max_length}")
            else:
                #memtime = memory_usage((model.s4_mods[0].s4_mod_in, (torch.rand(1,1, model.h), s4_states[0])), interval=0.001)
                memtime = memory_usage((model.get_action, ((states.to(dtype=torch.float32)[-3:, ...] - state_mean) / state_std,
                                                 actions.to(dtype=torch.float32)[-3:-1, ...],
                                                 rewards.to(dtype=torch.float32)[-3:-1, ...],
                                                 target_return.to(dtype=torch.float32)[-3:, ...],
                                                 timesteps[-3:, ...].to(dtype=torch.long),
                                                 ), {'s4_states': s4_states}), interval=0.0001)
            memtime = np.array(memtime)
            logging.info(f"LOG EVAL MEM2: mean_used: {memtime.mean()} , max_uesd: {memtime.max()}")
            sys.stdout.flush()

        actions[-1] = action
        action = action.detach().cpu().numpy()

        state, reward, done, _ = env.step(action)

        cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)
        states = torch.cat([states, cur_state], dim=0)
        rewards[-1] = reward

        if mode != 'delayed':
            pred_return = target_return[0,-1] - (reward/scale)
        else:
            pred_return = target_return[0,-1]
        target_return = torch.cat(
            [target_return, pred_return.reshape(1, 1)], dim=1)
        timesteps = torch.cat(
            [timesteps,
             torch.ones((1, 1), device=device, dtype=torch.long) * (t+1)], dim=1)

        episode_return += reward
        episode_length += 1

        if done:
            break
    average_diff, last_action_diff = 0, 0
    if s4_rec:
        if model.config.track_step_err:
            actions = actions.to(dtype=torch.float32).reshape(1,-1,act_dim)
            states = states.to(dtype=torch.float32).reshape(1,-1,state_dim)
            target_return = target_return.to(dtype=torch.float32).reshape(1,-1,1)
            _, predicted_actions, _ = model.forward(
                (states.to(dtype=torch.float32)[0, :-1, :].unsqueeze(0) - state_mean) / state_std,
                actions,
                rewards.to(dtype=torch.float32),
                target_return.to(dtype=torch.float32)[0, :-1, 0].unsqueeze(0),
                timesteps,
                running=True
            )
            delta = predicted_actions[:,:-1,:] - actions[:,1:,:]
            average_diff = delta.abs().mean().cpu().item()
            last_action_diff = delta[:,-1,:].abs().mean().cpu().item()
    return episode_return, episode_length, average_diff, last_action_diff
