from typing import Dict

import gym
import numpy as np
from numba import njit
import wandb



def get_eval_stats(agent, env: gym.Env, num_episodes: int=10) -> Dict[str, float]:
    stats = {'return': [], 'length': []}
    successes = None
    for _ in range(num_episodes):
        observation, done = env.reset(), False
        while not done:
            action = agent.sample_actions(observation, temperature=0.0)
            observation, reward, done, info = env.step(action)
        for k in stats.keys():
            stats[k].append(info['episode'][k])

        if 'is_success' in info:
            if successes is None:
                successes = 0.0
            successes += info['is_success']

    for k, v in stats.items():
        stats[k] = np.mean(v)

    if successes is not None:
        stats['success'] = successes / num_episodes
    return stats

def get_eval_stats_planning(agent, env: gym.Env, step: int, num_episodes: int=10) -> Dict[str, float]:
    stats = {'return': [], 'length': []}
    successes = None
    for _ in range(num_episodes):
        observation, done = env.reset(), False
        traj_start = 1
        while not done:
            action = agent.sample_actions(observation, traj_start, eval_mode=1, step=step)
            observation, reward, done, info = env.step(action)
            traj_start = 0

        for k in stats.keys():
            stats[k].append(info['episode'][k])

        if 'is_success' in info:
            if successes is None:
                successes = 0.0
            successes += info['is_success']

    for k, v in stats.items():
        stats[k] = np.mean(v)

    if successes is not None:
        stats['success'] = successes / num_episodes
    return stats



def get_observations_and_actions(traj):
    observations = []
    actions = []
    for (obs, act, _, _, _, _, _) in traj:
        observations.append(obs)
        actions.append(act)
    return np.stack(observations), np.stack(actions)

# For now the numba version runs slower than the naive version probably due to the
# overhead of numba compilation because of the variable length argument passing
@njit
def numba_get_observations_and_actions(traj):
    observation_shape = traj[0][0].shape
    action_shape = traj[0][1].shape
    observations = np.empty((len(traj), *observation_shape))
    actions = np.empty((len(traj), *action_shape))
    for i in range(len(traj)):
        observations[i] = traj[i][0]
        actions[i] = traj[i][1]
    return observations, actions

def get_predictive_model_eval_stats(agent, valid_dataset, test_seq_lengths = [10, 50, 100]):
    trajs = valid_dataset.split_into_trajectories()
    traj_prediction_error_per_traj = []
    stats = {f'traj_prediction_error_{seq_len}': None for seq_len in test_seq_lengths}
    for seq_length in test_seq_lengths:
        for traj in trajs:
            num_seq_per_traj = len(traj) // seq_length
            traj_pred_error_per_seq = []
            observations, actions = get_observations_and_actions(traj)
            #observations, actions = numba_get_observations_and_actions(traj)
            for i in range(num_seq_per_traj):
                start = i * seq_length
                end = start + seq_length
                if end+1 >= len(traj):
                    break
                obs = observations[start:end]
                acts = actions[start:end-1]
                next_obs = observations[start+1:end]
                pred_next_obs = agent.predict_next_state(obs[0], acts)
                traj_pred_error_per_seq.append(np.mean(np.square(pred_next_obs - next_obs)))
            traj_prediction_error_per_traj.append(np.mean(traj_pred_error_per_seq))
        stats[f'traj_prediction_error_{seq_length}'] = np.mean(traj_prediction_error_per_traj)
    print(stats)
    return stats

def get_predictive_model_eval_stats_ramdom_seq_sampling(agent, valid_buffer, pred_reward = 0, test_seq_lengths = [10, 50, 100], num_sequences_to_test = 200):

    stats = {f'observation_seq_pred_error_{seq_len}': None for seq_len in test_seq_lengths}
    if pred_reward:
        stats.update({f'reward_seq_pred_error{seq_len}': None for seq_len in test_seq_lengths})
    for seq_length in test_seq_lengths:
        pred_error_list = []
        rew_error_list = []
        for _ in range(num_sequences_to_test):
            batch = valid_buffer.sample_seq(256, seq_length)

            obs = batch.observations
            acts = batch.actions
            seq_mask = batch.seq_masks
            rews = batch.rewards
            if pred_reward:
                pred_next_obs, pred_rew_seq = agent.predict_next_state_and_reward(obs[:, 0], acts)
            else:
                pred_next_obs = agent.predict_next_state(obs[:, 0], acts)
            pred_error_list.append(np.mean(np.square(pred_next_obs[:, :-1] - obs[:, 1:]) * seq_mask[:, 1:, None]))
            if pred_reward:
                rew_error_list.append(np.mean(np.square(np.squeeze(pred_rew_seq) - rews) * seq_mask))

            # state = obs[:, 0]
            # for i in range(acts.shape[1]-1):
            #     if pred_reward:
            #         pred_next_obs, pred_rew = agent.predict_next_state_and_reward(state, acts[:, i])
            #     else:
            #         pred_next_obs = agent.predict_next_state(state, acts[:, i])
            #     pred_error_list.append(np.mean(np.square(pred_next_obs - obs[:, i+1]) * seq_mask[:, i+1, None]))
            #     state = pred_next_obs
            #     if pred_reward:
            #         rew_error_list.append(np.mean(np.square(pred_rew - rews[:, i]) * seq_mask[:, i]))

        stats[f'observation_seq_pred_error_{seq_length}'] = np.mean(pred_error_list)
        if pred_reward:
            stats[f'reward_seq_pred_error{seq_length}'] = np.mean(rew_error_list)
    print(stats)
    return stats



