import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Generator, Optional, Tuple

import pandas as pd
from tqdm import tqdm, trange
from trajdata import UnifiedDataset
from trajdata.caching import EnvCache, SceneCache
from trajdata.caching.df_cache import DataFrameCache
from trajdata.data_structures import Scene


def process_scene(scene_path: Path, cache_path: Path):
    scene = EnvCache.load(scene_path)
    scene_cache = DataFrameCache(cache_path, scene)
    df = scene_cache.scene_data_df
    df = df.assign(scene_id=scene.name)
    df.set_index("scene_id", append=True, inplace=True)
    return df


def parallel_extract_dataframes(
    dataset: UnifiedDataset,
    max_workers: int = os.cpu_count(),
) -> pd.DataFrame:
    dataframes = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(process_scene, scene_path, dataset.cache_path) for scene_path in tqdm(dataset._scene_index)
        ]

        for future in tqdm(
            as_completed(futures),
            desc=f"Extracting Dataframes ({max_workers} CPUs)",
            total=len(futures),
        ):
            dataframes.append(future.result())

    overall_df = pd.concat(dataframes, axis=0)
    overall_df = overall_df.swaplevel("scene_id", "agent_id")
    overall_df = overall_df.swaplevel("agent_id", "scene_ts")
    return overall_df


def process_scene_idx(scene_path: Path, cache_path: Path) -> pd.DataFrame:
    scene = EnvCache.load(scene_path)
    agent_type_map = {agent.name: agent.type.name for agent in scene.agents}

    scene_dir = SceneCache.scene_cache_dir(cache_path, scene.env_name, scene.name)
    df = pd.read_feather(
        scene_dir / DataFrameCache._agent_data_file(scene.dt),
        use_threads=False,
    ).set_index(["agent_id", "scene_ts"])
    df["agent_type"] = df.index.get_level_values("agent_id").map(agent_type_map)
    df = df.assign(scene_id=scene.name)
    return df


def get_hist_cache_path(plot_dir: Path, env_name: str, chunk_start: int, chunk_end: int) -> Path:
    return plot_dir / f"{env_name}_{chunk_start}_{chunk_end}.npz"


def parallel_generate_dataframes(
    dataset: UnifiedDataset,
    max_workers: int = 10 * os.cpu_count(),
    chunk_size: int = 1000,
    max_scenes: Optional[int] = None,
    hist_cache_dir: Optional[Path] = None,
    skip_if_cached: bool = False,
) -> Generator[Tuple[pd.DataFrame, Path], None, None]:
    max_scenes = dataset.num_scenes() if max_scenes is None else min(dataset.num_scenes(), max_scenes)

    for chunk_beg in trange(0, max_scenes, chunk_size, desc="DataFrame Chunks", position=0):
        chunk_end = min(chunk_beg + chunk_size, max_scenes)

        curr_cache_file = get_hist_cache_path(hist_cache_dir, list(dataset.envs_dict.keys())[0], chunk_beg, chunk_end)
        if curr_cache_file.exists() and skip_if_cached:
            yield None, curr_cache_file
        else:
            dataframes = []
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = list()
                curr_len = chunk_end - chunk_beg
                for scene_idx in trange(
                    chunk_beg,
                    chunk_end,
                    total=curr_len,
                    desc=f"{chunk_beg} to {chunk_end}",
                    position=1,
                    leave=False,
                ):
                    futures.append(
                        executor.submit(process_scene_idx, dataset._scene_index[scene_idx], dataset.cache_path)
                    )

                for future in tqdm(
                    as_completed(futures),
                    desc=f"Extracting Dataframes ({max_workers} CPUs)",
                    total=len(futures),
                    leave=False,
                    position=2,
                ):
                    dataframes.append(future.result())

            overall_df = pd.concat(dataframes, axis=0)
            overall_df.set_index(["scene_id", "agent_type"], append=True, inplace=True)
            overall_df = overall_df.reorder_levels(["scene_id", "agent_id", "agent_type", "scene_ts"])

            yield overall_df, curr_cache_file
