import os
from pathlib import Path
from random import randrange

import pandas as pd
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

# for running on single GPU
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"


def identity(x):
    return x


class DistributedDalaloaderWrapper:
    def __init__(self, dataloader: DataLoader, collate_fn):
        self.dataloader = dataloader
        self.collate_fn = collate_fn

    def _epoch_iterator(self, it):
        for batch in it:
            yield self.collate_fn(batch)

    def __iter__(self):
        it = iter(self.dataloader)
        return self._epoch_iterator(it)

    def __len__(self):
        return len(self.dataloader)

    @property
    def dataset(self):
        return self.dataloader.dataset

    def set_epoch(self, epoch: int):
        self.dataloader.sampler.set_epoch(epoch)


def universal_collater(batch):
    all_data = [[] for _ in range(len(batch[0]))]
    for one_batch in batch:
        for i, (data) in enumerate(one_batch):
            all_data[i].append(data)
    return all_data


def universal_dict_collater(batch):
    keys = batch[0].keys()
    all_data = {key: [] for key in keys}
    for one_batch in batch:
        for key in keys:
            all_data[key].append(one_batch[key])
    return all_data


class VoxCelebTrain(Dataset):
    def __init__(self, args, col_sample="filepath", col_label="label", random_crop=True):
        self.df = pd.read_csv(args.meta_csv_file)
        self.batch_length = args.batch_length
        self.col_sample = col_sample
        self.col_label = col_label
        self.random_crop = random_crop
        self.pt_dir = args.pt_dir

    def __len__(self):
        return len(self.df)

    def _random_crop(self, waveform):
        max_start_index = waveform.shape[1] - self.batch_length
        try:
            random_start = randrange(max_start_index)
            cropped_waveform = waveform[:, random_start : random_start + self.batch_length]
            assert cropped_waveform.shape[1] == self.batch_length, (
                f"Expected cropped shape to be {self.batch_length}, "
                f"got: {cropped_waveform.shape[1]}, "
                f"start index: {random_start}"
            )
            return cropped_waveform
        except Exception as e:
            print(e)
            assert False, f"Shape: {waveform.shape[1]}, start index: {max_start_index}"

    def __getitem__(self, idx):
        filename = str(Path(self.pt_dir) / "/".join(self.df.loc[idx, self.col_sample].split("/")[3:]))
        waveform = torch.load(filename)
        if len(waveform.shape) == 1:
            waveform = waveform.unsqueeze(0)
        speaker_id = torch.tensor([int(self.df.loc[idx, self.col_label])], dtype=torch.long)
        padding_mask = torch.full((1, self.batch_length), fill_value=False, dtype=torch.bool)
        length = waveform.shape[-1]
        if length > self.batch_length:
            if self.random_crop:
                waveform = self._random_crop(waveform)  # train
            else:
                waveform = waveform[:, : self.batch_length]  # validation
        elif length < self.batch_length:
            padding_length = self.batch_length - length
            waveform = torch.nn.functional.pad(waveform, (0, padding_length), "constant", value=0.0)
            padding_mask[:, -padding_length:] = True
        sample = {
            "waveform": waveform,
            "padding_mask": padding_mask,
            "label": speaker_id,
        }

        return sample


def pad_collate_fn(batch):
    max_length1 = max(sample["waveform1"].shape[-1] for sample in batch)
    max_length2 = max(sample["waveform2"].shape[-1] for sample in batch)
    max_length = max(max_length1, max_length2)

    def pad_and_get_mask(waveform, max_length):
        length = waveform.shape[-1]
        if length < max_length:
            padding_length = max_length - length
            waveform = F.pad(waveform, (0, padding_length), "constant", value=0.0)

        # Create padding mask
        padding_mask = torch.full((1, max_length), fill_value=False, dtype=torch.bool)
        if length < max_length:
            padding_mask[:, -padding_length:] = True
        return waveform, padding_mask.long()

    for sample in batch:
        waveform1, padding_mask1 = pad_and_get_mask(sample["waveform1"], max_length)
        waveform2, padding_mask2 = pad_and_get_mask(sample["waveform2"], max_length)
        sample["waveform1"] = waveform1
        sample["waveform2"] = waveform2
        sample["padding_mask1"] = padding_mask1
        sample["padding_mask2"] = padding_mask2

    waveforms1 = torch.stack([sample["waveform1"] for sample in batch]).squeeze(1)
    waveforms2 = torch.stack([sample["waveform2"] for sample in batch]).squeeze(1)
    padding_masks1 = torch.stack([sample["padding_mask1"] for sample in batch]).squeeze(1)
    padding_masks2 = torch.stack([sample["padding_mask2"] for sample in batch]).squeeze(1)
    labels = torch.stack([sample["label"] for sample in batch]).squeeze(1)

    return {
        "waveform1": waveforms1,
        "waveform2": waveforms2,
        "padding_mask1": padding_masks1,
        "padding_mask2": padding_masks2,
        "label": labels,
    }


class VoxCelebTest(Dataset):
    def __init__(self, args, state, original_file="original_file", compared_file="compared_file", col_label="labels"):
        meta_csv_file = args.val_csv_file if state == "val" else args.test_csv_file
        self.df = pd.read_csv(meta_csv_file)
        # self.batch_length = args.val_batch_length
        self.col_original_file = original_file
        self.col_compared_file = compared_file
        self.col_label = col_label
        self.pt_dir = args.pt_dir

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        label = torch.tensor([int(self.df.loc[idx, self.col_label])], dtype=torch.long)
        waveform1_path = str(Path(self.pt_dir) / "/".join(self.df.loc[idx, self.col_original_file].split("/")[3:]))
        waveform1 = torch.load(waveform1_path)
        if len(waveform1.shape) == 1:
            waveform1 = waveform1.unsqueeze(0)
        waveform2_path = str(Path(self.pt_dir) / "/".join(self.df.loc[idx, self.col_compared_file].split("/")[3:]))
        waveform2 = torch.load(waveform2_path)
        if len(waveform2.shape) == 1:
            waveform2 = waveform2.unsqueeze(0)

        sample = {
            "waveform1": waveform1,
            "waveform2": waveform2,
            "label": label,
        }

        return sample


class DataloaderFactory:
    def __init__(self, args, mode="finetune"):
        self.args = args
        self.mode = mode

    def build(self, state: str = "train", bs: int = 1, fold: int = 1):
        if state == "train":
            dataset = VoxCelebTrain(args=self.args)
        else:
            dataset = VoxCelebTest(args=self.args, state=state)
        collate_fn = universal_dict_collater if state == "train" else pad_collate_fn
        sampler = DistributedSampler(dataset, shuffle=state == "train")
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=bs,
            drop_last=False,
            num_workers=self.args.num_workers,
            collate_fn=identity,
            sampler=sampler,
            pin_memory=True,
            multiprocessing_context=mp.get_context("fork"),
        )

        return DistributedDalaloaderWrapper(dataloader, collate_fn)
