import os
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
from PIL import Image
from timm.data import create_transform
from torch.utils.data import Dataset
from torchvision import transforms

from avr.data.dataset import DatasetSplit
from avr.data.transform import shuffle_objects, select_n_answers

IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)


class VasrDataset(Dataset):
    def __init__(
            self,
            dataset_root_dir: str = ".",
            split: DatasetSplit = DatasetSplit.TRAIN,
            do_shuffle_answers: bool = True,
            num_answers: int = 4,
            regime: str = "split_random",
            do_normalize: bool = False,
            do_vit_transform: bool = False,
    ):
        super().__init__()
        assert 1 <= num_answers <= 4
        assert not (do_normalize and do_vit_transform)
        self.dataset_root_dir = dataset_root_dir
        if regime == "split_random":
            csv_path = f"{split.value}_random.csv"
            self.candidates_key = "random_candidates"
        elif regime == "split_distractors":
            csv_path = f"{split.value}_silver.csv"
            self.candidates_key = "distractors"
        else:
            raise ValueError(f"Unsupported regime: {regime}")
        self.df = pd.read_csv(
            os.path.join(
                dataset_root_dir,
                "vasr_dataset",
                regime,
                csv_path,
            )
        )
        self.do_shuffle_answers = do_shuffle_answers
        self.num_answers = num_answers
        self.img_dir = os.path.join(dataset_root_dir, "images_512")
        self.transform = None
        if do_vit_transform:
            vit_transform = create_transform(
                **{
                    "input_size": (3, 384, 384),
                    "interpolation": "bicubic",
                    "mean": (0.5, 0.5, 0.5),
                    "std": (0.5, 0.5, 0.5),
                    "crop_pct": 1.0,
                    "crop_mode": "center",
                }
            )
            self.transform = lambda x: vit_transform(Image.fromarray(x))
        else:
            if do_normalize:
                self.normalize = transforms.Normalize(
                    IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
                )
            self.transform = self.to_tensor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int) -> Tuple[List[np.array], List[np.array], int]:
        data = self.df.iloc[idx]
        context = [
            self.load(data["A_img"]),
            self.load(data["B_img"]),
            self.load(data["C_img"]),
        ]
        answers = [self.load(data["D_img"])]
        answers += [
            self.load(filename)
            for filename in self.parse_answers(data[self.candidates_key])
        ]
        target = 0
        if self.do_shuffle_answers:
            answers, target = shuffle_objects(answers, target)
        answers, target = select_n_answers(answers, target, self.num_answers)
        context = [self.transform(x) for x in context]
        answers = [self.transform(x) for x in answers]
        return context, answers, target

    @staticmethod
    def parse_answers(text: str) -> List[str]:
        text = text[1:-1]
        text = text.replace('"', "")
        text = text.replace(" ", "")
        return text.split(",")

    def load(self, filename: str) -> np.array:
        path = os.path.join(self.img_dir, filename)
        with open(path, "rb") as f:
            image = Image.open(f).convert("RGB")
            image = np.asarray(image)
        return image

    def to_tensor(self, x: np.array) -> torch.Tensor:
        x = torch.tensor(x, dtype=torch.float32)
        x = x.permute(2, 0, 1)
        if self.normalize is not None:
            x = self.normalize(x)
        return x
