import torch, csv
from collections import defaultdict

class DMSData(torch.nn.Module):
    def __init__(self, dms_path, vocab, ref_seq, prepend_bos=False, append_eos=False, seleted_feats=None):
        # 20 x L x D matrix + masks
        self.ref_seq = ref_seq
        self.prepend_bos = prepend_bos
        self.append_eos = append_eos
        self.vocab = vocab

        ref_seqlen = len(ref_seq)

        data = self.read(dms_path)
        self.dms_raw_data = data
        channels = list(set([x[2] for x in data]))
        if seleted_feats is not None:
            channels = [c for c in channels if c in seleted_feats]
            print("DMS Features", channels)
            assert len(channels) > 0
        self.channels = channels
        self.channels_map = {x: i for i, x in enumerate(channels)}
        
        # print("Number of effects", len(channels))
        t = torch.zeros((ref_seqlen, len(vocab), len(channels)))
        mask = torch.zeros((ref_seqlen, len(vocab), len(channels)))
        for mut, pos, eff, v in data:
            if eff not in self.channels_map:
                continue
            x, p, c, v = vocab.get_idx(mut), int(pos), self.channels_map[eff], float(v)
            t[p, x, c] = v
            mask[p, x, c] = 1

        if prepend_bos:
            t = torch.cat([torch.zeros((1, len(vocab), len(channels))), t], dim=0)
            mask = torch.cat([torch.zeros((1, len(vocab), len(channels))), mask], dim=0)
        if append_eos:
            t = torch.cat([t, torch.zeros((1, len(vocab), len(channels)))], dim=0)
            mask = torch.cat([mask, torch.zeros((1, len(vocab), len(channels)))], dim=0)
        
        self.dms_tensor = t
        self.dms_mask = mask

        # self.process_sequences()
        self.dms_seqs, self.dms_feature_labels, self.dms_feature_masks = self.process_sequences()
    
    def build_dataset(self, desc):
        dataset = []
        for i, seq in enumerate(self.all_seqs):
            dataset.append((str(i), seq, str(i) + " " + desc))
        return dataset

    def process_sequences(self, ):
        # L x 20 sequences!
        # plus L x 20 x F values
        reference_seq = torch.zeros(len(self.ref_seq)).long()
        for i, c in enumerate(self.ref_seq):
            reference_seq[i] = self.vocab.get_idx(c)
        
        all_single_aa_muts = defaultdict(dict)
        for mut, pos, eff, v in self.dms_raw_data:
            if eff not in self.channels_map:
                continue

            new_string = list(self.ref_seq)
            new_string[int(pos)] = mut
            new_string = "".join(new_string)

            all_single_aa_muts[new_string][eff] = float(v)
            if eff in all_single_aa_muts[new_string]:
                assert all_single_aa_muts[new_string][eff] == float(v)
            # x, p, c, v = self.vocab.get_idx(mut), int(pos), self.channels_map[eff], float(v)
            # all_single_aa_muts.add((p, x))

        all_seqs_string = list(all_single_aa_muts.keys())
        all_seqs = []
        for seq in all_seqs_string:
            all_seqs.append(self.vocab.encode_line(seq))
        all_seqs = torch.stack(all_seqs, dim=0)
        
        if self.prepend_bos:
            bos = all_seqs.new_zeros(all_seqs.size(0), 1) + self.vocab.bos()
            all_seqs = torch.cat([bos, all_seqs], dim=1)
        if self.append_eos:
            eos = all_seqs.new_zeros(all_seqs.size(0), 1) + self.vocab.eos()
            all_seqs = torch.cat([all_seqs, eos], dim=1)

        feature_labels = []
        feature_masks = []
        for feature in self.channels:
            labels = torch.zeros(len(all_seqs))
            masks = torch.zeros(len(all_seqs)).bool()
            for i, string in enumerate(all_seqs_string):
                if feature in all_single_aa_muts[string]:
                    labels[i] = all_single_aa_muts[string][feature]
                    masks[i] = True
            feature_labels.append(labels)
            feature_masks.append(masks)

        feature_labels = torch.stack(feature_labels, dim=1)
        feature_masks = torch.stack(feature_masks, dim=1)

        return all_seqs, feature_labels, feature_masks

    def read(self, path):
        data = []
        with open(path, newline='') as csvfile:
            spamreader = csv.reader(csvfile, delimiter=',')
            for i, row in enumerate(spamreader):
                if i == 0:
                    continue
                # Mutant,Index,Effect,Value
                mut, idx, eff, v = row
                data.append((mut, idx, eff, v))
        return data


class TemporalDMSDataset(torch.utils.data.Dataset):
    def __init__(self, dms_seqs, vocab, forcasting_time=0) -> None:
        super().__init__()
        self.src_dataset = dms_seqs
        self.vocab = vocab
        self.forcasting_time = forcasting_time
        
    def __len__(self, ):
        return len(self.src_dataset)
        
    def __getitem__(self, index):
        return {
            "index": index,
            "attention_mask": (self.src_dataset[index] != self.vocab.pad()).float(),
            "input_ids": self.src_dataset[index],
            "labels": self.src_dataset[index],
            "input_time": float(self.forcasting_time), # float(src[-1].split()[-1])
        }