import os
import re
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from fastai.vision.all import untar_data, URLs, get_image_files
from ofa.imagenet_classification.data_providers.al_sampler import ImagenetALDataProvider
from path_prefix import dataset_root


class FLRALDataProvider(ImagenetALDataProvider):
    @staticmethod
    def name():
        return "flw"

    @property
    def n_classes(self):
        return 10

    @property
    def save_path(self):
        self._save_path = dataset_root
        return self._save_path

    def train_dataset(self, _transforms):
        return FLRTrainDataset(self.save_path, _transforms)

    def test_dataset(self, _transforms):
        return FLRTestDataset(self.save_path, _transforms)


class FLRTrainDataset(Dataset):
    def __init__(self, data_root, transform) -> None:
        with open(os.path.join(data_root, 'train', 'trainImageList.txt'), 'r') as f:
            all_data = f.read().splitlines()
        with open(os.path.join(data_root, 'train', 'testImageList.txt'), 'r') as f:
            all_data += f.read().splitlines()
        self.transform = transform
        self.f_path = []
        self.imgs = []
        self.labels = []
        self.img_shapes = []
        for line in all_data:
            line = line.split(' ')
            abs_path = os.path.join(data_root, 'train', line[0]).replace("\\", "/")
            self.f_path.append(abs_path)
            self.labels.append(line[5:])
            image = Image.open(abs_path).convert("RGB")
            self.imgs.append(image)
            self.img_shapes.append((image.width, image.height))
        self.labels = np.asarray(self.labels, dtype=float)

    def __getitem__(self, index):
        lab = self.labels[index]
        for ii in range(len(lab)):
            if ii % 2 == 0:
                lab[ii] /= self.img_shapes[index][0]
            else:
                lab[ii] /= self.img_shapes[index][1]
        assert max(lab) <= 1
        return self.transform(self.imgs[index]), lab
    
    def __len__(self) -> int:
        return len(self.f_path)


class FLRTestDataset(Dataset):
    def __init__(self, data_root, transform) -> None:
        with open(os.path.join(data_root, 'compare', 'bioid', 'facial_points_positions_bioid_1521_label.txt'), 'r') as f:
            all_data = f.read().splitlines()
        with open(os.path.join(data_root, 'compare', 'lfpw', 'facial_points_positions_lfpw_train_781_label.txt'), 'r') as f:
            all_data = f.read().splitlines()
        with open(os.path.join(data_root, 'compare', 'lfpw', 'facial_points_positions_lfpw_test_249_label.txt'), 'r') as f:
            all_data = f.read().splitlines()
        self.transform = transform
        self.f_path = []
        self.imgs = []
        self.labels = []
        self.img_shapes = []
        for il, line in enumerate(all_data):
            if il == 0:
                continue
            line = line.split(' ')
            abs_path = os.path.join(data_root, 'test', line[0]).replace("\\", "/")
            self.f_path.append(abs_path)
            self.labels.append(line[1:])
            image = Image.open(abs_path).convert("RGB")
            self.imgs.append(image)
            self.img_shapes.append((image.width, image.height))
        self.labels = np.asarray(self.labels, dtype=float)

    def __getitem__(self, index):
        lab = self.labels[index]
        for ii in range(len(lab)):
            if ii % 2 == 0:
                lab[ii] /= self.img_shapes[index][0]
            else:
                lab[ii] /= self.img_shapes[index][1]
        assert max(lab) <= 1
        return self.transform(self.imgs[index]), lab
    
    def __len__(self) -> int:
        return len(self.f_path)


class BIWIALDataProvider(ImagenetALDataProvider):
    def __init__(self, **kwargs):
        path = untar_data(URLs.BIWI_HEAD_POSE)
        self.img_files_path = get_image_files(path)
        people_ids = np.arange(1, 25)
        np.random.seed(0)
        np.random.shuffle(people_ids)
        pattern = r"\d+"
        cut_point = int(0.7*len(people_ids))
        self.train_people_ids = people_ids[:cut_point]
        self.test_people_ids = people_ids[cut_point:]

        self.train_imgs_path = []
        self.test_imgs_path = []
        for item in self.img_files_path:
            res = re.findall(pattern=pattern, string=str(item))
            if int(res[0]) in self.train_people_ids:
                self.train_imgs_path.append(item)
            else:
                self.test_imgs_path.append(item)

        def label_files(x):
            return f"{str(x)[:-7]}pose.txt"

        self.train_label_files_path = list(map(label_files, self.train_imgs_path))
        self.test_label_files_path = list(map(label_files, self.test_imgs_path))
        self.cal = np.genfromtxt(path/'01'/'rgb.cal', skip_footer=6)
        super().__init__(**kwargs)

    @staticmethod
    def name():
        return "biwi"

    @property
    def n_classes(self):
        return 2

    def train_dataset(self, _transforms):
        return BIWIDataset(self.train_imgs_path, self.train_label_files_path, self.cal, _transforms)

    def test_dataset(self, _transforms):
        return BIWIDataset(self.test_imgs_path, self.test_label_files_path, self.cal, _transforms)


class BIWIDataset(Dataset):
    def __init__(self, imgs_path, labels_path, cal, transform):
        self.cal = cal
        self.transform = transform
        self.label_files_path = labels_path
        self.img_files_path = imgs_path

    def _get_ctr(self, f):
        # first 3 lines: head rotation matrix; 4th line: head center coord
        ctr = np.genfromtxt(f, skip_header=3)
        c1 = ctr[0] * self.cal[0][0] / ctr[2] + self.cal[0][2]
        c2 = ctr[1] * self.cal[1][1] / ctr[2] + self.cal[1][2]
        return torch.tensor([c1, c2])

    def __getitem__(self, index):
        lab = self._get_ctr(self.label_files_path[index])  # head center coord
        img = Image.open(self.img_files_path[index]).convert("RGB")
        lab[0] /= img.width
        lab[1] /= img.height
        assert max(lab) <= 1
        return self.transform(img), lab

    def __len__(self) -> int:
        return len(self.img_files_path)
    