"""
This module contains the dataloader for the letters dataset.
"""

import numpy as np
import torch
import torch_geometric
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageDraw, ImageFont
from typing import Callable, Dict, Optional, Tuple, Type, Union

import sys
from pathlib import Path

current_directory = Path(__file__).absolute().parent
parent_directory = current_directory.parent
parent_parent_directory = current_directory.parent.parent
sys.path.append(str(parent_directory))


class letters_dataset(Dataset):
    def __init__(
        self,
        noise_scale=0.05,
        source_noise_scale=0.5,
        num_rotations=10,
        ivp_batch_size=None,
        conditional=False,
        seed=0,
        mode="train",
    ) -> None:

        self.noise_scale = noise_scale
        self.source_noise_scale = source_noise_scale
        self.num_rotations = num_rotations
        self.ivp_batch_size = ivp_batch_size
        self.conditional = conditional
        self.seed = seed

        assert mode in [
            "train",
            "val",
            "test",
        ], "Invalid mode. Must be either 'train' or 'val' or 'test'" 
        self.mode = mode
        
        self.alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
        
        if self.conditional:
            self.num_train_conditions = (len(self.alphabet) - 2) * self.num_rotations + 2
            self.num_val_conditions = self.num_rotations
            self.num_test_conditions = self.num_rotations
            self.num_conditions = self.num_train_conditions + self.num_val_conditions + self.num_test_conditions
        
        self.samples = self.get_denoising_samples(mode, num_rotations=num_rotations, seed=seed)
        self.num_samples = len(self.samples)

    def char_sampler(self, char, font_size=100, rotation=0.0):
        # generate samples from a character
        font = ImageFont.truetype(str(parent_parent_directory)+"/arial.ttf", font_size)
        image_size = 2*font_size
        img = Image.new('L', (image_size, image_size), color = 0)
        d = ImageDraw.Draw(img)
        d.text((0,0), char, fill=(255), font=font)

        img = np.array(img)

        img = np.flipud(img)

        # extract samples
        grid = np.indices(img.shape).T
        mask = img > 0
        samples = grid[mask]
        samples = torch.from_numpy(samples).float()

        min_vals = samples.min(axis=0).values
        max_vals = samples.max(axis=0).values
        range_vals = max_vals - min_vals
        samples = (samples - min_vals) / range_vals
        ratio = range_vals[0] / range_vals[1]

        samples = samples * 2 - 1
        samples *= 3.

        if ratio.item() < 1.:
            samples[:, 0] *= ratio
        else:
            samples[:, 1] /= ratio

        samples = samples + self.noise_scale*torch.randn_like(samples)

        samples = samples.cuda()

        M = torch.tensor([
            [np.cos(rotation), -np.sin(rotation)],
            [np.sin(rotation), np.cos(rotation)]
        ], device=samples.device, dtype=samples.dtype)
        samples = samples @ M.T

        return samples

    def get_denoising_samples(self, mode='train', num_rotations=10, seed=0):
        torch.manual_seed(seed)
        np.random.seed(seed)

        samples = []
        if mode == 'train':
            if self.conditional:
                cond = 0
            for target in self.alphabet:
                if target == "Y" or target == "X":
                    target_samples = self.char_sampler(target)
                    source_samples = target_samples + self.source_noise_scale * torch.randn_like(
                        target_samples
                    )
                    if not self.conditional:
                        samples.append(
                            (source_samples.cpu().numpy(), target_samples.cpu().numpy())
                        )
                    else:
                        condition = np.zeros((source_samples.shape[0], self.num_conditions))
                        condition[:, cond] = 1
                        samples.append(
                            (source_samples.cpu().numpy(), target_samples.cpu().numpy(), condition)
                        )
                        cond += 1
                else:
                    for i in range(num_rotations):
                        rotation = 2.0 * torch.pi * torch.rand(1)
                        target_samples = self.char_sampler(
                            target, rotation=rotation.item()
                        )
                        source_samples = (
                            target_samples + self.source_noise_scale * torch.randn_like(target_samples)
                        )
                        if not self.conditional:
                            samples.append(
                                (source_samples.cpu().numpy(), target_samples.cpu().numpy())
                            )
                        else:
                            condition = np.zeros((source_samples.shape[0], self.num_conditions))
                            condition[:, cond] = 1
                            samples.append(
                                (source_samples.cpu().numpy(), target_samples.cpu().numpy(), condition)
                            )
                            cond += 1
        elif mode == 'val':
            target = "X"
            if self.conditional:
                cond = self.num_train_conditions
            for i in range(num_rotations):
                rotation = 2.0 * torch.pi * torch.rand(1)
                target_samples = self.char_sampler(target, rotation=rotation.item())
                source_samples = (
                    target_samples + self.source_noise_scale * torch.randn_like(target_samples)
                )
                if not self.conditional:
                    samples.append(
                        (source_samples.cpu().numpy(), target_samples.cpu().numpy())
                    )
                else:
                    condition = np.zeros((source_samples.shape[0], self.num_conditions))
                    condition[:, cond] = 1
                    samples.append(
                        (source_samples.cpu().numpy(), target_samples.cpu().numpy(), condition)
                    )
                    cond += 1
        elif mode == "test":
            target = "Y"
            if self.conditional:
                cond = self.num_train_conditions + self.num_val_conditions
            for i in range(num_rotations):
                rotation = 2.0 * torch.pi * torch.rand(1)
                target_samples = self.char_sampler(target, rotation=rotation.item())
                source_samples = (
                    target_samples + self.source_noise_scale * torch.randn_like(target_samples)
                )
                if not self.conditional:
                    samples.append(
                        (source_samples.cpu().numpy(), target_samples.cpu().numpy())
                    )
                else:
                    condition = np.zeros((source_samples.shape[0], self.num_conditions))
                    condition[:, cond] = 1
                    samples.append(
                        (source_samples.cpu().numpy(), target_samples.cpu().numpy(), condition)
                    )
                    cond += 1
        else:
            raise ValueError('Invalid mode')

        return samples

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        if self.conditional:
            source, target, cond = self.samples[idx] # [1, num_samples, d] one "env" at a time.
            if self.ivp_batch_size is not None:
                if self.mode == "val" or self.mode == "test":
                    return idx, source, target, cond
                ivp_idx = np.random.choice(
                    np.arange(source.shape[0]), size=self.ivp_batch_size, replace=False
                )
                return idx, source[ivp_idx], target[ivp_idx], cond[ivp_idx]
            return idx, source, target, cond  # [1, ivp_bs, d]
        else:
            source, target = self.samples[idx]  # [1, num_samples, d] one "env" at a time.
            if self.ivp_batch_size is not None:
                if self.mode == "val" or self.mode == "test":
                    return idx, source, target
                ivp_idx = np.random.choice(
                    np.arange(source.shape[0]), size=self.ivp_batch_size, replace=False
                )
                return idx, source[ivp_idx], target[ivp_idx]
            return idx, source, target  # [1, ivp_bs, d]


