import re

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

from sequence_models.constants import PROTEIN_ALPHABET, PAD

pad_idx = PROTEIN_ALPHABET.index(PAD)


def embed(model, dl, device, batch_size,
          seq2seq=True, contacts=False, to_esm=None,
          alphabet=None, max_len=None,
          batch_sampler=None):
    embeddings = []
    y = []
    model = model.eval()
    for src, tgt, mask in dl:
        input_mask = (src != PROTEIN_ALPHABET.index(PAD)).float().unsqueeze(-1)
        if to_esm is not None:
            src = src[:, :1022]
            input_mask = input_mask[:, :1022].unsqueeze(-1)
            n, ell = src.shape
            esm_src = torch.zeros(n, ell + 2) + alphabet.padding_idx
            esm_src[:, 0] = alphabet.cls_idx
            tokenized = [[to_esm[s.item()] for s in sr if s != pad_idx] + [alphabet.eos_idx] for sr in src]
            ells = []
            for i, t in enumerate(tokenized):
                el = len(t)
                ells.append(el - 1)
                esm_src[i, 1:el + 1] = torch.tensor(t)
            esm_src = esm_src.to(device).long()
            src = esm_src
        src = src.to(device)
        with torch.no_grad():
            if to_esm is not None:
                embedding = model(src, repr_layers=[33], return_contacts=False)['representations'][33].detach().cpu()
            else:
                embedding = model(src, input_mask=input_mask.to(device)).detach().cpu()
        for e, t, m in zip(embedding, tgt, input_mask):
            ell = int(m.sum())
            embeddings.append(e[:ell, :])
            if seq2seq:
                y.append(t[:ell])
            elif contacts:
                y.append(t[:ell, :ell])
            else:
                y.append(t)
    ds = EmbeddedDataset(embeddings, y)
    if batch_sampler is not None:
        dl = DataLoader(ds, batch_sampler=batch_sampler,
                        collate_fn=EmbeddingCollater(seq2seq=seq2seq, contacts=contacts, max_len=max_len),
                        num_workers=8,
                        pin_memory=False)
    else:
        dl = DataLoader(ds, batch_size=batch_size, shuffle=True,
                        collate_fn=EmbeddingCollater(seq2seq=seq2seq, contacts=contacts, max_len=max_len),
                        num_workers=8,
                        pin_memory=False)
    return dl


class EmbeddedDataset(Dataset):

    def __init__(self, src, tgt):
        super(EmbeddedDataset).__init__()
        assert len(src) == len(tgt)
        self.src = src
        self.tgt = tgt

    def __getitem__(self, idx):
        return self.src[idx], self.tgt[idx]

    def __len__(self):
        return len(self.src)


class EmbeddingCollater(object):

    def __init__(self, seq2seq=True, contacts=False, max_len=None):
        self.seq2seq = seq2seq
        self.contacts = contacts
        self.max_len = max_len
        return

    def __call__(self, batch):
        src, tgt = tuple(zip(*batch))
        ells = [len(s) for s in src]
        max_len = max(ells)
        src = list(src)
        tgt = list(tgt)
        if self.max_len is not None and max_len > self.max_len:
            for i, ell in enumerate(ells):
                if ell > self.max_len:
                    start = np.random.choice(ell - self.max_len)
                    stop = start + self.max_len
                    src[i] = src[i][start: stop]
                    if self.seq2seq:
                        tgt[i] = tgt[i][start: stop]
                    elif self.contacts:
                        tgt[i] = tgt[i][start: stop, start: stop]
                    ells[i] = self.max_len
            max_len = self.max_len
        padded = torch.zeros(len(src), max_len, src[0].shape[-1])
        for i, (s, ell) in enumerate(zip(src, ells)):
            padded[i, :ell] = s
        src = padded
        if self.seq2seq:
            padded = torch.ones(len(src), max_len, dtype=torch.long) * -100
            for i, (t, ell) in enumerate(zip(tgt, ells)):
                padded[i, :ell] = t
            tgt = padded
            input_mask = (tgt != -100).float()
        elif self.contacts:
            masks = [torch.ones_like(dist).bool() for dist in tgt]
            masks = [F.pad(d, [0, max_len - ell, 0, max_len - ell], value=False)
                       for d, ell in zip(masks, ells)]
            input_mask = torch.stack(masks, dim=0)
            tgt = [F.pad(d, [0, max_len - ell, 0, max_len - ell], value=0)
                       for d, ell in zip(tgt, ells)]
            tgt = torch.stack(tgt, dim=0)
        else:
            tgt = torch.tensor(tgt).view(-1, 1)
            input_mask = torch.zeros(len(src), max_len)
            for i, ell in enumerate(ells):
                input_mask[i, :ell] = 1.0
        return src, tgt, input_mask
