import matplotlib.pyplot as plt
from sklearn import datasets

from rlkit.torch import pytorch_util as ptu

import torch
from torch_geometric.utils import k_hop_subgraph

from torch_geometric.data import Data
from torch_geometric.transforms import Compose, Distance, KNNGraph, ToDevice

import numpy as np

num_samples = 1000
azim_val = -70
elev_val = 12

point_size = 8

sr_points, sr_color = datasets.make_swiss_roll(n_samples=num_samples, random_state=10)

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")
fig.add_axes(ax)
ax.scatter(
    sr_points[:, 0], sr_points[:, 1], sr_points[:, 2], c=sr_color, s=point_size, alpha=0.9
)

ax.view_init(azim=azim_val, elev=elev_val)

plt.savefig("swiss_roll_scatter.png", dpi=700, transparent=True, pad_inches=0, bbox_inches="tight")

graph_sample_size = num_samples
random_idxs = np.random.permutation(num_samples)[:graph_sample_size]

plt.clf()

sample_sr_points, sample_sr_colors = sr_points[random_idxs], sr_color[random_idxs]

composed_transforms = Compose([ToDevice(ptu.device), KNNGraph(force_undirected=True, k=4), Distance()])

sample_tensor = torch.from_numpy(sample_sr_points)
obs_data = Data(pos=sample_tensor)
obs_dataset = composed_transforms(obs_data)

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")
fig.add_axes(ax)
ax.scatter(
    sample_sr_points[:, 0], sample_sr_points[:, 1], sample_sr_points[:, 2], c=sample_sr_colors, s=point_size, alpha=0.9
)
all_edges_np = obs_dataset.edge_index.to("cpu").numpy()
all_pos = obs_dataset.pos.to("cpu").numpy()
source_np = all_pos[all_edges_np[0]]
target_np = all_pos[all_edges_np[1]]
print(source_np.shape, target_np.shape)
#ax.plot(source_np, target_np, color='black')

for node_1_pos, node_2_pos in zip(source_np, target_np):
    ax.plot([node_1_pos[0], node_2_pos[0]], [node_1_pos[1], node_2_pos[1]], [node_1_pos[2], node_2_pos[2]],
            color='black', linewidth=0.45)

ax.view_init(azim=azim_val, elev=elev_val)


plt.savefig("edged_swiss_roll.png", dpi=700, transparent=True, pad_inches=0, bbox_inches="tight")

plt.clf()

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d")
fig.add_axes(ax)

sample_nodes_size = 50
nodes_sample_idxs = torch.from_numpy(np.random.permutation(all_pos.shape[0])[:sample_nodes_size])

nodes_subset, edges_subset, _, _ = k_hop_subgraph(nodes_sample_idxs, 2, obs_dataset.edge_index)


edges_subset_np = edges_subset.numpy()
nodes_subset_np = nodes_subset.numpy()

reduced_sample = []


for node_idx in np.arange(all_pos.shape[0]):
    if node_idx not in nodes_subset_np:
        reduced_sample.append(node_idx)

reduced_sample_sr_points = sample_sr_points[reduced_sample]
reduced_sample_sr_colors = sample_sr_colors[reduced_sample]
ax.scatter(
    reduced_sample_sr_points[:, 0], reduced_sample_sr_points[:, 1], reduced_sample_sr_points[:, 2],
    c=reduced_sample_sr_colors, s=point_size, alpha=0.9
)

sgd_sample_pos = sample_sr_points[nodes_subset_np]
ax.scatter(
    sgd_sample_pos[:, 0], sgd_sample_pos[:, 1], sgd_sample_pos[:, 2],
    c="red", s=point_size, alpha=0.7
)

sgd_edge_sampled_src = sample_sr_points[edges_subset_np[0]]
sgd_edge_sampled_target = sample_sr_points[edges_subset_np[1]]

for node_1_pos, node_2_pos in zip(sgd_edge_sampled_src, sgd_edge_sampled_target):
    ax.plot([node_1_pos[0], node_2_pos[0]], [node_1_pos[1], node_2_pos[1]], [node_1_pos[2], node_2_pos[2]],
            color='red', linewidth=0.45)

reduced_sample_edges_1 =[]
reduced_sample_edges_0 =[]

for node_src, node_target in zip(all_edges_np[0], all_edges_np[1]):
    if node_src in reduced_sample or node_target in reduced_sample:
        reduced_sample_edges_0.append(node_src)
        reduced_sample_edges_1.append(node_target)

reduced_sample_edge_pos_0 = all_pos[reduced_sample_edges_0]
reduced_sample_edge_pos_1 = all_pos[reduced_sample_edges_1]
for node_1_pos, node_2_pos in zip(reduced_sample_edge_pos_0, reduced_sample_edge_pos_1):
    ax.plot([node_1_pos[0], node_2_pos[0]], [node_1_pos[1], node_2_pos[1]], [node_1_pos[2], node_2_pos[2]],
            color='black', linewidth=0.45)

ax.view_init(azim=azim_val, elev=elev_val)
plt.savefig("sampled_edge_swiss_roll.png", dpi=700, transparent=True, pad_inches=0, bbox_inches="tight")
