import numpy as np

def region_occupancy(trajectories, res=64):
    all_concatenated = np.concatenate(trajectories)
    H, xedges, yedges = np.histogram2d(all_concatenated[:, 1], all_concatenated[:, 0], bins=(res, res), range=((0, 1), (0, 1)), density=True)
    H = H.T

    return H, xedges, yedges


def crop_signal_with_rotation(signal, region, min_len=24):
    origin = region[0:2]
    rotation = region[2]
    size = region[3:5]
    
    signal = signal - origin
    signal = signal / size
    rotation = rotation * np.pi / 180

    rotated_signal = np.zeros_like(signal)
    rotated_signal[:, 0] = np.cos(rotation) * signal[:, 0] + np.sin(-rotation) * signal[:, 1]
    rotated_signal[:, 1] = np.sin(rotation) * signal[:, 0] + np.cos(-rotation) * signal[:, 1]
    signal = rotated_signal
    
    parts = []
    last_outside_region = 0
    
    for i in range(len(signal)):
        lat = signal[i, 0]
        lon = signal[i, 1]
        # print(lat, lon)
        # print(min_lat, max_lat, min_lon, max_lon)
        if lat < 0 or lat > 1 \
            or lon < 0 or lon > 1:
            
            # if i - last_outside_region + 1 >= min_len:
            part = signal[last_outside_region+1:i]
            if len(part) >= min_len:
                parts.append(part)
            
            last_outside_region = i
        
    if last_outside_region != i:
        part = signal[last_outside_region+1:i]
        if len(part) >= min_len:
            parts.append(part)
                    
    return parts


def extract_signals(dataset, region, min_len=24):
    all_parts = []
    # origin = (region_lat, region_lon)
    # size = (height, width)

    for traj in dataset:
        parts = crop_signal_with_rotation(traj, region)
        all_parts.extend(parts)
            
    return all_parts

def random_region(size, bounds):
    height, width = size
    min_lat, max_lat, min_lon, max_lon = bounds
    
    region = np.zeros(5, dtype=np.float32)
    
    region[0] = np.random.uniform(min_lat, max_lat, 1)
    region[1] = np.random.uniform(min_lon, max_lon, 1)
    region[2] = np.random.uniform(0, 360, 1)
    region[3:5] = size
    
    return region


def sample_regions(n, size, bounds):
    height, width = size
    min_lat, max_lat, min_lon, max_lon = bounds
    
    regions = np.zeros((n, 5))
    
    regions[:, 0] = np.random.uniform(min_lat, max_lat-height, n)
    regions[:, 1] = np.random.uniform(min_lon, max_lon-width, n)
    regions[:, 2] = np.random.uniform(0, 360, n)
    regions[:, 3:5] = size
    
    return regions


def extract_trajectories(region_trajectories, n_trajectories_per_region, seq_len, region, heatmap):
    region_subbatch = []
    j = 0
    
    while len(region_subbatch) < n_trajectories_per_region:
        if j == len(region_trajectories):
            print(f"Warning: Could not find enough trajectories long enough! Found: {len(region_subbatch)}. Restarting.")
            return []
        
        traj = region_trajectories[j]

        if len(traj) <= seq_len:
            j += 1
            continue

        idx = np.random.randint(0, len(traj) - seq_len)

        region_subbatch.append((region, heatmap, traj[idx:idx+seq_len]))

        j += 1
        
    return region_subbatch


def sample_random_region(dataset, n_trajectories_per_region, seq_len, size, bounds):
    region_subbatch = []

    while len(region_subbatch) < n_trajectories_per_region:
        region_trajectories = []
        while len(region_trajectories) < n_trajectories_per_region:
            region = sample_regions(1, size, bounds).flatten()
            region_trajectories = extract_signals(dataset, region)
    
        np.random.shuffle(region_trajectories)
        heatmap, _, _ = region_occupancy(region_trajectories, res=64)
    
        region_subbatch = extract_trajectories(region_trajectories, n_trajectories_per_region, seq_len, region, heatmap)

    return region_subbatch


def generate_batch_transposed(n_regions, n_trajectiories_per_region, seq_len):
    batch = []
    for i in range(n_regions):
        region_subbatch = sample_region(n_trajectiories_per_region, seq_len)
        batch.extend(region_subbatch)
    return batch


def generate_batch_numpy(n_regions, n_trajectiories_per_region, seq_len):
    regions = []
    heatmaps = []
    trajectories = []
    for i in range(n_regions):
        for region, heatmap, traj in sample_region(n_trajectiories_per_region, seq_len):
            regions.append(region)
            heatmaps.append(heatmap)
            trajectories.append(traj)
    
    regions = np.stack(regions)
    heatmaps = np.stack(heatmaps)
    trajectories = np.stack(trajectories)
    
    return regions, heatmaps, trajectories