import numpy as np
import torch
import torch.cuda
import torch.multiprocessing as tmp
import torch.nn.functional as F


def copy_dict_to_device(data: dict, device: torch.device) -> dict:
    for key, value in data.items():
        data[key] = value.to(device)
    return data


def collate_function(batch, ignore_index):
    src_seq = []
    trg_seq = []
    seqlens = []
    positions = []
    sample_length = []
    for sample in batch:

        if isinstance(sample[0], list):
            srq_sample = [s[:-1] for s in sample[0]]
            trg_sample = [s[1:] for s in sample[0]]

            src_seq.append(torch.cat(srq_sample, dim=0))
            trg_seq.append(torch.cat(trg_sample, dim=0))
            seqlens_in_sample = (np.array(sample[1]) - 1).tolist()
            seqlens.extend(seqlens_in_sample)
            positions.append(seqlens_in_sample)
            sample_length.append(sum(seqlens_in_sample))
        else:
            src_seq.append(sample[0][:-1])
            trg_seq.append(sample[0][1:])
            seqlens.append(sample[1] - 1)
            positions.append([sample[1] - 1])
            sample_length.append(sample[1] - 1)

    src_seq = torch.nn.utils.rnn.pad_sequence(src_seq, batch_first=True, padding_value=0)
    trg_seq = torch.nn.utils.rnn.pad_sequence(trg_seq, batch_first=True, padding_value=ignore_index)
    seqlens_tensor = torch.tensor(seqlens, dtype=torch.int32)

    max_seqlen = src_seq.size(1)
    if max_seqlen != max(sample_length):
        print(max_seqlen, max(sample_length))
    assert max_seqlen == max(sample_length)

    cu_seqlens = F.pad(torch.cumsum(seqlens_tensor, dim=0, dtype=torch.torch.int32), (1, 0))
    indices = torch.cat([torch.arange(i * max_seqlen, i * max_seqlen + l) for i, l in enumerate(sample_length)])

    seqlen_arange = [torch.cat([torch.arange(l) for l in sample_lens], dim=0) for sample_lens in positions]
    position_ids = torch.nn.utils.rnn.pad_sequence(seqlen_arange, batch_first=True, padding_value=0)
    attention_mask = torch.ones(src_seq.size(0), max_seqlen, dtype=torch.bool)
    for i, l in enumerate(sample_length):
        attention_mask[i, l:] = 0

    return {
        'src_seq': src_seq,
        'seqlens': seqlens_tensor,
        'cu_seqlens': cu_seqlens,
        'indices': indices,
        'position_ids': position_ids,
        'attention_mask': attention_mask,
        'trg_seq': trg_seq,
    }


def _buffer_process(sample_loader, buffer, device, batch_size, ignore_index):
    while True:
        batch = []
        try:
            for i in range(batch_size):
                sample = next(sample_loader)
                batch.append(sample)
            batch = collate_function(batch, ignore_index)
            batch = copy_dict_to_device(batch, device)
            buffer.put(batch)
        except StopIteration:
            break
    raise StopIteration


class PrefetchCollator():
    def __init__(self, sample_loader, device, batch_size: int, buffer_size: int = 3, ignore_index: int = -100):
        self.sample_loader = iter(sample_loader)
        self.device = device
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.ignore_index = ignore_index

    def __iter__(self):
        return self

    def __next__(self):
        batch = []
        try:
            for i in range(self.batch_size):
                sample = next(self.sample_loader)
                batch.append(sample)
            batch = collate_function(batch, self.ignore_index)
            batch = copy_dict_to_device(batch, self.device)
            return batch
        except StopIteration:
            raise StopIteration


class BufferedPrefetchCollator():
    def __init__(self, sample_loader, device, batch_size: int, buffer_size: int = 3, ignore_index: int = -100,
                 ctx=None):
        self.sample_loader = iter(sample_loader)
        self.device = device
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        self.ignore_index = ignore_index

        if ctx is None:
            self.ctx = tmp.get_context('spawn')
        else:
            self.ctx = ctx
        self.buffer = self.ctx.Queue(maxsize=buffer_size)

        self.process = self.ctx.Process(target=_buffer_process,
                                        args=(self.sample_loader, self.buffer, device, batch_size, ignore_index))
        self.process.daemon = True
        self.process.start()

    def __del__(self):
        self.buffer.close()
        if self.process.is_alive():
            self.process.terminate()
            self.process.join()

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = self.buffer.get()
            return batch
        except:
            raise StopIteration
