import os
import pytorch_lightning as pl
import torchvision.transforms as torch_transforms
from datasets.shapes_generation.attributes import Views
from datasets.shapes_generation.shapes_synset_to_class_name import (
    SYNSET_TO_CLASS_NAME,
)
import torch
import json
import pandas as pd
import itertools
from typing import List, OrderedDict, Tuple, Any, Dict, Optional, Set
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset, ConcatDataset
from pytorch_lightning.trainer import supporters
import matplotlib.pyplot as plt
from datasets.augmentations import (
    DetRandomResizedCrop,
    DetRandomHorizontalFlip,
    DetRandomApply,
    DetGaussianBlur,
    DetColorJitter,
    SimCLRTrainDataTransform,
)

DET_TRANSFORMS = [
    DetRandomResizedCrop,
    DetRandomHorizontalFlip,
    DetRandomApply,
    DetGaussianBlur,
    DetColorJitter,
]
import random
from ast import literal_eval
import random
import numpy as np

USER = os.getenv("USER")
# computed across subset of training instances
MEANS = [0.1154, 0.1129, 0.1017]  # These are the 52 classes mean of Shapes
STDS = [1.1017, 1.0993, 1.0832]

# MEANS = [0.1458, 0.1502, 0.1539] #DARK BACKGROUND
# STDS = [0.1431, 0.1290, 0.1142]
# MEANS = [0.8356, 0.8427, 0.7686] #WHITE BACKGROUND
# STDS = [0.2249, 0.2350, 0.2212]
# MEANS = [0.5033, 0.4263, 0.3149] # ALL SOLID BACKGROUND
# STDS =[0.3889, 0.3109, 0.2125]


class iLab(Dataset):
    """Dataset for fetch iLab objects with specific backgrounds.

    Args:
        data_dir: parent directory containing images
        background: background id
    """

    def __init__(
        self,
        data_dir: str = "/checkpoint/USER/ilab/iLab-2M/home2/toy/iLab2M/test_img/",
        fov_dir: str = "DATADIR/datasets/shapes_renderings/",
        images_names: str = "test_images_names.txt",
        background: List = ["b0000"],
        classes: List = ["car", "train", "bus", "plane"],
        mean: List[float] = MEANS,
        std: List[float] = STDS,
        img_transforms: List[torch.nn.Module] = [
            torch_transforms.Resize(256),
            torch_transforms.CenterCrop(224),
        ],
        online_transforms: List[torch.nn.Module] = [
            torch_transforms.Resize(256),
            torch_transforms.CenterCrop(224),
        ],
    ):

        # sanity check for errors when passing not a list
        for b in background:
            assert b.startswith("b")
        self.data_dir = data_dir
        self.background = background
        self.classes = classes
        samples = np.genfromtxt(os.path.join(self.data_dir, images_names), dtype="U")
        self.samples = self.filter_samples_on_background(samples)
        self.samples = self.filter_samples_on_classes(self.samples)

        self.fov_dir = fov_dir
        self.mean = mean
        self.std = std

        self.fov_df_shapes = self.load_fov(self.fov_dir)
        self.class_to_idx = self.map_class_to_idx()

        self.img_transforms = img_transforms
        self.online_transforms = online_transforms
        self.to_tensor = torch_transforms.Compose(
            self.img_transforms
            + [
                torch_transforms.ToTensor(),
                torch_transforms.Normalize(mean=self.mean, std=self.std),
            ]
        )

        self.to_tensor_online = torch_transforms.Compose(
            self.online_transforms
            + [
                torch_transforms.ToTensor(),
                torch_transforms.Normalize(mean=self.mean, std=self.std),
            ]
        )

    def filter_samples_on_background(self, samples):
        return [k for k in samples if k.split("-")[2] in self.background]

    def filter_samples_on_classes(self, samples):
        return [k for k in samples if k.split("-")[0] in self.classes]

    @staticmethod
    def load_fov(data_dir: str) -> pd.DataFrame:
        df = pd.read_csv(
            os.path.join(data_dir, "fov.csv"),
            delimiter="\t",
            dtype={
                "class": str,
                "instance_id": str,
                "image_path": str,
                "pose_x": float,
                "pose_y": float,
                "pose_z": float,
            },
        )
        # set columns as indices for fast filtering
        df = df.set_index(["instance_id", "pose_x", "pose_y", "pose_z"])
        # sorting speeds up lookups
        df = df.sort_index()
        return df

    def class_name_to_synset(self, val):
        if val == "plane":
            val = "airplane"

        for key, value in SYNSET_TO_CLASS_NAME.items():
            if val == value:
                return key

        return "key doesn't exist"

    def synset_to_class_name(self, synset: str) -> str:
        """Maps a synset n20393 -> dog (human readable class name)"""
        return SYNSET_TO_CLASS_NAME[synset]

    def map_class_to_idx(self) -> Dict[str, int]:
        """One-hot encodes classes based on synsets."""
        classes = sorted(self.fov_df_shapes["class"].unique())
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return class_to_idx

    def get_info_from_name(self, image_name):

        cl, i, bckgd, cam, rot, lig, foc = image_name.split("-")
        foc = foc.split(".")[0]
        foc = int(foc.split("f")[1])
        rot = int(rot.split("r")[1])
        lig = int(lig.split("l")[1])
        cam = int(cam.split("c")[1])
        bckgd = int(bckgd.split("b")[1])
        return cl, i, bckgd, cam, rot, lig, foc

    def __getitem__(self, index: int):
        image_path = self.samples[index]
        image = self.pil_loader(os.path.join(self.data_dir, image_path))

        x = self.to_tensor(image)
        label, i, bckgd, cam, pose, lig, foc = self.get_info_from_name(image_path)
        synset = self.class_name_to_synset(label)
        label_idx = self.class_to_idx[synset.split("n")[1]]
        fov = {
            "pose": pose,
            "synset": synset,
            "class_name": self.synset_to_class_name(synset),
            "image_path": image_path,
        }
        return x, label_idx, fov

    @staticmethod
    def pil_loader(path: str) -> Image.Image:
        # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
        with open(path, "rb") as f:
            img = Image.open(f)
            img.convert("RGB")
        return img

    def __len__(
        self,
    ):
        return len(self.samples)


