import io
import math

import numpy as np
import cv2
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt


def plot_trajectories(writer, label, epoch, env_traj, model_traj, means, stds):
    state_dim = env_traj[0].size
    model_states = [[obs[s] for obs in model_traj] for s in range(state_dim)]
    env_states = [[obs[s] for obs in env_traj] for s in range(state_dim)]

    means = [np.array([mean[s] for mean in means]) for s in range(state_dim)]
    stds = [np.array([std[s] for std in stds]) for s in range(state_dim)]

    cols = 1
    rows = math.ceil(state_dim / cols)

    plt.clf()
    fig, axes = plt.subplots(rows, cols, figsize=(9 * cols, 3 * rows))
    axes = axes.ravel()

    for i in range(state_dim):
        ax = axes[i]
        X = range(len(model_states[i]))

        ax.fill_between(X, means[i] + stds[i], means[i] - stds[i], color='r', alpha=0.5)
        ax.plot(env_states[i], color='k')
        ax.plot(model_states[i], color='b')
        ax.plot(means[i], color='r')

        if i == 0:
            ax.set_title('reward')
        elif i == 1:
            ax.set_title('terminal')
        else:
            ax.set_title('state dim {}'.format(i - 2))
    plt.tight_layout()

    buf = io.BytesIO()
    plt.savefig(buf, format='png', layout='tight')
    buf.seek(0)

    img = cv2.imdecode(np.fromstring(buf.getvalue(), dtype=np.uint8), -1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img.transpose(2, 0, 1) / 255.

    writer.add_image(label, img, epoch)

    plt.close()


'''
    writer video : [ batch x channels x timesteps x height x width ]
'''


def record_trajectories(writer, label, epoch, env_images, model_images=None):
    traj_length = len(env_images)
    if model_images is not None:
        assert len(env_images) == len(model_images)
        images = [np.concatenate((env_img, model_img)) for (env_img, model_img) in zip(env_images, model_images)]
    else:
        images = env_images

    ## [ traj_length, 2 * H, W, C ]
    images = np.array(images)
    images = torch.Tensor(images)

    ## [ traj_length, C, 2 * H, W ]
    images = images.permute(0, 3, 1, 2)
    ## [ B, traj_length, C, 2 * H, W ]
    images = images.unsqueeze(0)

    images = images / 255.
    images = images[:, :, 0].unsqueeze(2)

    print('[ Visualization ] Saving to {}'.format(label))
    fps = min(max(traj_length / 5, 2), 30)
    writer.add_video('video_' + label, images, epoch, fps=fps)


def visualize_policy(real_env, fake_env, policy, writer, timestep, max_steps=100, focus=None, label='model_vis',
                     img_dim=128):
    init_obs = real_env.reset()
    obs = init_obs.copy()

    observations_r = [obs]
    observations_f = [obs]
    rewards_r = [0]
    rewards_f = [0]
    terminals_r = [False]
    terminals_f = [False]
    means_f = [np.concatenate((np.zeros(2), obs))]
    stds_f = [np.concatenate((np.zeros(2), obs * 0))]
    actions = []

    i = 0
    term_r, term_f = False, False
    while not (term_r and term_f) and i <= max_steps:

        act = policy.actions_np(obs[None])[0]
        if not term_r:
            next_obs_r, rew_r, term_r, info_r = real_env.step(act)
            observations_r.append(next_obs_r)
            rewards_r.append(rew_r)
            terminals_r.append(term_r)

        if not term_f:
            next_obs_f, rew_f, term_f, info_f = fake_env.step(obs, act)
            observations_f.append(next_obs_f)
            rewards_f.append(rew_f)
            terminals_f.append(term_f)
            means_f.append(info_f['mean'])
            stds_f.append(info_f['std'])

        actions.append(act)

        if not term_f:
            obs = next_obs_f
        else:
            obs = next_obs_r

        i += 1

    terminals_r = np.array([terminals_r]).astype(np.uint8).T
    terminals_f = np.array([terminals_f]).astype(np.uint8).T
    rewards_r = np.array([rewards_r]).T
    rewards_f = np.array([rewards_f]).T

    rewards_observations_r = np.concatenate((rewards_r, terminals_r, np.array(observations_r)), -1)
    rewards_observations_f = np.concatenate((rewards_f, terminals_f, np.array(observations_f)), -1)
    plot_trajectories(writer, label, timestep, rewards_observations_r, rewards_observations_f, means_f, stds_f)
    record_trajectories(writer, label, epoch, images_r)