class LettersDatamodule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size=256,
        ivp_batch_size=None,
        noise_scale=0.05,
        source_noise_scale=0.5,
        num_rotations=10,
        conditional=False,
        seed=0,
    ) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.ivp_batch_size = ivp_batch_size
        self.noise_scale = noise_scale
        self.source_noise_scale = source_noise_scale
        self.num_rotations = num_rotations
        self.conditional = conditional
        self.seed = seed
        
        self.save_hyperparameters(logger=True)
        
        self.train_dataset = letters_dataset(
            mode="train",
            noise_scale=self.noise_scale,
            source_noise_scale=self.source_noise_scale,
            num_rotations=self.num_rotations,
            ivp_batch_size=self.ivp_batch_size,
            conditional=self.conditional,
            seed=self.seed,
        )
        
        self.val_dataset = letters_dataset(
            mode="val",
            noise_scale=self.noise_scale,
            source_noise_scale=self.source_noise_scale,
            num_rotations=self.num_rotations,
            ivp_batch_size=self.ivp_batch_size,
            conditional=self.conditional,
            seed=self.seed,
        )

        self.test_dataset = letters_dataset(
            mode="test",
            noise_scale=self.noise_scale,
            source_noise_scale=self.source_noise_scale,
            num_rotations=self.num_rotations,
            ivp_batch_size=self.ivp_batch_size,
            conditional=self.conditional,
            seed=self.seed,
        )

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True, #set to Flase for testing_pipline,
            num_workers=4,
            pin_memory=True,
        )
        
    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=10, # TODO: fix, this is temp
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )
    
    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=10, # set to 1 for predict_step in testing_pipeline, # TODO: fix, this is temp
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )


