import os
from pathlib import Path
from typing import Union
from copy import deepcopy

import numpy as np
import pandas as pd
from scipy.stats import circmean
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import pil_loader
import yaml
from torchvision.datasets import ImageNet


def get_image_meta_path(index_df, idx, dataset_path):
    """
    Get the image path and meta data path for a given index.
    """
    image_idx, scene_name, wnid, model_name = index_df.iloc[idx, 0:4]

    record_path = dataset_path.joinpath('images', scene_name, wnid, model_name)
    img_path = record_path.joinpath(f"img_{image_idx:010d}.jpg")
    img_meta_path = record_path.joinpath(f"img_{image_idx:010d}_info.csv")
    return img_path, img_meta_path


def center_circ_array(arr):
    """
    Center an array of circular data.
    arr: array of circular data, 1d numpy array, in degrees, range (-inf, inf)
    return: centered array, 1d numpy array, in degrees, range (-180, 180), 
        centered around 0, has circular mean of 0 (mean is not 0)
    """
    arr = arr % 360.0
    arr = arr - circmean(arr, high=360.0, low=0.0)
    arr = arr % 360.0
    arr[arr > 180.0] -= 360.0
    return arr


def create_mapping(str_list: list):
    """
    Create a mapping from string to int and int to string
    param: str_list: list of strings with no duplicates
    return: two dictionaries that contains the mapping from string to int and int to string
    """
    str_list.sort()
    str2int_map = {}
    for i, str_name in enumerate(str_list):
        str2int_map[str_name] = i
    int2str_map = {v: k for k, v in str2int_map.items()}
    return str2int_map, int2str_map


def load_image(image_filepath):
    """Load an image from disk and return a PIL.Image object.
    from https://github.com/brain-score/model-tools/blob/75365b54670d3f6f63dcdf88395c0a07d6b286fc/model_tools/activations/pytorch.py#L118
    """
    with Image.open(image_filepath) as pil_image:
        if 'L' not in pil_image.mode.upper() and 'A' not in pil_image.mode.upper() \
                and 'P' not in pil_image.mode.upper():  # not binary and not alpha and not palletized
            # work around to https://github.com/python-pillow/Pillow/issues/1144,
            # see https://stackoverflow.com/a/30376272/2225200
            return pil_image.copy()
        else:  # make sure potential binary images are in RGB
            rgb_image = Image.new("RGB", pil_image.size)
            rgb_image.paste(pil_image)
            return rgb_image


