import os

import IPython
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
from typing import Any, Iterable
import torch
import torch.nn.functional as F
import tqdm
from PIL import Image
import random

class MapPretrainDataset(Dataset):
    def __init__(self, 
                 configs: dict, 
                 data_split_type: str, 
                 transform, 
                 use_huggingface_transform: bool,
        ) -> None:
        super().__init__()
        
        self.data_folder = configs['data_folder']
        self.data_split_folder = configs['data_split_folder']
        self.data_split_type = data_split_type
        self.len_traj_pred = configs['len_trajectory_pred']
        self.action_num = configs['action_num']
        self.generate_index_size = configs[f'generate_index_size_{data_split_type}']
        self.location_num = configs['location_num']
        self.min_query_num = configs['min_query_num'] if 'min_query_num' in configs else 5
        self.max_query_num = configs['max_query_num'] if 'max_query_num' in configs else 10
        self.sample_traj_by_length = configs['sample_traj_by_length'] if 'sample_traj_by_length' in configs else True
        self.location_dataset_name = configs['location_dataset_name'] if 'location_dataset_name' in configs else 'location'
        self.navigation_dataset_name = configs['navigation_dataset_name'] if 'navigation_dataset_name' in configs else 'navigation'
        self.goal_augment_prob = configs.get('goal_augment_prob', 0.0)
        self.ego_from_navigation_prob = configs.get('ego_from_navigation_prob', 0.6)
        self.transform = transform
        self.use_huggingface_transform = use_huggingface_transform
        
        # Read all trajectory indexes
        traj_path = os.path.join(self.data_split_folder, f'{self.navigation_dataset_name}-{data_split_type}.txt')
        self.traj_indexes = read_content(traj_path)
        
        # Read all location indexes
        location_train_path = os.path.join(self.data_split_folder, f'{self.location_dataset_name}-train.txt')
        location_train_indexes = read_content(location_train_path)
        
        location_val_path = os.path.join(self.data_split_folder, f'{self.location_dataset_name}-val.txt')
        location_val_indexes = read_content(location_val_path)
        
        self.map_indexes = location_train_indexes
        if data_split_type == 'train':
            self.query_indexes = location_train_indexes
        elif data_split_type == 'val':
            self.query_indexes = location_val_indexes
        else:
            raise ValueError(f'{data_split_type} not supported')
        
        # Compute the total timesteps in each trajectory
        self.traj_lengths = [self.compute_length(traj_idx) for traj_idx in self.traj_indexes]
        self.navigation_total_steps = sum(self.traj_lengths)
        
        self.traj_state_indexes = []
        for traj_idx, traj_len in zip(self.traj_indexes, self.traj_lengths):
            for i in range(traj_len):
                self.traj_state_indexes.append((traj_idx, i))
        
        self.generate_stats()
        
    def compute_length(self, traj_idx):
        traj_info_path = os.path.join(
            self.data_folder, 
            self.navigation_dataset_name, 
            traj_idx, 
            'traj-data.pkl'
        )
        with open(traj_info_path, 'rb') as file:
            traj_info = pickle.load(file)
        
        # The variable "traj_info" contains: 
        # 'position', 'yaw', 'goal_position', 'goal_yaw', 'action'
        traj_len = len(traj_info['position'])
        return traj_len
        
    def get_dataset_state(self):
        return f'{self.data_split_type}_{self.generate_index_size}_location_{self.location_num}_min_{self.min_query_num}_max_{self.max_query_num}' + \
            f'_{self.location_dataset_name}_{self.navigation_dataset_name}'
            
    def get_primitive_data_state(self):
        return f'{self.location_dataset_name}_{self.navigation_dataset_name}'

    def generate_stats(self):
        # Generate statistics of different actions. 
        # Generate reasonable normalizing factors. 
        stats_path = os.path.join(
            self.data_split_folder, 
            f'dataset_stats.pkl'
        )
        if os.path.exists(stats_path):
            with open(stats_path, 'rb') as file:
                self.stats = pickle.load(file)
        else:
            print('Generating dataset stats')
            # Compute std of 3d positions (std_x, std_y, std_z)
            location_position = []
            for map_index in self.map_indexes:
                location_info_path = os.path.join(
                    self.data_folder, 
                    self.location_dataset_name, 
                    f'{map_index}.pkl'
                )
                with open(location_info_path, 'rb') as file:
                    location_info = pickle.load(file)
                location_position.append(location_info['position'])
            # IPython.embed()
            location_position = np.stack(location_position, 0)
            position_std = location_position.std(0)
            position_mean = location_position.mean(0)

            # Compute mean local path distance (path_x, path_y, path_z)
            gaps = []
            for traj_index in self.traj_indexes:
                traj_info_path = os.path.join(
                    self.data_folder, 
                    self.navigation_dataset_name, 
                    traj_index, 
                    'traj-data.pkl'
                )
                with open(traj_info_path, 'rb') as file:
                    traj_info = pickle.load(file)
                if len(traj_info['position']) <= 5:
                    continue
                local_position_change = traj_info['position'][5:] - traj_info['position'][:-5]
                gaps.append(local_position_change)
            gaps = np.concatenate(gaps, 0)
            gaps_mean_norm = np.sqrt((gaps ** 2).mean(0))

            self.stats = {
                'position_std': position_std, 
                'position_mean': position_mean,
                'path_mean_norm': gaps_mean_norm
            }
            with open(stats_path, 'wb') as file:
                pickle.dump(self.stats, file)
        print('Using normalize data:')
        for key, value in self.stats.items():
            print(f'{key} = {value}')

    def sample_index_data(self): 
        # Generate balancer for uniformly sampling the starting position from the trajectory data. 
        query_count = np.random.randint(self.min_query_num, self.max_query_num + 1)
        map_count = self.location_num - query_count
        query_index_sample = np.random.choice(
            self.query_indexes, 
            query_count, 
            replace=False
        )
        map_index_sample = np.random.choice(
            self.map_indexes, 
            map_count, 
            replace=False
        )
        location_index_sample = np.concatenate([query_index_sample, map_index_sample])
        location_position = []
        location_yaw = []
        for location_index_item in location_index_sample:
            location_info_path = os.path.join(
                self.data_folder, 
                self.location_dataset_name, 
                f'{location_index_item}.pkl'
            )
            with open(location_info_path, 'rb') as file:
                location_info = pickle.load(file)
            location_position.append(location_info['position'])
            location_yaw.append(location_info['yaw'])
        
        # We need to get all of the following from trajectory info: 
        #    (1) The index of current state; 
        #    (2) The goal state; 
        #    (3) The corresponding position and yaw; 
        #    (4) The local actions; 
        #    (5) The local path and global path. 
        #    (6) The distance to the goal state. 
        
        # Select a random trajectory, and a random starting state index
        if self.sample_traj_by_length:
            traj_index, curr_time = random.choice(self.traj_state_indexes)
        else:
            traj_index = random.choice(self.traj_indexes)
            
        # Read the information in the trajectory
        traj_info_path = os.path.join(
            self.data_folder, 
            self.navigation_dataset_name, 
            traj_index, 
            'traj-data.pkl'
        )
        with open(traj_info_path, 'rb') as file:
            traj_info = pickle.load(file)
        
        # The variable "traj_info" contains: 
        # 'position', 'yaw', 'goal_position', 'goal_yaw', 'action'
        traj_len = len(traj_info['position'])
        
        if not self.sample_traj_by_length:
            curr_time = np.random.randint(0, traj_len)
        
        # Add augmented goal sampled in the episode
        # Only perform goal augmentation during training
        if self.data_split_type == 'train' and np.random.rand() < self.goal_augment_prob:
            # Augment goal within the episode
            goal_time = np.random.randint(curr_time, traj_len)
        else:
            # Use saved final goal
            goal_time = traj_len
        
        # Get position and yaw of current time
        curr_pos, curr_yaw = traj_info['position'][curr_time], traj_info['yaw'][curr_time]
        
        # Get goal position and yaw
        if goal_time == traj_len:
            goal_pos, goal_yaw = traj_info['goal_position'], traj_info['goal_yaw']
        else:
            goal_pos, goal_yaw = traj_info['position'][goal_time], traj_info['yaw'][goal_time]
        
        # Get the local actions
        use_halt_action = False
        local_actions = []
        for step_delta in range(self.len_traj_pred):
            if curr_time + step_delta < goal_time:
                action = traj_info['action'][curr_time + step_delta]
                if not use_halt_action and action == 7:
                    action = 0
                local_actions.append(action)
            else:
                # Fill in stop actions
                if use_halt_action:
                    local_actions.append(7)
                else:
                    local_actions.append(0)
        
        # Get the local path
        # One movement step will take 0.4s, while turn step will take 0.2s. 
        # We will use the position & yaw at the next 5 steps, 
        # which corresponds to 1s - 2s. 
        local_pos, local_yaw = [], []
        for step_delta in range(1, self.len_traj_pred + 1):
            if curr_time + step_delta < goal_time:
                local_pos.append(traj_info['position'][curr_time + step_delta])
                local_yaw.append(traj_info['yaw'][curr_time + step_delta])
            elif goal_time == traj_len:
                local_pos.append(traj_info['position'][-1])
                local_yaw.append(traj_info['yaw'][-1])
            else:
                local_pos.append(traj_info['position'][goal_time])
                local_yaw.append(traj_info['yaw'][goal_time])
        local_pos, local_yaw = np.array(local_pos), np.array(local_yaw)
        
        # Get the global path
        global_pos, global_yaw = [], []
        for step_float in np.linspace(curr_time, goal_time, self.len_traj_pred):
            step_future = round(step_float)
            step_future = min(step_future, traj_len - 1)
            global_pos.append(traj_info['position'][step_future])
            global_yaw.append(traj_info['yaw'][step_future])
        global_pos, global_yaw = np.array(global_pos), np.array(global_yaw)
        
        # Change the coordinate to egocentric
        local_path = to_ego_coords(
            local_pos, local_yaw, curr_pos, curr_yaw
        )
        global_path = to_ego_coords(
            global_pos, global_yaw, curr_pos, curr_yaw
        )
        
        # Compute the distance to goal
        end_time = min(goal_time + 1, traj_len)
        future_path = traj_info['position'][curr_time:end_time]
        distance_step = np.linalg.norm(future_path[1:] - future_path[:-1], axis=-1)
        distance = distance_step.sum()
                
        # location_index (50)
        # query_count ()
        # location_position (50, 3)
        # location_yaw (50)
        # traj_index ()
        # curr_time ()
        # curr_pos (3)
        # curr_yaw ()
        # goal_pos (3)
        # goal_yaw ()
        # local_actions (10000, 5)
        # local_paths (10000, 5, 6)
        # global_paths (10000, 5, 6)
        
        index_data = {
            'location_index': np.array(location_index_sample), 
            'query_count': query_count, 
            'location_position': np.array(location_position), 
            'location_yaw': np.array(location_yaw), 
            'traj_index': traj_index, 
            'curr_time': curr_time, 
            'goal_time': goal_time, 
            'goal_is_augmented': goal_time < traj_len,
            'curr_pos': np.array(curr_pos), 
            'curr_yaw': curr_yaw, 
            'goal_pos': np.array(goal_pos), 
            'goal_yaw': goal_yaw, 
            'local_actions': np.array(local_actions, dtype=np.int64), 
            'local_paths': np.array(local_path),
            'global_paths': np.array(global_path), 
            'distance': distance
        }
        
        return index_data
            
    def __len__(self) -> int:
        return self.generate_index_size
    
    def __getitem__(self, index: int) -> Any:
        # Locations
        index_data = self.sample_index_data()
        
        location_index = index_data['location_index']
        query_count = index_data['query_count']
        location_position = index_data['location_position']
        location_yaw = index_data['location_yaw']
        
        # Build location type 
        location_types = [2] * query_count + [1] * (self.location_num - query_count)
        location_types = np.array(location_types)
        
        # Build image list
        images = []
        for idx in location_index:
            path = os.path.join(
                self.data_folder, self.location_dataset_name, f'{idx}.bmp'
            )
            image = Image.open(path).convert('RGB')
            images.append(image)
        
        # Use trajectory data as center with 0.6 probability
        # while use location data as center with 0.4 probability
        if np.random.rand() < self.ego_from_navigation_prob:
            # traj_index ()
            # curr_time ()
            # curr_pos (2)
            # curr_yaw ()
            # goal_pos (2)
            # goal_yaw ()
            # local_actions (5)
            # local_paths (5, 6)
            # global_paths (5, 6)
            traj_index = index_data['traj_index']
            curr_time = index_data['curr_time']
            goal_time = index_data['goal_time']
            goal_is_augmented = index_data['goal_is_augmented']
            curr_pos = index_data['curr_pos']
            curr_yaw = index_data['curr_yaw']
            goal_pos = index_data['goal_pos']
            goal_yaw = index_data['goal_yaw']
            local_actions = index_data['local_actions']
            local_paths = index_data['local_paths']
            global_paths = index_data['global_paths']
            distance = index_data['distance']
            ego_pos = curr_pos.copy()
            ego_yaw = curr_yaw
            
            # Substitute 2 of the queries into current state and goal
            location_types[0] = 0
            location_position[0, :] = curr_pos
            location_yaw[0] = curr_yaw
            
            # Substitute image
            curr_image = Image.open(os.path.join(
                self.data_folder, self.navigation_dataset_name, traj_index, f'online_{curr_time}.bmp'
            )).convert('RGB')
            if goal_is_augmented:
                goal_image = Image.open(os.path.join(
                    self.data_folder, self.navigation_dataset_name, traj_index, f'online_{goal_time}.bmp'
                )).convert('RGB')
            else:
                goal_image = Image.open(os.path.join(
                    self.data_folder, self.navigation_dataset_name, traj_index, 'goal.bmp'
                )).convert('RGB')
            images[0] = curr_image
            images[1] = goal_image
            
            location_position[1, :] = goal_pos
            location_yaw[1] = goal_yaw
            goal_masks = np.array([0, 1] + [0] * (self.location_num - 2))
            local_paths = np.repeat(
                np.expand_dims(local_paths, 0), 
                self.location_num, 0
            )
            global_paths = np.repeat(
                np.expand_dims(global_paths, 0), 
                self.location_num, 0
            )
            local_actions = np.repeat(
                np.expand_dims(local_actions, 0), 
                self.location_num, 0
            )
            distance = np.repeat(distance, self.location_num, 0)
        else:
            ego_index = np.random.randint(0, self.location_num)
            location_types[ego_index] = 0
            goal_masks = np.array([0] * self.location_num)
            ego_pos = location_position[ego_index]
            ego_yaw = location_yaw[ego_index]
            
            local_paths = np.zeros((self.location_num, self.len_traj_pred, 6))
            global_paths = np.zeros((self.location_num, self.len_traj_pred, 6))
            local_actions = np.zeros((self.location_num, self.len_traj_pred), dtype=np.int64)
            distance = np.zeros(self.location_num)
            
        # Compute egocentric positions and yaw
        
        ego_position_other = to_ego_coords(
            location_position, location_yaw, ego_pos, ego_yaw
        )[:, :3]
        ego_yaw_other = location_yaw - ego_yaw
        
        # Currently we don't preload the image, which may be slow
        if self.use_huggingface_transform:
            image_tensor = self.transform(images, return_tensors="pt")['pixel_values']
        else:
            if type(images) is list:
                image_tensor = torch.stack([self.transform(im) for im in images], 0)
            else:
                image_tensor = self.transform(images).unsqueeze(0)
        
        position_mean, position_std = self.stats['position_mean'], self.stats['position_std']
        map_positions_normalized = (location_position - position_mean) / position_std
        ego_positions_normalized = ego_position_other / position_std
        data = {
            'images': image_tensor,
            'map_positions': torch.tensor(map_positions_normalized, dtype=torch.float32), 
            'map_yaws': torch.tensor(location_yaw, dtype=torch.float32), 
            'ego_positions': torch.tensor(ego_positions_normalized, dtype=torch.float32), 
            'ego_yaws': torch.tensor(ego_yaw_other, dtype=torch.float32), 
            'pos_scale': torch.tensor(self.stats['position_std'], dtype=torch.float32), 
            'local_path_scale': torch.tensor(self.stats['path_mean_norm'], dtype=torch.float32), 
            'location_types': torch.tensor(location_types), 
            'goal_masks': torch.tensor(goal_masks), 
            'local_path': torch.tensor(local_paths, dtype=torch.float32), 
            'global_path': torch.tensor(global_paths, dtype=torch.float32), 
            'local_actions': torch.tensor(local_actions), 
            'distance': torch.tensor(distance, dtype=torch.float32)
        }
        
        # for key, value in data.items():
        #     print(key, value.shape, value.dtype)
        return data
    
    def process_images(self, images):
        if self.use_huggingface_transform:
            image_tensor = self.transform(images, return_tensors="pt")['pixel_values']
        else:
            if type(images) is list:
                image_tensor = torch.stack([self.transform(im) for im in images], 0)
            else:
                image_tensor = self.transform(images).unsqueeze(0)
        return image_tensor
        