class iLabDataModule(pl.LightningDataModule):
    """ """

    def __init__(
        self,
        data_dir: str = "/checkpoint/USER/ilab/iLab-2M/home2/toy/iLab2M/test_img/",
        fov_dir: str = "DATADIR/datasets/shapes_renderings/",
        images_names: str = "test_images_names.txt",
        background: List = ["b0000"],
        classes: List = ["car", "train", "bus", "plane"],
        batch_size: int = 8,
        num_workers: int = 8,
        mean: List[float] = MEANS,
        std: List[float] = STDS,
    ):
        super().__init__()

        self.data_dir = data_dir
        self.fov_dir = fov_dir
        self.background = background
        self.classes = classes
        self.images_names = images_names
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.mean = mean
        self.std = std
        self.fov_df_shapes = iLab.load_fov(self.fov_dir)
        self.num_classes = len(self.fov_df_shapes["class"].unique())

        self.train_loader_names = []
        self.val_loader_names = []

        self.test_loader_names = [
            "test",
        ]
        self.val_augmentations = [
            torch_transforms.Resize(256),
            torch_transforms.CenterCrop(224),
        ]

        # only prepare data from rank 0
        self.prepare_data_per_node = False

    def setup(self, stage: Optional[str] = None):

        self.test_dataset = iLab(
            data_dir=self.data_dir,
            fov_dir=self.fov_dir,
            images_names=self.images_names,
            background=self.background,
            classes=self.classes,
            img_transforms=self.val_augmentations,
            mean=self.mean,
            std=self.std,
        )

    def _make_loader(
        self, dataset: Dataset, shuffle: bool = True, num_workers=None
    ) -> DataLoader:
        if num_workers is None:
            # use default if not provided
            num_workers = self.num_workers
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
            drop_last=False,
        )

    def check_number_samples(self, loaders: List[DataLoader]):
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            print("Size", torch.distributed.get_world_size())
            world_size = torch.distributed.get_world_size()
            for loader in loaders:
                if len(loader.dataset) < world_size:
                    raise ValueError(
                        "Loader {loader} not enough samples to give one per GPU"
                    )

    def test_dataloader(self, domain=None) -> List[DataLoader]:
        loaders = OrderedDict(
            [
                (
                    "test",
                    self._make_loader(self.test_dataset, shuffle=False, num_workers=1),
                ),
            ]
        )
        assert (
            list(loaders.keys()) == self.test_loader_names
        ), "test loader names don't match"
        if not domain or domain == "all":
            return list(loaders.values())
        raise ValueError(f"domain {domain} not supported")

    def compute_mean_std(self) -> Tuple[torch.Tensor, torch.Tensor]:
        x_sum = torch.tensor([0.0, 0.0, 0.0])
        x_squared_sum = torch.tensor([0.0, 0.0, 0.0])
        n = 0

        for x, _, _ in self.test_dataset:
            x_sum += x.mean(axis=[1, 2])
            x_squared_sum += (x ** 2).mean(axis=[1, 2])
            n += 1

        n = float(n)  # this is to get float division on n/n-1
        mean = x_sum / n
        var = (x_squared_sum / (n - 1)) - (mean ** 2) * n / (n - 1)
        std = torch.sqrt(var)
        return mean, std