class TDWDataset(Dataset):
    """TDW dataset.
    images are 256x256 pixels
    neg_x: distance of object from the screen, in TDW world-space units, where the screen is at x = 0, + is going into the image
    ty: horizontal position of object, in pixels, center of image is 0, + is going right
    tz: vertical position of object, in pixels, center of image is 0, + is going up
    euler_1, euler_2, euler_3: rotation of object, in degrees, returned by TDW local transform relative to the camera
    euler_x_proc: rotation of object, in degrees, processed to be in the range (-180, 180), centered around 0

    tdw_image_dataset_small_multi_env(_hdri): 8 categories, ~5,000 images
    tdw_image_dataset_large_20230907: 117 categories, 587 objects, ~1,350,000 images
    """
    norm_columns = ['rel_pos_x', 'rel_pos_y', 'rel_pos_z'] # columns that need to be normalized
    vis_collumns = ['rel_rot_euler_0', 'rel_rot_euler_1', 'rel_rot_euler_2'] + norm_columns

    def __init__(self,
                 root_dir: Union[str, Path] = './data/tdw_image_dataset_small_multi_env_hdri',
                 split: str = 'train',
                 transform = None,
                 fraction: float = 1.0,
                 shuffle_cat: bool = False,
                 ):
        """
        Arguments:
            root_dir (string, or Path): Directory with all the images.
            fraction (float): fraction of the dataset to use, 1.0 means use all the data
            shuffle_cat (bool): whether to shuffle the object categry and identity labels
        """
        if isinstance(root_dir, str):
            self.root_path: Path = Path(root_dir)
        else:
            self.root_path: Path = root_dir
        self.dset_name = self.root_path.name + '_' + split + f'_{fraction}'.replace('.', '_')
        self.transform = transform
        self.shuffle_cat = shuffle_cat

        with open(self.root_path.joinpath('mappings.yml'), 'r') as file:
            mappings = yaml.safe_load(file)
        self.mappings = mappings

        self.headers = self.root_path.joinpath('img_meta_headers.txt').read_text(encoding="utf-8").split("\n")
        self.means_stds = pd.read_csv(self.root_path.joinpath('norm_column_mean_std.csv'), index_col=0).iloc[0]
        
        dataset_index = pd.read_csv(self.root_path.joinpath('index_img_shuffled_with_meta.csv'), index_col=0)
        full_dset_size = len(dataset_index)

        # clamp the validation set size to be between 1,000 and 50,000
        val_set_size = min(max(round(full_dset_size * 0.2), 1000), 50000)
        train_set_size = full_dset_size - val_set_size
        assert train_set_size > 0

        if split == 'all':
            pass
        elif split == 'train':
            dataset_index = dataset_index[:train_set_size]
        elif split == 'val':
            dataset_index = dataset_index[train_set_size:].reset_index(drop=True)
        else:
            raise ValueError('split must be either all, train, or val')
        
        use_index = round(len(dataset_index) * fraction)
        self.dataset_index = dataset_index[:use_index]

        if self.shuffle_cat:
            self.dataset_index['wnid_s'] = self.dataset_index['wnid'].sample(frac=1.0).to_list()
            self.dataset_index['model_s'] = self.dataset_index['model'].sample(frac=1.0).to_list()

    def __len__(self):
        return len(self.dataset_index)

    def __getitem__(self, idx):
        image_meta = self.dataset_index.iloc[idx]
        image_idx = image_meta['image_index']
        img_path = self.root_path.joinpath('images', image_meta['scene'], image_meta['wnid'], image_meta['model'],
                                           f"img_{image_idx:010d}.jpg")
        image = pil_loader(img_path)
        if self.transform:
            image = self.transform(image)
        sample = {'image': image}

        if self.shuffle_cat:
            read_wnid = image_meta['wnid_s']
            read_model = image_meta['model_s']
        else:
            read_wnid = image_meta['wnid']
            read_model = image_meta['model']
        sample['category_label'] = self.mappings['category_str2int'][read_wnid]
        sample['object_label'] = self.mappings['object_str2int'][read_model]
        
        # for i in range(2, 9):
        #     # reduce the 8 category labels (0..7) to 2, 3, 4, 5, 6, 7, 8 category labels
        #     sample[f'cat_label_reduce{i}'] = sample['category_label'] if sample['category_label'] < i else (i - 1)

        for i in range(3):
            sample[f'rel_rot_euler_{i}'] = np.float32(image_meta[f'rel_rot_euler_{i}'])
        
        for column in self.norm_columns:
            sample[column] = np.float32((image_meta[column] - self.means_stds[f'{column}_mean']) / self.means_stds[f'{column}_std'])

        return sample


class MyImageNet(ImageNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        updated_classes = deepcopy(self.classes)
        # resolve some duplicate class names in ImageNet
        updated_classes[134] = ('crane bird',)
        updated_classes[517] = ('crane machine',)
        updated_classes[638] = ('maillot',)
        updated_classes[639] = ('maillot tank suit',)

        imn_class_list = [class_tup[0] for class_tup in updated_classes]
        
        str2int_map = {}
        for i, str_name in enumerate(imn_class_list):
            str2int_map[str_name] = i
        int2str_map = {v: k for k, v in str2int_map.items()}

        self.mappings = {
            'category_str2int': str2int_map,
            'category_int2str': int2str_map,
        }

    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        return {'image': img, 'category_label': target}
