import numpy as np
from algo.utils.wsre import wsre
import datetime

def eval(args, env, trainers, disc_trainer, memories, writer, step, algo="wasserstein"):
    SR_MIN = np.log(1 / args.num_modes) * 10
    avg_reward = 0.
    avg_dsr = 0.
    avg_acc = 0.
    avg_wsr = 0.
    episodes = 10
    reward_per_mode = [[] for _ in range(args.num_modes)]
    for i in range(episodes):
        state = env.reset()
        episode_reward = 0.
        episode_dsr = 0.
        episode_steps = 0
        done = False
        _next_state = []
        label = i % args.num_modes
        while not done:
            action, _ = trainers[label].act(state, eval=True)

            next_state, reward, done, _ = env.step(action)
            dsr = max(disc_trainer.score(next_state, np.array([label])), SR_MIN)
            episode_reward += reward
            episode_dsr += dsr
            episode_steps += 1

            _next_state.append(next_state)
            state = next_state

        assert len(_next_state) == episode_steps
        labels = np.array([label for _ in range(episode_steps)])
        score = disc_trainer.score(np.stack(_next_state, axis=0), labels)
        avg_acc += np.mean(np.exp(score))

        if args.num_modes == 1:
            srs = np.array([0 for _ in range(episode_steps)])
        else:
            state_batch = np.stack(_next_state, axis=0)
            _dist = []
            sum_dist = []
            for j in range(args.num_modes):
                if j!=label:
                    dist = np.zeros(episode_steps)
                    if len(memories[j]) > args.max_episode_len: # Ensure we have a non-empty target batch
                        target_state_batch = list(memories[j].dump(args.max_episode_len))[0]
                        dist = wsre(state_batch, target_state_batch)
                    sum_dist.append(np.sum(dist))
                    _dist.append(dist)
            min_dist_idx = np.argmin(sum_dist)
            srs = _dist[min_dist_idx]

        wsr = sum(srs.tolist())
        avg_wsr += wsr

        avg_reward += episode_reward
        avg_dsr += episode_dsr
        reward_per_mode[label].append(episode_reward)

    avg_reward /= episodes
    avg_dsr /= episodes
    if algo == "wasserstein":
        avg_r = avg_reward * args.rc + avg_wsr * args.src
    else:
        avg_r = avg_reward * args.rc + avg_dsr * args.src
    avg_acc /= episodes
    avg_wsr /= episodes
    avg_reward_per_mode = [np.mean(reward_per_mode[k]) for k in range(args.num_modes)]
    max_avg_reward = max(avg_reward_per_mode)
    min_avg_reward = min(avg_reward_per_mode)

    writer.add_scalar('test/avg_reward', avg_reward, step)
    writer.add_scalar('test/avg_dsr', avg_dsr, step)
    writer.add_scalar('test/avg_r', avg_r, step)
    writer.add_scalar('test/avg_acc', avg_acc, step)
    writer.add_scalar('test/avg_wsr', avg_wsr, step)
    writer.add_scalar('test/max_avg_reward', max_avg_reward, step)
    writer.add_scalar('test/min_avg_reward', min_avg_reward, step)
    for k in range(args.num_modes):
        writer.add_scalar('test/mode_{}_avg_reward'.format(k), avg_reward_per_mode[k], step)

    T = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    print("----------------------------------------")
    print("[{}] Test Episodes: {}, Avg. Reward: {}, Avg. DSR: {}, Avg. WSR: {}, Avg. R: {}, Avg. Acc: {}".format(T, episodes, round(avg_reward, 2), round(avg_dsr, 2), round(avg_wsr, 2), round(avg_r, 2), round(avg_acc, 2)))
    print("----------------------------------------")