import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import random
from copy import copy


class PairDataset(Dataset):
    def __init__(self, x, y, seqs, affinity_range, return_seqs=False):
        seq_to_idxs = {}
        for i in range(len(x)):
            if seqs[i] not in seq_to_idxs:
                seq_to_idxs[seqs[i]] = []
            seq_to_idxs[seqs[i]].append(i)
        x = torch.tensor(x)
        y = torch.tensor(y).float().unsqueeze(1)
        y[y < -2.5] = -2.5
        y[y > 6.5] = 6.5
        for seq in seq_to_idxs:
            seq_data = y[seq_to_idxs[seq]].flatten()
            top_idxs = []
            for start, end in affinity_range:
                top_idxs.append(torch.arange(0, len(seq_data))[(seq_data >= start) & (seq_data < end)])
            top_idxs = [[seq_to_idxs[seq][idx] for idx in idxs] for idxs in top_idxs]
            seq_to_idxs[seq] = top_idxs
        self.x = x.detach()
        self.y = y.detach()
        self.seqs = seqs
        self.seq_to_idxs = seq_to_idxs
        self.avail_idxs = {seq: [[]] * len(affinity_range) for seq in self.seq_to_idxs}
        self.return_seqs = return_seqs

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        seq = self.seqs[idx]
        context_idxs = []
        for i in range(len(self.avail_idxs[seq])):
            if not self.avail_idxs[seq][i]:
                self.avail_idxs[seq][i] = copy(self.seq_to_idxs[seq][i])
                random.shuffle(self.avail_idxs[seq][i])
            context_idxs.append(self.avail_idxs[seq][i].pop())
        if self.return_seqs:
            return self.x[context_idxs], self.y[context_idxs], self.x[idx], self.y[idx], seq
        else:
            return self.x[context_idxs], self.y[context_idxs], self.x[idx], self.y[idx]


def get_dataloaders(batch_size, affinity_range=[(-50, 0)],  data_file='data.pickle'):
    x_train, x_test, y_train, y_test, train_seqs, test_seqs, token_to_idx = pickle.load(open(data_file, 'rb'))
    idx_to_token = {token_to_idx[token]: token for token in token_to_idx}
    train_dataloader = DataLoader(PairDataset(x_train, y_train, train_seqs, affinity_range), batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(PairDataset(x_test, y_test, test_seqs, affinity_range, return_seqs=True), batch_size=batch_size, shuffle=True)
    return train_dataloader, test_dataloader, max(x_train.max(), x_test.max()) + 1, x_train.shape[1], y_test.std(), idx_to_token


def get_for_eval(data_file='data.pickle'):
    x_train, x_test, y_train, _, _, _, token_to_idx = pickle.load(open(data_file, 'rb'))
    return token_to_idx, y_train.mean(), y_train.std(), max(x_train.max(), x_test.max()) + 1, x_train.shape[1]
