"""
Very specific code for debugging the antmaze environment.
"""
import matplotlib
matplotlib.use('Agg')
from matplotlib import patches

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from functools import partial
from mpl_toolkits.axes_grid1 import make_axes_locatable

import gym
import d4rl
import numpy as np
import functools as ft

# from src import roomworld_utils
from .d4rl_ant import get_canvas_image
import os
import os.path as osp

class Visualizer:
    def __init__(self, env_name, viz_env, dataset):
        data_path = osp.abspath(osp.join(osp.dirname(__file__), f'antmaze_aux/{env_name}-aux.npz'))
        print('Attempting to load from: ', data_path)
        data = np.load(data_path)
        self.data = {k: data[k] for k in data}
        self.dataset = dataset
        self.viz_env = viz_env
        self.K = 6
    
    def get_metrics(self, policy_fn):
        directions = self.get_gradients(policy_fn)
        goods = np.clip(self.is_goods(directions), -2.0, 2.0)

        masks = 1.0 - self.data['masked'][::self.K, ::self.K]

        return {
            'average_advantage': np.mean(goods),
            'pct_aligned': np.mean(goods > 0),
            'masked_average_advantage': np.mean(goods * masks) / np.mean(masks),
            'masked_pct_aligned': np.mean((goods > 0) * masks) / np.mean(masks),
        }

    def is_goods(self, directions):
        X, Y = self.data['X'][::self.K], self.data['Y'][::self.K]
        directions = directions.reshape((len(Y), len(X), 2))
        nY, nX = self.data['V'].shape
        print(X.shape, Y.shape, nX, nY)

        goods = np.zeros(directions.shape[:-1])
        for i in range(len(X)):
            for j in range(len(Y)):
                adv = -float('inf')
                for dist in range(9, 12):
                    d = np.round(directions[j, i] * dist).astype(int) # x = x_i, y = y_j
                    new_adv = self.data['V'][
                        np.clip(j * self.K + d[1], 0, nY-1),
                        np.clip(i * self.K + d[0], 0, nX-1)
                    ] - self.data['V'][j * self.K, i * self.K]
                    adv = max(adv, new_adv)
                goods[j, i] = adv
        return goods

    def get_gradients(self, policy_fn, N=20):
        X, Y = np.meshgrid(self.data['X'][::self.K],self.data['Y'][::self.K])
        observations = np.array([X.flatten(), Y.flatten()]).T
        base_observation = np.copy(self.dataset['observations'][0])
        base_observations = np.tile(base_observation, (observations.shape[0], 1))
        base_observations[:, :2] = observations

        policies = policy_fn(base_observations)
        directions = policies / (1e-6 + np.linalg.norm(policies, axis=1, keepdims=True))
        return directions
    
    def policy_image(self, policy_fn):
        fig = plt.figure(tight_layout=True)
        canvas = FigureCanvas(fig)
        ax = plt.gca()
        
        X, Y = np.meshgrid(self.data['X'][::self.K],self.data['Y'][::self.K])
        directions = self.get_gradients(policy_fn)
        goods = np.clip(self.is_goods(directions), -2.0, 2.0)

        
        true_dx = self.data['dX'][::self.K, ::self.K] / 3.0
        true_dy = self.data['dY'][::self.K, ::self.K] / 3.0
        mesh = ax.quiver(X, Y, directions[..., 0], directions[..., 1], goods, cmap=plt.cm.coolwarm_r)
#         plt.colorbar()
#         plt.clim(-2, 2)
        self.viz_env.draw(ax)
        image = get_canvas_image(canvas)
        plt.close(fig)
        return image

    def get_distances(self, trajs):
        final_points = np.array([trajectory['observation'][-1][:2] for trajectory in trajs])
        final_points = np.stack([final_points[:, 1], final_points[:, 0]], axis=1)
        print(final_points.shape)
        from scipy.interpolate import interpn
        return interpn((self.data['Y'], self.data['X']), self.data['pV'], final_points, method='linear', bounds_error=False, fill_value=-300.0)

    def get_distance_metrics(self, trajs):
        import wandb
        distances = self.get_distances(trajs)
        bins = np.arange(self.data['pV'].min(), self.data['pV'].max(), 20)
        hist = np.histogram(distances, bins)
        metrics = {
            'average_distance': np.mean(distances),
            'pct_within_10': np.mean(distances > -10),
            'pct_within_20': np.mean(distances > -20),
            'median_distance': np.median(distances),
            'dist_hist': wandb.Histogram(np_histogram=hist),
        }
        return metrics