import os
import torch
import random
import numpy as np
from typing import List, Tuple
from torch.utils.data import Dataset


CODON2AA = {
    'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
    'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
    'TAT': 'Y', 'TAC': 'Y', 'TAA': '*', 'TAG': '*',
    'TGT': 'C', 'TGC': 'C', 'TGA': '*', 'TGG': 'W',
    'CTT': 'L', 'CTC': 'L', 'CTA': 'L', 'CTG': 'L',
    'CCT': 'P', 'CCC': 'P', 'CCA': 'P', 'CCG': 'P',
    'CAT': 'H', 'CAC': 'H', 'CAA': 'Q', 'CAG': 'Q',
    'CGT': 'R', 'CGC': 'R', 'CGA': 'R', 'CGG': 'R',
    'ATT': 'I', 'ATC': 'I', 'ATA': 'I', 'ATG': 'M',
    'ACT': 'T', 'ACC': 'T', 'ACA': 'T', 'ACG': 'T',
    'AAT': 'N', 'AAC': 'N', 'AAA': 'K', 'AAG': 'K',
    'AGT': 'S', 'AGC': 'S', 'AGA': 'R', 'AGG': 'R',
    'GTT': 'V', 'GTC': 'V', 'GTA': 'V', 'GTG': 'V',
    'GCT': 'A', 'GCC': 'A', 'GCA': 'A', 'GCG': 'A',
    'GAT': 'D', 'GAC': 'D', 'GAA': 'E', 'GAG': 'E',
    'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
}


def DNA2AA(DNA: str):
    assert len(DNA) % 3 == 0, f'Invalid DNA length: {len(DNA)}'
    return ''.join(CODON2AA[DNA[i:i+3]] for i in range(0, len(DNA), 3))


def DNA2RNA(DNA: str):
    return DNA.replace('T', 'U')


def load_embedding(seqs: List[str], folder: str):
    embeds = []
    for seq in seqs:
        aa = DNA2AA(seq)
        file_path = os.path.join(folder, f"{aa}.npy")
        assert os.path.exists(file_path), f'File not found: {file_path}'
        embeds.append(np.load(file_path))
    return torch.tensor(np.array(embeds)).squeeze(1)


class RankingDataset(Dataset):
    def __init__(self, root: str, k: int, test_mode: bool = False, shuffle: bool = False):
        folder = os.path.join(root, str(k))
        
        if test_mode:
            files_label = [f for f in os.listdir(os.path.join(folder, 'test_label')) if f.endswith('.txt')]
            files_data = [f for f in os.listdir(os.path.join(folder, 'test_data')) if f.endswith('.txt')]
            assert len(files_label) == len(files_data), 'The number of test data and test label should be the same'
            
            self.data = [self.file2seqs(os.path.join(folder, 'test_data', i)) for i in files_label]
            labels = [self.file2seqs(os.path.join(folder, 'test_label', i)) for i in files_label]
            self.labels = [tuple(labels[i].index(self.data[i][j]) for j in range(k)) for i in range(len(labels))]
        else:
            files_train = [f for f in os.listdir(os.path.join(folder, 'train')) if f.endswith('.txt')]
            self.data = [self.file2seqs(os.path.join(folder, 'train', i)) for i in files_train]
        
        self.num_samples = len(self.data)
        self.test_mode = test_mode
        self.folder = folder
        self.shuffle = shuffle
    
    def file2seqs(self, file_path: str):
        with open(file_path, 'r') as f:
            return tuple([line.strip() for line in f.readlines() if line.strip()])
    
    def __len__(self):
        return self.num_samples

    def sample(self, idx):
        # TODO: more sampling strategies
        data_i = list(self.data[idx])
        if self.shuffle:
            random.shuffle(data_i)
        labels = tuple(self.data[idx].index(i) for i in data_i)
        return tuple(data_i), labels

    def __getitem__(self, idx) -> Tuple[Tuple[str], Tuple[str]]:
        if self.test_mode:
            return self.data[idx], self.labels[idx]
        else:
            return self.sample(idx)


def modaility_map(seq_type: str, seq: str):
    if seq_type == 'AA':
        seq = DNA2AA(seq)
    if seq_type == 'RNA':
        seq = DNA2RNA(seq)
    return seq


class EmbedMapper:
    def __init__(self, root: str, model_name: str, seq_type: str):
        index_path = os.path.join(root, f'{seq_type}_{model_name}_index.csv')
        self.seq_type = seq_type
        
        self.embed_dict = {}
        with open(index_path, 'r') as f:
            for line in f.readlines():
                seq, embed_path = line.strip().split(',')
                self.embed_dict[seq] = os.path.join(root, f'{seq_type}_{model_name}',embed_path)
    
    def __getitem__(self, seqs: List[Tuple[str]]):
        assert isinstance(seqs, list) and isinstance(seqs[0], tuple) and isinstance(seqs[0][0], str), 'Invalid seqs type'
        k, bs = len(seqs), len(seqs[0])
        embeds = [[] for _ in range(bs)]
        for ki in seqs:
            for b, seq in enumerate(ki):
                seq = modaility_map(self.seq_type, seq)
                if seq not in self.embed_dict:
                    raise ValueError(f"Sequence {seq} not found in the embedding dictionary")
                np_embed = np.load(self.embed_dict[seq])
                if len(np_embed.shape) != 1:
                    np_embed = np_embed.reshape(-1)
                embeds[b].append(np_embed)
        return torch.tensor(np.array(embeds))
