import numpy as np
import torch
import torch.nn.functional as F
import exp_utils as PQ
import matplotlib.pyplot as plt
import pickle


def softminus(x: torch.Tensor):
    return -F.softplus(-x)


def swish(x):
    return x * x.sigmoid()


def barrier(states: torch.Tensor):
    max_angle = +np.pi / 2
    min_angle = -np.pi / 2

    # def interval_barrier(x, lb, rb):
    #     eps = 1e-30
    #     b = -((x - lb + eps) * (rb - x + eps)).log() + 2 * np.log((rb - lb) / 2)
    #     return torch.where(torch.as_tensor((lb < x) & (x < rb)), b, torch.tensor(100., device=x.device))

    def interval_barrier(x, lb, rb):
        x = (x - lb) / (rb - lb)
        eps = 1e-6
        b = -((x + eps) * (1 - x + eps) / (0.5 + eps)**2).log()
        b_min, b_max = 0, -np.log(4 * eps)
        grad = 1. / eps - 1
        out = grad * torch.max(-x, x - 1)
        return torch.where(torch.as_tensor((0 < x) & (x < 1)), b, b_max + out)

    b1 = interval_barrier(states[..., 0], min_angle, max_angle)
    return b1
    # b2 = interval_barrier(states[..., 1], -1, 1)
    # return (b1 + b2) / 2


def plot_pendulum_set(fns, device, clouds, filename, title, max_thresh=3, *,
                      x_min=-np.pi / 2, x_max=np.pi / 2, y_min=-2, y_max=2,
                      encode=lambda x: x, decode=lambda x: (x[:, 0], x[:, 1]), xlabel="angle", ylabel="angvel"):

    xs = np.linspace(x_min, x_max, 201)
    ys = np.linspace(y_min, y_max, 201)

    X, Y = np.meshgrid(xs, ys)
    # if getattr(env, 'obs_type', 'state') == 'state':
    #     points = torch.tensor([np.cos(X), np.sin(X), Y], dtype=torch.float32, device=device).permute(1, 2, 0)
    # else:
    points = torch.tensor(encode(np.array([X, Y])), dtype=torch.float32, device=device).permute(1, 2, 0)
    values = {key: fn(points).cpu().detach().numpy() for key, fn in fns.items()}

    fig, axes = plt.subplots(nrows=1, ncols=len(values), figsize=(8 * len(values), 6))
    if not isinstance(axes, np.ndarray):
        axes = [axes]

    cmaps = {
        'hardD': plt.cm.RdBu,
        'softD': plt.cm.BrBG,
        'U': plt.cm.PRGn,
        'L': plt.cm.BrBG,
        'barrier': plt.cm.BrBG,
        'logBarrier': plt.cm.BrBG,
    }

    for ax, (key, value) in zip(axes, values.items()):
        if key in ['L', 'U']:
            vmin, vmax = 0, 2
            # vmin, vmax = 0.5, 1.5
            value = value + 1
        elif key == 'logBarrier':
            vmin, vmax = -3, 3
        else:
            thresh = max(min(np.max(value), -np.min(value), max_thresh), 1e-6)
            vmin, vmax = thresh, -thresh
        # thresh = max_thresh

        im = ax.imshow(value, cmap=cmaps[key], extent=[x_min, x_max, y_min, y_max], aspect='auto', origin='lower',
                       vmax=vmax, vmin=vmin)
        CS = plt.contour(X, Y, values['L'] + 1, levels=[1.], colors=['tab:orange'])
        fig.colorbar(im, ax=ax)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.clabel(CS)
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)

        if key == 'L':
            color = {'traj': 'C8', 's': 'C3', 'expl': 'C6'}
            for i, (name, cloud) in enumerate(clouds.items()):
                # if getattr(env, 'obs_type', 'state') == 'state':
                #     xs = cloud[:, 0]
                #     ys = cloud[:, 1]
                # else:
                #     xs = np.arctan2(cloud[:, 1], cloud[:, 0])
                #     ys = cloud[:, 2]
                xs, ys = decode(cloud)
                # alpha = 0.2 if len(cloud) >= 1000 else 1
                alpha = 1 / max(np.log(len(clouds) / 100.), 1.)
                ax.plot(xs, ys, label=name, markersize=1, ls=' ', marker='o', color=color[name], alpha=alpha)

            if len(clouds):
                ax.legend(loc=1)

        ax.set_title(title + f", fn = {key}")
        ax.grid()

    fig.tight_layout()
    fig.savefig(filename, dpi=150)
    plt.close(fig)