class letters_batch_dataset(Dataset):
    def __init__(
        self,
        noise_scale=0.05,
        source_noise_scale=0.5,
        num_rotations=10,
        ivp_batch_size=None,
        conditional=False,
        seed=0,
        mode="train",
    ) -> None:
        self.noise_scale = noise_scale
        self.source_noise_scale = source_noise_scale
        self.num_rotations = num_rotations
        self.ivp_batch_size = ivp_batch_size
        self.conditional = conditional
        self.seed = seed

        assert mode in [
            "train",
            "val",
            "test",
        ], "Invalid mode. Must be either 'train' or 'val' or 'test'"
        self.mode = mode

        self.alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

        if self.conditional:
            self.num_train_conditions = (
                len(self.alphabet) - 2
            ) * self.num_rotations + 2
            self.num_val_conditions = self.num_rotations
            self.num_test_conditions = self.num_rotations
            self.num_conditions = (
                self.num_train_conditions
                + self.num_val_conditions
                + self.num_test_conditions
            )

        self.samples = self.get_denoising_samples(
            mode, num_rotations=num_rotations, seed=seed
        )
        self.num_samples = len(self.samples)

    def char_sampler(self, char, font_size=100, rotation=0.0):
        # generate samples from a character
        font = ImageFont.truetype(
            str(parent_parent_directory) + "/arial.ttf", font_size
        )
        image_size = 2 * font_size
        img = Image.new("L", (image_size, image_size), color=0)
        d = ImageDraw.Draw(img)
        d.text((0, 0), char, fill=(255), font=font)

        img = np.array(img)

        img = np.flipud(img)

        # extract samples
        grid = np.indices(img.shape).T
        mask = img > 0
        samples = grid[mask]
        samples = torch.from_numpy(samples).float()

        min_vals = samples.min(axis=0).values
        max_vals = samples.max(axis=0).values
        range_vals = max_vals - min_vals
        samples = (samples - min_vals) / range_vals
        ratio = range_vals[0] / range_vals[1]

        samples = samples * 2 - 1
        samples *= 3.0

        if ratio.item() < 1.0:
            samples[:, 0] *= ratio
        else:
            samples[:, 1] /= ratio

        samples = samples + self.noise_scale * torch.randn_like(samples)

        samples = samples.cuda()

        M = torch.tensor(
            [
                [np.cos(rotation), -np.sin(rotation)],
                [np.sin(rotation), np.cos(rotation)],
            ],
            device=samples.device,
            dtype=samples.dtype,
        )
        samples = samples @ M.T

        return samples

    def get_denoising_samples(self, mode="train", num_rotations=10, seed=0):
        torch.manual_seed(seed)
        np.random.seed(seed)

        samples = []
        if mode == "train":
            if self.conditional:
                cond = 0
            for target in self.alphabet:
                if target == "Y" or target == "X":
                    target_samples = self.char_sampler(target)
                    source_samples = (
                        target_samples
                        + self.source_noise_scale * torch.randn_like(target_samples)
                    )
                    if not self.conditional:
                        samples.append(
                            (
                                source_samples.unsqueeze(0).cpu(),
                                target_samples.unsqueeze(0).cpu(),
                            )
                        )
                    else:
                        condition = np.zeros(
                            (source_samples.shape[0], self.num_conditions)
                        )
                        condition[:, cond] = 1
                        samples.append(
                            (
                                source_samples.cpu().numpy(),
                                target_samples.cpu().numpy(),
                                condition,
                            )
                        )
                        cond += 1
                else:
                    for i in range(num_rotations):
                        rotation = 2.0 * torch.pi * torch.rand(1)
                        target_samples = self.char_sampler(
                            target, rotation=rotation.item()
                        )
                        source_samples = (
                            target_samples
                            + self.source_noise_scale * torch.randn_like(target_samples)
                        )
                        if not self.conditional:
                            samples.append(
                                (
                                    source_samples.unsqueeze(0).cpu(),
                                    target_samples.unsqueeze(0).cpu(),
                                )
                            )
                        else:
                            condition = np.zeros(
                                (source_samples.shape[0], self.num_conditions)
                            )
                            condition[:, cond] = 1
                            samples.append(
                                (
                                    source_samples.cpu().numpy(),
                                    target_samples.cpu().numpy(),
                                    condition,
                                )
                            )
                            cond += 1
        elif mode == "val":
            target = "X"
            if self.conditional:
                cond = self.num_train_conditions
            for i in range(num_rotations):
                rotation = 2.0 * torch.pi * torch.rand(1)
                target_samples = self.char_sampler(target, rotation=rotation.item())
                source_samples = (
                    target_samples
                    + self.source_noise_scale * torch.randn_like(target_samples)
                )
                if not self.conditional:
                    samples.append(
                        (
                            source_samples.unsqueeze(0).cpu(),
                            target_samples.unsqueeze(0).cpu(),
                        )
                    )
                else:
                    condition = np.zeros((source_samples.shape[0], self.num_conditions))
                    condition[:, cond] = 1
                    samples.append(
                        (
                            source_samples.cpu().numpy(),
                            target_samples.cpu().numpy(),
                            condition,
                        )
                    )
                    cond += 1
        elif mode == "test":
            target = "Y"
            if self.conditional:
                cond = self.num_train_conditions + self.num_val_conditions
            for i in range(num_rotations):
                rotation = 2.0 * torch.pi * torch.rand(1)
                target_samples = self.char_sampler(target, rotation=rotation.item())
                source_samples = (
                    target_samples
                    + self.source_noise_scale * torch.randn_like(target_samples)
                )
                if not self.conditional:
                    samples.append(
                        (
                            source_samples.unsqueeze(0).cpu(),
                            target_samples.unsqueeze(0).cpu(),
                        )
                    )
                else:
                    condition = np.zeros((source_samples.shape[0], self.num_conditions))
                    condition[:, cond] = 1
                    samples.append(
                        (
                            source_samples.cpu().numpy(),
                            target_samples.cpu().numpy(),
                            condition,
                        )
                    )
                    cond += 1
        else:
            raise ValueError("Invalid mode")

        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        if self.conditional:
            source, target, cond = self.samples[
                idx
            ]  # [1, num_samples, d] one "env" at a time.
            if self.ivp_batch_size is not None:
                if self.mode == "val" or self.mode == "test":
                    return idx, source, target, cond
                ivp_idx = np.random.randint(source.shape[0], size=self.ivp_batch_size)
                return idx, source[ivp_idx], target[ivp_idx], cond[ivp_idx]
            return idx, source, target, cond  # [1, ivp_bs, d]
        else:
            source, target = self.samples[
                idx
            ]  # [1, num_samples, d] one "env" at a time.
            if self.ivp_batch_size is not None:
                if self.mode == "val" or self.mode == "test":
                    return idx, source, target
                ivp_idx = np.random.randint(source.shape[0], size=self.ivp_batch_size)
                return idx, source[ivp_idx], target[ivp_idx]
            return idx, source, target  # [1, ivp_bs, d]



