# %%
from data import load_dataset_from_str
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches
from tddpm.data_generation import random_region, region_occupancy, crop_signal_with_rotation
from tddpm.evaluation import unscaled_occupancy, undo_rotation_and_scaling
from tddpm.crop import crop_signal_with_rotation as crop_signal_with_rotation_optim
import random
from tqdm import tqdm

# %%
unconditional_dataset = load_dataset_from_str('porto-100', prefix='data/')

# %%
ORIGINAL_REGION = (39.9, 40, 116.3, 116.4)
MIN_LAT, MAX_LAT, MIN_LON, MAX_LON = ORIGINAL_REGION
MAP_HEIGHT = MAX_LAT - MIN_LAT
MAP_WIDTH = MAX_LON - MIN_LON

REGION_HEIGHT = 1/10 * MAP_HEIGHT
REGION_WIDTH = 1/10 * MAP_WIDTH
SIZE = (REGION_HEIGHT, REGION_WIDTH)
SEQ_LEN = 128

# %%
MIN_LAT, MIN_LON = unconditional_dataset.min(axis=(0, 1))
MAX_LAT, MAX_LON = unconditional_dataset.max(axis=(0, 1))

# %%
TRAIN_BOUNDS = MIN_LAT, MAX_LAT, MIN_LON, MAX_LON
print(TRAIN_BOUNDS)

# %%
"""
heatmap_data, xedges, yedges = unscaled_occupancy(unconditional_dataset, 512)

plt.imshow(np.log(heatmap_data), extent=[yedges[0], yedges[-1], xedges[0], xedges[-1]], origin='lower', interpolation=None)

for i in range(100):
    region = random_region(size=SIZE, bounds=TRAIN_BOUNDS)
    region_lat, region_lon, rotation, height, width = region
    rect = patches.Rectangle((region_lon, region_lat), width, height, angle=rotation, linewidth=1, edgecolor='r', facecolor='none')
    plt.gca().add_patch(rect)
# plt.show()
"""

# %%
def region_heatmap(dataset, region, res=64):
    all_obs = np.concatenate(dataset, axis=0)
    all_obs -= region[:2]
    all_obs /= region[3:]

    rotation = region[2] * np.pi / 180
    cosrotation = np.cos(rotation)
    sinrotation = np.sin(rotation)
    cosnrotation = np.cos(-rotation)
    sinnrotation = np.sin(-rotation)

    rotated_obs = np.zeros_like(all_obs)
    rotated_obs[:, 0] = cosnrotation * all_obs[:, 0] + sinnrotation * all_obs[:, 1]
    rotated_obs[:, 1] = sinrotation * all_obs[:, 0] + cosnrotation * all_obs[:, 1]
    H, _, _ = np.histogram2d(rotated_obs[:, 1],
                            rotated_obs[:, 0], 
                            bins=(res, res), 
                            range=((0, 1), (0, 1)),
                            density=True)
    return H.T

# %%
def inside_region(traj, region):
    subtrajs = crop_signal_with_rotation_optim(traj.astype(np.float32), region.astype(np.float32), min_len=1)
    if len(subtrajs) > 0:
        max_len = 0
        for subtraj in subtrajs:
            max_len = max(max_len, len(subtraj))

        return max_len * 1. / len(traj)
    else:
        return 0

# %%
def find_region(traj, dataset, random_tries=10000):
    for _ in range(random_tries):
        region = random_region(size=SIZE, bounds=TRAIN_BOUNDS)
        subtrajs = crop_signal_with_rotation_optim(traj.astype(np.float32), region.astype(np.float32), min_len=SEQ_LEN)
        if len(subtrajs) > 0:
            return True, region, subtrajs[0], region_heatmap(dataset, region)
        
    return False, None, None, None

# %%
"""
found, region, scaled_traj, heatmap = find_region(traj, unconditional_dataset)

plt.imshow(heatmap, extent=(0, 1, 0, 1), origin='lower')
plt.plot(scaled_traj[:, 1], scaled_traj[:, 0], color='orange')
plt.show()

heatmap_data, xedges, yedges = unscaled_occupancy(unconditional_dataset, 512)

plt.imshow(heatmap_data, extent=[yedges[0], yedges[-1], xedges[0], xedges[-1]], origin='lower', interpolation=None)

plt.plot(traj[:, 1], traj[:, 0], color='orange')
region_lat, region_lon, rotation, height, width = region
rect = patches.Rectangle((region_lon, region_lat), width, height, angle=rotation, linewidth=1, edgecolor='r', facecolor='none')
plt.gca().add_patch(rect)
plt.show()
"""

# %% [markdown]
# # Pre-sampling this for easier training

# %%
rng = np.random.default_rng(12345)

found = np.zeros(len(unconditional_dataset), dtype=bool)

regions = []
heatmaps = []
trajectories = []

for i, traj in tqdm(enumerate(unconditional_dataset), total=len(unconditional_dataset)):
    for j in range(10):
        found, region, scaled_traj, heatmap = find_region(traj, unconditional_dataset)
        if found:
            regions.append(region)
            heatmaps.append(heatmap)
            trajectories.append(scaled_traj)
        else:
            break


# %%
filename = f'data/preprocessed/cabspotting-25-conditional-{SEQ_LEN}.npz'
np.savez_compressed(filename, regions=regions, heatmaps=heatmaps, trajectories=trajectories)

# %%