def find_max_barrier(states, model, policy, barrier, horizon):
    max_barrier = barrier(states)
    for T in range(horizon):
        if T % 10 == 0: print(T)
        actions = policy(states)
        next_states = model(states, actions)
        states = next_states
        max_barrier = max_barrier.max(barrier(states))
    return max_barrier


def plot_policy_safe_region(oracle, policy, device):
    x_min, x_max = -np.pi / 2, np.pi / 2
    xs = np.linspace(x_min, x_max, 201)
    ys = np.linspace(-1, 1, 201)

    X, Y = np.meshgrid(xs, ys)
    states = torch.tensor([X, Y], dtype=torch.float32, device=device).permute(1, 2, 0)

    max_barrier = find_max_barrier(states, oracle, policy, barrier, 200)

    fig, ax = plt.subplots()
    im = ax.imshow(max_barrier, cmap=plt.cm.RdBu, extent=[x_min, x_max, -1, 1], aspect='auto', origin='lower')
    ax.set_xlabel('angle')
    ax.set_ylabel('angular vel')
    ax.set_title('real safe zone')
    ax.grid()

    fig.colorbar(im)
    fig.tight_layout()
    fig.savefig(PQ.log_dir / 'safe_region.png', dpi=150)
    plt.close(fig)


def from_state_to_observation(x):
    return torch.stack([x[..., 0].cos(), x[..., 0].sin(), x[..., 1]]).t()


def from_observation_to_state(x):
    return torch.stack([torch.atan2(x[..., 1], x[..., 0]), x[..., 2]]).t()


@torch.no_grad()
def plot_real(env, L, U, s, filename, device):
    fig, ax = plt.subplots()

    xs = torch.linspace(-1, 1, 100000, device=device)[:, None]
    L_s = L(xs)
    U_s = U(xs)
    hardD = torch.where(L_s <= 0, U_s, L_s * 0)
    plt.plot(xs.cpu().numpy(), L_s.clamp(max=2).cpu().numpy(), label='L')
    plt.plot(xs.cpu().numpy(), U_s.clamp(max=2).cpu().numpy(), label='U')
    plt.plot(xs.cpu().numpy(), hardD.clamp(max=2).cpu().numpy(), label='hardD')
    s = s.cpu().detach().numpy()
    plt.plot(s, np.zeros_like(s), ls='', label='s', marker='.')

    lb = -env.barrier
    rb = env.barrier

    ax.axvline(lb, color='C3')
    ax.axvline(rb, color='C3')
    ax.set_xlim(-1.1, 1.1)

    ax.grid()
    ax.set_xlabel('s')
    ax.set_ylabel('value')
    ax.legend(loc=1)
    fig.tight_layout()
    fig.savefig(filename, dpi=150)
    plt.close(fig)


@torch.no_grad()
def plot_x_vs_L(buf_path, L, *, filename, title):
    fig, ax = plt.subplots(figsize=(8, 6))

    with open(buf_path, "rb") as f:
        buf = pickle.load(f)
    states = buf.state.cuda()
    next_states = buf.state.cuda()
    xs = states[..., 0].cpu().numpy()
    Ls = L(states).cpu().numpy()
    Us = L(next_states).cpu().numpy()
    barrier = L.env_barrier_fn(states).cpu().numpy()
    indices = np.where(Ls <= 0.1)[0]

    colors = ['C3' if l < 0 and u > 0 else 'C0' for l, u in zip(Ls[indices], Us[indices])]
    ax.scatter(xs[indices], Ls[indices], label="Ls", marker='.', s=1, c=colors)
    # ax.plot(xs[indices], barrier[indices], label="Barrier", ls='', markersize=1, marker='.')
    ax.set_xlabel('$x(s)$')
    ax.set_ylabel("$L(s)$")
    ax.set_title(title)
    ax.legend()
    ax.grid()

    fig.tight_layout()
    fig.savefig(filename, dpi=150)
    plt.close(fig)

