import copy
import random

import numpy as np
import torch
from moving_out.benchmarks.moving_out import MovingOutEnv
from moving_out.utils.states_encoding import StatesEncoder
from moving_out.utils.trajectory_path_analyzer import TrajectoryPathAnalyzer
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset


class TrajectoryDataset(Dataset):
    def __init__(
        self,
        data_path,
        m,
        k,
        use_loaded_numpy=False,
        add_noise=False,
        noise_std=0.1,
        no_another_states=False,
        shift_another_states=False,
        shift_range=1,
        recombination_trajectories=False,
        recombination_analyzer=None,
        img_obs=False,
        img_obs_map_id=None,
    ):
        self.img_obs = img_obs
        if use_loaded_numpy:
            self.pair_data = data_path
            self.data = []
            if self.img_obs:
                self.img_obs_map_id = img_obs_map_id
                states_encoder = StatesEncoder()
                # self.env = MovingOutEnv(map_name = img_obs_map_id)
                # self.env.viewer = None
                self.env = None
                self.json_data = []
            for data in self.pair_data:
                if self.img_obs:
                    states_list = []
                    json_list = []
                    for states in data:
                        states_list.append(
                            [states_encoder.get_state_by_json(states[0])[0], states[1]]
                        )
                        json_list.append([states[0], states[1]])
                    self.data += [states_list]
                    self.json_data += [json_list]
                else:
                    self.data += [data]
        else:
            self.pair_data = np.load(data_path, allow_pickle=True)
            self.data = []
            for data in self.pair_data:
                self.data += [data]
        self.m = m
        self.k = k
        self.add_noise = add_noise
        self.noise_std = noise_std
        self.no_another_states = no_another_states
        self.shift_another_states = shift_another_states
        self.shift_range = shift_range
        self.recombination_trajectories = recombination_trajectories

        if self.recombination_trajectories:
            from moving_out.utils.discretizer import Discretizer

            self.discretizer = Discretizer(-1.2, 1.2, -23, 24)
            self.recombination_analyzer = recombination_analyzer

    def __len__(self):
        return sum(len(traj) for traj in self.data)

    def get_item_by_traj_idx(
        self, traj_idx, index, img_obs=False, recomb_idx_index=None
    ):
        # calculate the item indx then use self.get_item_by_indx
        traj = self.data[traj_idx]
        # Current step
        state = traj[index][0]

        repected_beginning = 0
        repected_end = 0

        # Previous m steps (with padding)
        prev_states = []
        for i in range(self.m):
            if index - i - 1 >= 0:
                prev_states.append(traj[index - i - 1][0])
            else:
                prev_states.append(traj[0][0])  # padding
                repected_beginning += 1
        # Reverse to get chronological order
        prev_states.reverse()

        if img_obs:
            traj_json = self.json_data[traj_idx]
            selected_traj = traj_json[index][0]
            prev_json_states = []
            for i in range(self.m):
                if index - i - 1 >= 0:
                    prev_json_states.append(traj_json[index - i - 1][0])
                else:
                    prev_json_states.append(traj_json[0][0])  # padding
                    # repected_beginning += 1
            prev_json_states.reverse()

        # Next k steps with n actions starting from t (with padding)
        next_actions = []
        the_other_agent_actions = []
        for i in range(self.k):
            if index + i < len(traj):
                action = traj[index + i][1][0]
                the_other_agent_action = traj[index + i][1][1]
                next_actions.append(action)
                the_other_agent_actions.append(the_other_agent_action)
            else:
                next_actions.append(traj[-1][1][0])
                the_other_agent_actions.append(traj[-1][1][1])  # padding
                repected_end += 1
        # Convert to tensors
        prev_states = torch.tensor(prev_states, dtype=torch.float32)
        next_actions = torch.tensor(next_actions, dtype=torch.float32)
        the_other_agent_action = torch.tensor(
            the_other_agent_actions, dtype=torch.float32
        )
        state = torch.tensor(state, dtype=torch.float32)
        if self.add_noise and not self.no_another_states:
            noise_std = self.noise_std
            noise_mean = 0

            if not img_obs:
                noise = torch.randn_like(state[8:16]) * noise_std + noise_mean
                state[8:16] += noise
                noise = torch.randn_like(prev_states[:, 8:16]) * noise_std + noise_mean
                prev_states[:, 8:16] += noise
            else:
                if recomb_idx_index is not None:
                    traj_idx, traj_index = recomb_idx_index

                    recomb_traj_json = self.json_data[traj_idx]
                    recomb_selected_traj = recomb_traj_json[index][0]
                    selected_traj["states"]["robot_2"]["pos"] = recomb_selected_traj[
                        "states"
                    ]["robot_2"]["pos"]
                    selected_traj["states"]["robot_2"]["angle"] = recomb_selected_traj[
                        "states"
                    ]["robot_2"]["angle"]

                    recomb_prev_json_states = []
                    for i in range(self.m):
                        if index - i - 1 >= 0:
                            recomb_prev_json_states.append(
                                recomb_traj_json[index - i - 1][0]
                            )
                        else:
                            recomb_prev_json_states.append(recomb_traj_json[0][0])
                    for pev_states, recomb_pev_states in zip(
                        prev_json_states, recomb_prev_json_states
                    ):
                        selected_traj = copy.deepcopy(selected_traj)
                        selected_traj["states"]["robot_2"]["pos"] = recomb_pev_states[
                            "states"
                        ]["robot_2"]["pos"]
                        selected_traj["states"]["robot_2"]["angle"] = recomb_pev_states[
                            "states"
                        ]["robot_2"]["angle"]

                selected_traj = copy.deepcopy(selected_traj)
                selected_traj["states"]["robot_2"]["pos"] = np.array(
                    selected_traj["states"]["robot_2"]["pos"]
                ) + (np.random.randn(2) * noise_std + noise_mean)
                selected_traj["states"]["robot_2"]["angle"] = np.array(
                    selected_traj["states"]["robot_2"]["angle"]
                ) + (np.random.randn(1) * noise_std + noise_mean)
                self.env.update_env_by_given_state(selected_traj)
                current_img_obs = self.env.render("rgb_array")
                previous_img_obs = []
                for pev_states in prev_json_states:
                    selected_traj = copy.deepcopy(pev_states)
                    selected_traj["states"]["robot_2"]["pos"] = np.array(
                        selected_traj["states"]["robot_2"]["pos"]
                    ) + (np.random.randn(2) * noise_std + noise_mean)
                    selected_traj["states"]["robot_2"]["angle"] = np.array(
                        selected_traj["states"]["robot_2"]["angle"]
                    ) + (np.random.randn(1) * noise_std + noise_mean)
                    self.env.update_env_by_given_state(selected_traj)
                    robot_1_obs = self.env.render("rgb_array")
                    # current_img_obs = robot_1_obs
                    previous_img_obs.append(robot_1_obs)
        elif self.no_another_states:
            noise_std = 0
            noise_mean = 0
            noise = torch.randn_like(state[8:16]) * noise_std + noise_mean
            state[8:16] *= noise
            noise = torch.randn_like(prev_states[:, 8:16]) * noise_std + noise_mean
            prev_states[:, 8:16] *= noise
        else:
            if img_obs:
                self.env.update_env_by_given_state(selected_traj)
                current_img_obs = self.env.render("rgb_array")
                previous_img_obs = []
                for pev_states in prev_json_states:
                    self.env.update_env_by_given_state(pev_states)
                    robot_1_obs = self.env.render("rgb_array")
                    previous_img_obs.append(robot_1_obs)

        if not img_obs:
            previous_img_obs = "_"
            current_img_obs = "_"

        return (
            prev_states,
            state,
            next_actions,
            the_other_agent_action,
            repected_beginning,
            repected_end,
            previous_img_obs,
            current_img_obs,
        )

    def get_item_by_indx(self, index, img_obs=False, recomb_idx_index=None):
        traj_idx, traj_index = self._locate_traj_by_indx(index)
        (
            prev_states,
            state,
            next_actions,
            the_other_agent_action,
            repected_beginning,
            repected_end,
            previous_img_obs,
            current_img_obs,
        ) = self.get_item_by_traj_idx(
            traj_idx, traj_index, img_obs=img_obs, recomb_idx_index=recomb_idx_index
        )
        return (
            prev_states,
            state,
            next_actions,
            the_other_agent_action,
            repected_beginning,
            repected_end,
            previous_img_obs,
            current_img_obs,
        )

    def _locate_traj_by_indx(self, index):
        traj_idx = 0
        ori_index = index
        while index >= len(self.data[traj_idx]):
            index -= len(self.data[traj_idx])
            traj_idx += 1

        return traj_idx, index

    def __getitem__(self, index):
        # Find the corresponding trajectory
        if self.img_obs:
            if self.env is None:
                self.env = MovingOutEnv(map_name=self.img_obs_map_id)
        if index >= len(self):
            print(f"Index {index} is out of bounds!")
        else:
            pass
        if self.shift_another_states:
            random_number = random.randint(-self.shift_range, self.shift_range)
            traj_idx = 0
            while index >= len(self.data[traj_idx]):
                index -= len(self.data[traj_idx])
                traj_idx += 1
            if (
                random_number == 0
                or traj_idx + random_number <= 0
                or traj_idx + random_number >= len(self.data[traj_idx])
            ):
                (
                    prev_states,
                    state,
                    next_actions,
                    the_other_agent_action,
                    _,
                    _,
                ) = self.get_item_by_indx(index)
            else:
                (
                    prev_states,
                    state,
                    next_actions,
                    the_other_agent_action,
                    _,
                    _,
                ) = self.get_item_by_indx(index)
                shift_prev_states, shift_state, _, _ = self.get_item_by_indx(
                    index + random_number
                )
                prev_states[:, 8:16] = shift_prev_states[:, 8:16]
                state[8:16] = shift_state[8:16]
        elif self.recombination_trajectories:
            (
                prev_states,
                state,
                next_actions,
                the_other_agent_action,
                repected_beginning,
                repected_end,
                previous_img_obs,
                current_img_obs,
            ) = self.get_item_by_indx(index, self.img_obs)

            if repected_end != 0:
                pass
            else:
                end_index = index + self.k - 1
                (
                    end_prev_states,
                    end_state,
                    end_next_actions,
                    end_the_other_agent_action,
                    repected_beginning,
                    repected_end,
                    _,
                    _,
                ) = self.get_item_by_indx(end_index, img_obs=False)
                start_pos = self.discretizer.discretize(
                    [float(state[0]), float(state[1])]
                )
                end_pos = self.discretizer.discretize(
                    [float(end_state[0]), float(end_state[1])]
                )
                result = self.recombination_analyzer.query_by_start_and_end(
                    start_pos, end_pos
                )
                if len(result["trajectories"]) > 1:
                    traj_idx, _, start_indx, end_indx = random.choice(
                        result["trajectories"]
                    )
                    if not self.img_obs:
                        (
                            recomb_prev_states,
                            recomb_state,
                            recomb_next_actions,
                            recomb_the_other_agent_action,
                            _,
                            _,
                            recomb_previous_img_obs,
                            recomb_current_img_obs,
                        ) = self.get_item_by_traj_idx(traj_idx, start_indx)
                        state[8:16] = recomb_state[8:16]
                        prev_states[:, 8:16] = recomb_prev_states[:, 8:16]
                    else:
                        (
                            recomb_prev_states,
                            recomb_state,
                            recomb_next_actions,
                            recomb_the_other_agent_action,
                            _,
                            _,
                            recomb_previous_img_obs,
                            recomb_current_img_obs,
                        ) = self.get_item_by_traj_idx(traj_idx, start_indx)

        else:
            (
                prev_states,
                state,
                next_actions,
                the_other_agent_action,
                _,
                _,
                previous_img_obs,
                current_img_obs,
            ) = self.get_item_by_indx(index, img_obs=self.img_obs)
        return (
            prev_states,
            state,
            next_actions,
            the_other_agent_action,
            previous_img_obs,
            current_img_obs,
        )


