import pickle
from pathlib import Path
from typing import Callable, Dict, List, Tuple

import numpy as np
from sklearn.cross_decomposition import CCA


def split_trajectories(
        dataset: Dict[str, np.ndarray]) -> Dict[int, List[np.ndarray]]:
    """split dataset by task_ids.

    Args: dataset (Dict[str, np.ndarray]):
        dataset, specific keys ("terminals", "observations", "infos/goal_id") are required

    Returns:
        Dict[int, List[np.ndarray]]: a dictionary that maps task_id to list of trajectories
    """

    terminals = dataset["terminals"]
    starts = np.full_like(terminals, False)
    starts[0] = True
    starts[1:] = terminals[:-1]

    start_idxs = np.where(starts == True)[0]
    idxs = np.concatenate((start_idxs, np.array([len(starts)])))
    observations = dataset["observations"]
    trajectories = {}
    for i in range(len(idxs) - 1):
        task_id = dataset["infos/goal_id"][idxs[i]]
        if task_id in trajectories:
            trajectories[task_id].append(observations[idxs[i]:idxs[i + 1]])
        else:
            trajectories[task_id] = [observations[idxs[i]:idxs[i + 1]]]

    for task_id in trajectories:
        np.random.shuffle(trajectories[task_id])

    return trajectories


def create_dataset_for_cca(
    source_dataset: Dict[str, np.ndarray],
    target_dataset: Dict[str, np.ndarray],
    task_ids: List[int],
) -> Tuple[np.ndarray, np.ndarray]:
    """permute trajectories and pad to align the lengths for cca

    Args:
        source_dataset (Dict[str, np.ndarray]):
            dataset from source domain, specific keys ("terminals", "observations", "infos/goal_id") are required

        target_dataset (Dict[str, np.ndarray]):
            dataset from target domain, specific keys ("terminals", "observations", "infos/goal_id") are required

        task_ids (List[int]): task_ids

    Returns:
        Tuple[np.ndarray, np.ndarray]: permuted source_trajectories and target_trajectories
    """

    source_trajs = split_trajectories(source_dataset)
    target_trajs = split_trajectories(target_dataset)

    source_observations = []
    target_observations = []
    for task_id in task_ids:
        num_source_trajs = len(source_trajs[task_id])
        num_target_trajs = len(target_trajs[task_id])
        num_trajs = min(num_source_trajs, num_target_trajs)
        for i in range(num_trajs):
            d = np.abs(
                len(source_trajs[task_id][i]) - len(target_trajs[task_id][i]))

            if len(source_trajs[task_id][i]) < len(target_trajs[task_id][i]):
                last_obs = np.tile(source_trajs[task_id][i][-1][None], (d, 1))
                source_trajs[task_id][i] = np.concatenate(
                    (source_trajs[task_id][i], last_obs))

            elif len(source_trajs[task_id][i]) > len(target_trajs[task_id][i]):
                last_obs = np.tile(target_trajs[task_id][i][-1][None], (d, 1))
                target_trajs[task_id][i] = np.concatenate(
                    (target_trajs[task_id][i], last_obs))

        concat_source_obs = np.concatenate(source_trajs[task_id][:num_trajs])
        concat_target_obs = np.concatenate(target_trajs[task_id][:num_trajs])

        source_observations.append(concat_source_obs)
        target_observations.append(concat_target_obs)

    source_observations = np.concatenate(source_observations)
    target_observations = np.concatenate(target_observations)

    return source_observations, target_observations


def train_cca(
    source_dataset: Dict[str, np.ndarray],
    target_dataset: Dict[str, np.ndarray],
    task_ids: List[int],
    n_components: int,
    model_path: Path,
) -> Tuple[Callable, Callable]:
    """train cca and return transform functions

    Args:
        source_dataset (Dict[str, np.ndarray]):
            dataset from source domain, specific keys ("terminals", "observations", "infos/goal_id") are required
        
        target_dataset (Dict[str, np.ndarray]):
            dataset from target domain, specific keys ("terminals", "observations", "infos/goal_id") are required

        task_ids (List[int]):
            task_ids for proxy tasks
        
        n_components (int):
            number of components to keep
        
        model_path (Path):
            path to the pkl file where the trained model will be saved

    Returns:
        Tuple[Callable, Callable]: source_transform, target_transform
    """

    source_observations, target_observations = create_dataset_for_cca(
        source_dataset=source_dataset,
        target_dataset=target_dataset,
        task_ids=task_ids,
    )

    cca = CCA(n_components=n_components, max_iter=2000)
    cca.fit(source_observations, target_observations)

    def source_transform(x: np.ndarray):
        if len(x.shape) == 1:
            x = x[None]
        return cca.transform(x)

    def target_transform(x: np.ndarray):
        if len(x.shape) == 1:
            x = x[None]
        return cca.transform(source_observations[:2], x)[1]

    pickle.dump(cca, open(model_path, "wb"))
    print("CCA model saved to", model_path)

    return source_transform, target_transform