from pathlib import Path
import json
import tarfile
import io
import re
import logging

from torchvision.transforms import v2 as transforms

import pandas as pd
import numpy as np
from PIL import Image

from .image_utils import (
    ImageDataset, ImageTransformDataset, ImageSubset, IndexedDataset, random_split)


_ROOT = "biwi_kinect"
txt_logger = logging.getLogger("sfda_reg")

class BiwiKinect(ImageDataset):

    def __init__(self, biwi_path: str, gender: str, target: str, y_unit: str = "rad"):
        assert target in ("yaw", "pitch", "roll", "all")
        self.target = target

        tar_path = Path(biwi_path, f"{_ROOT}.tar")
        self.tar = tarfile.open(str(tar_path), mode="r")

        with Path(biwi_path, "gender.json").open("r", encoding="utf-8") as f:
            person_dirs: str = json.load(f)[gender]

        df = {"person": [], "frame": [], "yaw": [], "roll": [], "pitch": []}

        tar_members = [m.name for m in self.tar.getmembers() if m.isfile()]

        for person in person_dirs:
            metadata_paths = (
                m for m in tar_members
                if re.search(rf"faces_0/{person}/frame_\d+_pose.txt", m))

            for p in metadata_paths:
                fp = self.tar.extractfile(p)
                assert fp is not None

                lines = fp.read().decode(encoding="utf-8").strip().split("\n")

                rot_matrix = np.array(
                    [[float(x) for x in l.strip().split(" ")] for l in lines[:3]])

                frame = p.split("/")[-1].split("_")[1]

                if y_unit == "rad":
                    y, r, p = matrix_to_angles(rot_matrix)
                elif y_unit == "deg":
                    y, r, p = matrix_to_degree(rot_matrix)
                df["person"].append(person)
                df["frame"].append(frame)
                df["yaw"].append(y)
                df["roll"].append(r)
                df["pitch"].append(p)

        self.metadata = pd.DataFrame(df)
        self.metadata.to_csv(Path(biwi_path, f"gender_{gender}_{y_unit}.csv"))

    def __len__(self) -> int:
        return len(self.metadata)

    def __getitem__(self, i: int) -> tuple[Image.Image, float]:
        metadata = self.metadata.iloc[i].to_dict()
        y = np.array(
            (metadata["yaw"], metadata["pitch"],
             metadata["roll"])) if self.target == "all" else metadata[self.target]

        img_path = f"biwi_kinect/faces_0/{metadata['person']}/frame_{metadata['frame']}_rgb.png"
        fp = self.tar.extractfile(img_path)
        assert fp is not None

        with io.BytesIO(fp.read()) as bio:
            img = Image.open(bio).convert("RGB")

        w, h = img.width, img.height
        c = (w - h) // 2
        img = img.crop((c, 0, w - c, h))

        return img, y

    def close(self):
        self.tar.close()


def matrix_to_angles(m: np.ndarray) -> tuple[float, float, float]:
    # y = np.arctan2(m[1, 0], m[0, 0])
    # r = np.arctan2(-m[2, 0], np.sqrt(m[2, 1] * m[2, 1] + m[2, 2] * m[2, 2]))
    # correction for roll and yaw
    r = np.arctan2(m[1, 0], m[0, 0])
    y = np.arctan2(-m[2, 0], np.sqrt(m[2, 1] * m[2, 1] + m[2, 2] * m[2, 2]))
    p = np.arctan2(m[2, 1], m[2, 2])
    return y, r, p


def matrix_to_degree(m: np.ndarray) -> tuple[float, float, float]:
    r = np.arctan2(m[1, 0], m[0, 0])
    y = np.arctan2(-m[2, 0], np.sqrt(m[2, 1] * m[2, 1] + m[2, 2] * m[2, 2]))
    p = np.arctan2(m[2, 1], m[2, 2])
    return np.degrees(y), np.degrees(r), np.degrees(p)



def get_biwi_kinect(
        fetch_dset: dict,
        domain: str,
        biwi_path: str,
        apply_normalization: bool = True,) -> tuple[ImageDataset, ImageDataset]:
    if apply_normalization:  # related to pre-trained weights (on ImageNet)
        to_tensor = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ])
    else:
        to_tensor = transforms.ToTensor()
    
    train_aug_type = fetch_dset.get("train_aug_type", 'basic')
    aug_type = fetch_dset.get("aug_type", 'basic')
    task_info = fetch_dset["task_info"]
    tr_val_split = fetch_dset["tr_val_split"]
    seed = fetch_dset.get("seed", 42)
    
    task = task_info.split('_')[0]
    y_unit = task_info.split('_')[1]

    train_transform = get_biwi_kinect_transform(to_tensor, train_aug_type)
    val_transform = get_biwi_kinect_transform(to_tensor, 'val')
    aug_transform = get_biwi_kinect_transform(to_tensor, aug_type)
    txt_logger.info(f"Train set aug is [{train_aug_type}] | Val set aug type is [val] | additional aug [{aug_type}].")

    ds = BiwiKinect(biwi_path, gender=domain, target=task, y_unit=y_unit)

    # source training 8:2
    if tr_val_split > 1:
        tr_val_file = f"configs/data/biwi-kinect_{domain}_val_indices.npy"
        val_indices: np.ndarray = np.load(tr_val_file)
        txt_logger.info(f"Dataset Division - BiwiKinect: load val indices from {tr_val_file}")
        train_mask = np.ones(len(ds), dtype=np.bool_)
        train_mask[val_indices] = False
        train_indices = np.arange(len(ds))[train_mask]
        train_raw_ds = ImageSubset(ds, train_indices.tolist())
        val_raw_ds = ImageSubset(ds, val_indices.tolist())
    # test time adaptation - all for validation set
    else:
        txt_logger.info(f"Dataset Division - BiwiKinect: split randomly, with train ratio {tr_val_split} and seed is {seed}.")
        n = int(len(ds) * tr_val_split) # train_ratio
        train_raw_ds, val_raw_ds = random_split(ds, n, seed)
    txt_logger.info(f"Dataset Division - BiwiKinect: train data - [{len(train_raw_ds)}] | val data - [{len(val_raw_ds)}]")

    train_ds = ImageTransformDataset(train_raw_ds, train_transform)
    train_aug_ds = ImageTransformDataset(train_raw_ds, aug_transform)
    val_ds = ImageTransformDataset(val_raw_ds, val_transform)
    val_aug_ds = ImageTransformDataset(val_raw_ds, aug_transform)
    
    return IndexedDataset(train_ds), IndexedDataset(train_aug_ds), IndexedDataset(
        val_ds), IndexedDataset(val_aug_ds)
    


def get_biwi_kinect_transform(to_tensor, aug_type):
    match aug_type:
        case "basic":
            aug_transform = transforms.Compose(
                [
                    transforms.Resize((256, 256)),
                    transforms.RandomCrop((224, 224)),
                    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), to_tensor
                ])
        case "val":
            aug_transform = transforms.Compose(
                            [transforms.Resize((256, 256)),
                            transforms.CenterCrop((224, 224)), to_tensor])
        case "val_noNormalization":
            aug_transform = transforms.Compose(
                            [transforms.Resize((256, 256)),
                            transforms.CenterCrop((224, 224)),
                            transforms.ToTensor(),])
    return aug_transform