# def create_dataloader(data_path, m, k, batch_size):
#     dataset = TrajectoryDataset(data_path, m, k)
#     return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)


def create_train_test_dataloader(
    data_path,
    m,
    k,
    batch_size,
    train_test_split_ration=0.9,
    add_noise=False,
    noise_std=0.1,
    no_another_states=False,
    shift_another_states=False,
    shift_range=1,
    recombination_trajectories=False,
    recombination_trajectories_cache_path=None,
    img_obs=False,
    img_obs_map_id=None,
):
    dataset = np.load(data_path, allow_pickle=True)
    if recombination_trajectories:
        analyzer = TrajectoryPathAnalyzer(grid_size=(-23, 24))
        analyzer.compute_from_data_file(data_path)
    else:
        analyzer = None
    if train_test_split_ration == 1:
        train_dataset = TrajectoryDataset(
            dataset,
            m,
            k,
            use_loaded_numpy=True,
            add_noise=add_noise,
            noise_std=noise_std,
            no_another_states=no_another_states,
            shift_another_states=shift_another_states,
            shift_range=shift_range,
            recombination_trajectories=recombination_trajectories,
            recombination_analyzer=analyzer,
            img_obs=img_obs,
            img_obs_map_id=img_obs_map_id,
        )
        test_dataset = train_dataset
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=6,
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=6,
        )
    else:
        train_data, val_data = train_test_split(
            dataset, test_size=1 - train_test_split_ration, random_state=42
        )

        train_dataset = TrajectoryDataset(
            train_data,
            m,
            k,
            use_loaded_numpy=True,
            add_noise=add_noise,
            noise_std=noise_std,
            no_another_states=no_another_states,
            shift_another_states=shift_another_states,
            shift_range=shift_range,
            recombination_trajectories=recombination_trajectories,
            recombination_analyzer=analyzer,
        )
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=6,
        )

        test_dataset = TrajectoryDataset(
            val_data, m, k, use_loaded_numpy=True, no_another_states=no_another_states
        )
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=6,
        )

    return train_loader, test_loader


if __name__ == "__main__":
    data_path = "/p/lialabdatasets/moving_out_dataset/111_0820.npy"
    m = 5  # past observation
    k = 3  # future actions

    dataset = TrajectoryDataset(data_path, m, k)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    print("Dataset size: len: ", len(dataset))
    for prev_states, state, next_actions, the_other_agent_action in dataloader:
        print("Previous States: ", prev_states.shape)
        print("Current State: ", state.shape)
        print("Next Actions: ", next_actions.shape)
        print("The other agents Actions: ", the_other_agent_action.shape)
        break