def read_content(path):
    with open(path, 'r') as file:
        content = file.read().strip()
    content = content.split(' ')
    if '' in content:
        content.remove('')
    return content

def yaw_rotmat(yaw: float) -> np.ndarray:
    yaw_radian = yaw * np.pi / 180
    return np.array(
        [
            [np.cos(yaw_radian), -np.sin(yaw_radian), 0],
            [np.sin(yaw_radian), np.cos(yaw_radian), 0],
            [0, 0, 1]
        ]
    )

def to_ego_coords(
    positions, yaws, curr_pos, curr_yaw
) -> np.ndarray:
    rotmat = yaw_rotmat(curr_yaw)
    pos = (positions - curr_pos).dot(rotmat)
    yaw = yaws - curr_yaw
    angle_repr = np.zeros_like(pos)
    angle_repr[:, 0] = np.cos(yaw * np.pi / 180)
    angle_repr[:, 1] = np.sin(yaw * np.pi / 180)
    return np.concatenate((pos, angle_repr), axis=1)

class SmoothedClassBalancer:
    def __init__(self, classes: dict) -> None:
        self.counts = {}
        for c in classes:
            self.counts[c] = 0
            
    def add(self, c):
        self.counts[c] += 1
        
    def sample(self):
        keys = list(self.counts.keys())
        values = [self.counts[k] ** (-0.5) for k in keys]
        p = np.array(values) / np.sum(values)
        class_index = np.random.choice(list(range(len(keys))), p=p)
        class_choice = keys[class_index]
        return class_choice
        

