from functools import partial
from typing import Optional, List, Tuple

import pytorch_lightning as pl
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from avr.data.vasr.dataset import VasrDataset


class VasrDataModule(pl.LightningDataModule):
    def __init__(self, cfg: DictConfig, collate_to_tensor: bool = False):
        super(VasrDataModule, self).__init__()
        self.cfg: DictConfig = cfg
        self.train_dataset: VasrDataset = None
        self.val_dataset: VasrDataset = None
        self.test_dataset: VasrDataset = None
        self.collate_to_tensor = collate_to_tensor

    def setup(self, stage: Optional[str] = None):
        self.train_dataset = instantiate(self.cfg.avr.data.analogy.vasr.train)
        self.val_dataset = instantiate(self.cfg.avr.data.analogy.vasr.val)
        self.test_dataset = instantiate(self.cfg.avr.data.analogy.vasr.test)

    def train_dataloader(self) -> DataLoader:
        return instantiate(
            self.cfg.torch.data_loader.train,
            dataset=self.train_dataset,
            shuffle=True,
            collate_fn=partial(my_collate, to_tensor=self.collate_to_tensor),
        )

    def val_dataloader(self) -> DataLoader:
        return instantiate(
            self.cfg.torch.data_loader.val,
            dataset=self.val_dataset,
            collate_fn=partial(my_collate, to_tensor=self.collate_to_tensor),
        )

    def test_dataloader(self) -> DataLoader:
        return instantiate(
            self.cfg.torch.data_loader.test,
            dataset=self.test_dataset,
            collate_fn=partial(my_collate, to_tensor=self.collate_to_tensor),
        )


def my_collate(
    batch: List[Tuple[List[torch.Tensor], List[torch.Tensor], int]],
    to_tensor: bool = False,
) -> Tuple[List[List[torch.Tensor]], List[List[torch.Tensor]], torch.Tensor]:
    if to_tensor:
        context = torch.stack([torch.stack(item[0], dim=0) for item in batch], dim=0)
        answers = torch.stack([torch.stack(item[1], dim=0) for item in batch], dim=0)
    else:
        context = [item[0] for item in batch]
        answers = [item[1] for item in batch]
    target = torch.LongTensor([item[2] for item in batch])
    return context, answers, target