class iLabLoadersDataModule(pl.LightningDataModule):
    """Data Module containing train, validation and test samples"""

    def __init__(
        self,
        data_dir: str = "/checkpoint/USER/ilab/iLab-2M/home2/toy/iLab2M/",
        fov_dir: str = "DATADIR/datasets/shapes_renderings/",
        background: List = ["b0000"],
        classes: List = ["car", "train", "bus", "plane"],
        batch_size: int = 8,
        num_workers: int = 8,
        mean: List[float] = MEANS,
        std: List[float] = STDS,
    ):
        super().__init__()

        self.data_dir = data_dir
        self.fov_dir = fov_dir
        self.background = background
        self.classes = classes
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.mean = mean
        self.std = std
        self.fov_df_shapes = iLab.load_fov(self.fov_dir)
        self.num_classes = len(self.fov_df_shapes["class"].unique())

        self.train_loader_names = ["train"]
        self.val_loader_names = ["val"]

        self.test_loader_names = [
            "test",
        ]

        self.train_augmentations = [
            torch_transforms.RandomResizedCrop(224),
            torch_transforms.RandomHorizontalFlip(),
        ]
        self.val_augmentations = [
            torch_transforms.Resize(256),
            torch_transforms.CenterCrop(224),
        ]

        # only prepare data from rank 0
        self.prepare_data_per_node = False

    def setup(self, stage: Optional[str] = None):

        self.all_train_dataset = iLab(
            data_dir=os.path.join(self.data_dir, "train_img"),
            fov_dir=self.fov_dir,
            images_names="train_images_names.txt",
            background=self.background,
            classes=self.classes,
            img_transforms=self.val_augmentations,
            mean=self.mean,
            std=self.std,
        )

        n_val_samples = int(len(self.all_train_dataset) * 0.10)
        n_train_samples = len(self.all_train_dataset) - n_val_samples
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            self.all_train_dataset, [n_train_samples, n_val_samples]
        )

        self.test_dataset = iLab(
            data_dir=os.path.join(self.data_dir, "test_img"),
            fov_dir=self.fov_dir,
            images_names="test_images_names.txt",
            background=self.background,
            classes=self.classes,
            img_transforms=self.val_augmentations,
            mean=self.mean,
            std=self.std,
        )

    def _make_loader(
        self, dataset: Dataset, shuffle: bool = True, num_workers=None
    ) -> DataLoader:
        if num_workers is None:
            # use default if not provided
            num_workers = self.num_workers
        return torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
            drop_last=False,
        )

    def train_dataloader(self) -> supporters.CombinedLoader:
        loaders = {
            "train": self._make_loader(self.train_dataset, num_workers=1),
        }
        combined_loaders = supporters.CombinedLoader(
            loaders,
        )
        return combined_loaders

    def val_dataloader(self) -> List[DataLoader]:
        # use an ordereddict to ensure order is preserved
        loaders = OrderedDict(
            [
                (
                    "val",
                    self._make_loader(self.val_dataset, shuffle=False, num_workers=1),
                ),
            ]
        )
        assert (
            list(loaders.keys()) == self.val_loader_names
        ), "val loader names don't match"

        return list(loaders.values())

    def test_dataloader(self, domain=None) -> List[DataLoader]:
        loaders = OrderedDict(
            [
                (
                    "test",
                    self._make_loader(self.test_dataset, shuffle=False, num_workers=1),
                ),
            ]
        )
        assert (
            list(loaders.keys()) == self.test_loader_names
        ), "test loader names don't match"
        if not domain or domain == "all":
            return list(loaders.values())
        raise ValueError(f"domain {domain} not supported")

    def compute_mean_std(self) -> Tuple[torch.Tensor, torch.Tensor]:
        x_sum = torch.tensor([0.0, 0.0, 0.0])
        x_squared_sum = torch.tensor([0.0, 0.0, 0.0])
        n = 0

        for x, _, _ in self.test_dataset:
            x_sum += x.mean(axis=[1, 2])
            x_squared_sum += (x ** 2).mean(axis=[1, 2])
            n += 1

        n = float(n)  # this is to get float division on n/n-1
        mean = x_sum / n
        var = (x_squared_sum / (n - 1)) - (mean ** 2) * n / (n - 1)
        std = torch.sqrt(var)
        return mean, std


if __name__ == "__main__":
    data = iLabDataModule(
        data_dir="/checkpoint/USER/ilab/iLab-2M/home2/toy/iLab2M/train_img/",
        images_names="train_images_names.txt",
    )
    data.setup()
    mean, std = data.compute_mean_std()
    print(mean, std)