class RandomizedClassBalancer:
    def __init__(self, classes: Iterable) -> None:
        """
        A class balancer that will sample classes randomly, but will prioritize classes that have been sampled less

        Args:
            classes (Iterable): The classes to balance
        """
        self.counts = {}
        for c in classes:
            self.counts[c] = 0

    def sample(self, class_filter_func=None) -> Any:
        """
        Sample the softmax of the negative logits to prioritize classes that have been sampled less

        Args:
            class_filter_func (Callable, optional): A function that takes in a class and returns a boolean. Defaults to None.
        """
        if class_filter_func is None:
            keys = list(self.counts.keys())
        else:
            keys = [k for k in self.counts.keys() if class_filter_func(k)]
        if len(keys) == 0:
            return None  # no valid classes to sample
        values = [-(self.counts[k] - min(self.counts.values())) for k in keys]
        p = F.softmax(torch.Tensor(values), dim=0).detach().cpu().numpy()
        class_index = np.random.choice(list(range(len(keys))), p=p)
        class_choice = keys[class_index]
        self.counts[class_choice] += 1
        return class_choice

    def __str__(self) -> str:
        string = ""
        for c in self.counts:
            string += f"{c}: {self.counts[c]}\n"
        return string
        
if __name__ == '__main__':
    from transformers import ViTImageProcessor
    transform = ViTImageProcessor.from_pretrained("facebook/vit-mae-base")
    use_huggingface_transform = True
    data_folder = '../map-pretrain-data'
    data_split_folder = '../map-pretrain-data/data-split'
    data_split_type = 'train'
    len_traj_pred = 5
    action_num = 8
    location_num = 50
    generate_index_size = 10000
    
    configs = {
        'data_folder': data_folder, 
        'data_split_folder': data_split_folder, 
        'len_trajectory_pred': len_traj_pred, 
        'action_num': action_num, 
        'location_num': location_num, 
        'generate_index_size_train': generate_index_size
    }
    
    dataset = MapPretrainDataset(
        configs, data_split_type, 
        transform, use_huggingface_transform
    )
    
    train_loader = DataLoader(
        dataset, 
        batch_size=4, 
        shuffle=True,
        num_workers=16,
        drop_last=True
    )
    
    for idx, data in enumerate(train_loader):
        print(f'in {idx} iterations')
        if idx % 64 == 0:
            for k, v in data.items():
                print(k, end=' ')
                if v.dtype is torch.float32:
                    print(v.min(), v.max(), v.mean(), v.std(), v.shape, v.dtype)
                else:
                    print(v.min(), v.max(), v.shape, v.dtype)
        