def collate_fn_pad(
    batch,
    *,
    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
):
    lengths = map(lambda data: data[1].shape[1], batch)
    max_len = max(lengths)

    (
        idx_pad,
        x0_pad,
        x1_pad,
        num_pad_idx,
    ) = [], [], [], []
    for idx, x0, x1 in batch:
        x0_pad.append(
            torch.functional.F.pad(
                x0, [0, 0, 0, max_len - x0.shape[1], 0, 0], mode="constant", value=0
            )
        )
        x1_pad.append(
            torch.functional.F.pad(
                x1, [0, 0, 0, max_len - x1.shape[1], 0, 0], mode="constant", value=0
            )
        )
        idx_pad.append(torch.tensor(idx))
        num_pad_idx.append(torch.tensor(max_len - x0.shape[1]))

    elem = batch[0]
    out = None
    if torch.utils.data.get_worker_info() is not None:
        # If we're in a background process, concatenate directly into a
        # shared memory tensor to avoid an extra copy
        numel = sum(x.numel() for x in batch)
        storage = elem._typed_storage()._new_shared(numel, device=elem.device)
        out = elem.new(storage).resize_(len(batch), *list(elem.size()))
    # return torch.stack(batch, 0, out=out)

    return (
        torch.stack(idx_pad, dim=0, out=out),
        torch.stack(x0_pad, dim=0, out=out).squeeze(1),
        torch.stack(x1_pad, dim=0, out=out).squeeze(1),
        # torch.stack(num_pad_idx, dim=0, out=out),
    )


class LettersPyGeoDatamodule(LettersDatamodule):
    def __init__(
        self,
        batch_size=256,
        ivp_batch_size=None,
        noise_scale=0.05,
        source_noise_scale=0.5,
        num_rotations=10,
        conditional=False,
        seed=0,
    ) -> None:
        super().__init__(
            batch_size=batch_size,
            ivp_batch_size=ivp_batch_size,
            noise_scale=noise_scale,
            source_noise_scale=source_noise_scale,
            num_rotations=num_rotations,
            conditional=conditional,
            seed=seed,
        )
        
        self.save_hyperparameters(logger=True)
        
        self.train_dataset = letters_batch_dataset(
            mode="train",
            noise_scale=self.noise_scale,
            source_noise_scale=self.source_noise_scale,
            num_rotations=self.num_rotations,
            ivp_batch_size=self.ivp_batch_size,
            conditional=self.conditional,
            seed=self.seed,
        )

        self.val_dataset = letters_batch_dataset(
            mode="val",
            noise_scale=self.noise_scale,
            source_noise_scale=self.source_noise_scale,
            num_rotations=self.num_rotations,
            ivp_batch_size=self.ivp_batch_size,
            conditional=self.conditional,
            seed=self.seed,
        )

        self.test_dataset = letters_batch_dataset(
            mode="test",
            noise_scale=self.noise_scale,
            source_noise_scale=self.source_noise_scale,
            num_rotations=self.num_rotations,
            ivp_batch_size=self.ivp_batch_size,
            conditional=self.conditional,
            seed=self.seed,
        )

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=collate_fn_pad,
            num_workers=0,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=10,  # TODO: fix, this is temp
            shuffle=False,
            #collate_fn=collate_fn_pad,
            num_workers=0,
            pin_memory=True,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=10,  # TODO: fix, this is temp
            shuffle=False,
            #collate_fn=collate_fn_pad,
            num_workers=0,
            pin_memory=True,
        )