import numpy as np
import random
from typing import List, Tuple, Dict


def sample(buffer, data_size: int, local_data_size: int):
    """
    Randomly sample specified number of data from the buffer.
    """
    length = data_size
    assert 1 <= local_data_size

    indices = np.random.randint(0, length, local_data_size)
    return buffer[indices]


def collect_replaytrajs(start_list: List[int], end_list: List[int], local_data_size: int) -> Tuple[List[Tuple[int, int]], List[int]]:
    """
    Collect continuous trajectory segments from trajectory list for replay dataset.
    """
    traj_lengths = [end - start + 1 for start, end in zip(start_list, end_list)]
    trajs = list(zip(start_list, traj_lengths))

    collected_trajs = []
    total_length = 0
    total_length_list = []

    if end_list[-1] > local_data_size:
        random_start = random.randint(1, end_list[-1] - local_data_size)
    else:
        random_start = 1

    random_index = next((i for i, (start, _) in enumerate(trajs) if start > random_start), None)

    for start, length in trajs[random_index:]:
        if total_length + length > local_data_size:
            remaining_length = local_data_size - total_length
            if remaining_length > 50:
                collected_trajs.append((start, start + remaining_length - 1))
                total_length_list.append(remaining_length)
                total_length += remaining_length
            break
        collected_trajs.append((start, start + length - 1))
        total_length_list.append(length)
        total_length += length

    return collected_trajs, total_length_list


def collect_trajs(start_list: List[int], end_list: List[int], local_data_size: int) -> Tuple[List[Tuple[int, int]], List[int]]:
    """
    Randomly collect trajectory segments from trajectory list.
    """
    traj_lengths = [end - start + 1 for start, end in zip(start_list, end_list)]
    trajs = list(zip(start_list, traj_lengths))
    random.shuffle(trajs)

    collected_trajs = []
    total_length = 0
    total_length_list = []

    for start, length in trajs:
        if length > 0:
            if total_length + length > local_data_size:
                remaining_length = local_data_size - total_length
                if remaining_length > 10:
                    collected_trajs.append((start, start + remaining_length - 1))
                    total_length_list.append(remaining_length)
                    total_length += remaining_length
                break
            collected_trajs.append((start, start + length - 1))
            total_length_list.append(length)
            total_length += length

    return collected_trajs, total_length_list


def extract_and_combine_trajs(dataset: Dict, collected_trajs: List[Tuple[int, int]]) -> Dict:
    """
    Extract specified trajectory segments from dataset and merge them into a new dataset.
    """
    new_dataset = {key: [] for key in dataset.keys()}

    for start, end in collected_trajs:
        for key in dataset.keys():
            segment = dataset[key][start:end + 1]
            new_dataset[key].append(segment)

    for key in new_dataset.keys():
        new_dataset[key] = np.concatenate(new_dataset[key], axis=0)

    return new_dataset
