import os
import pandas as pd
import re
from tqdm import tqdm
from collections import Counter
from torch.utils.data import Dataset


def get_rna_type_from_name(name: str) -> str:
    """Extract RNA type from dataset name."""
    # This is a simplified mapping - you may need to adjust based on actual naming convention
    name_lower = name.lower()
    if "mrna" in name_lower or "splice" in name_lower or "splicing" in name_lower:
        return "mRNA-splicing"
    elif "trna" in name_lower:
        return "tRNA"
    elif "aptamer" in name_lower:
        return "Aptamer"
    elif "ribozyme" in name_lower:
        return "Ribozyme"
    else:
        return "mRNA-coding"


class RNAGymDataset(Dataset):
    def __init__(self, path, seq_type="dna", num_samples=None):
        self.num_samples = num_samples
        self.seq_type = seq_type
        self.path = path
        self.data = self._load_data(path)

    def _load_data(self, path):
        csv_files = []
        if path.endswith(".csv"):
            csv_files.append(path)
        else:
            for root, _, files in os.walk(path):
                for file in files:
                    if file.endswith(".csv"):
                        csv_files.append(os.path.join(root, file))

        # Sort files to ensure consistent order
        cnts = Counter(get_rna_type_from_name(x) for x in csv_files)
        rnatypes_ordered = [item[0] for item in sorted(list(cnts.items()), key=lambda x: x[1])]
        get_order_by_rna_type_number = {k: chr(ord('A')+i) for i, k in enumerate(rnatypes_ordered)}
        csv_files.sort(key=lambda x: get_order_by_rna_type_number[get_rna_type_from_name(x)] + x)

        data_info = dict()
        result = []
        self.origins = dict()
        for file in tqdm(csv_files, desc="Loading RNAGym", dynamic_ncols=True):
            protein = os.path.basename(file).split(".")[0]
            csv = pd.read_csv(file, header=0)

            # Limit samples if num_samples is specified
            if self.num_samples is not None:
                if isinstance(self.num_samples, int):
                    csv = csv.head(self.num_samples)
                else:
                    num = int(self.num_samples * csv.shape[0])
                    csv = csv.head(num)

            length = len(csv["sequence"].iloc[0])
            for i in range(len(csv)):
                sequence = csv["sequence"].iloc[i]
                # Derive original (unmutated) sequence (origin)
                origin = self._extract_origin(sequence, csv["mutant"].iloc[i])
                result.append(
                    [
                        protein,
                        csv["mutant"].iloc[i],
                        csv["DMS_score"].iloc[i],
                    ]
                )
                if protein not in self.origins:
                    self.origins[protein] = origin
                else:
                    if self.origins[protein] != origin:
                        print(
                            f"Warning: inconsistent origins for {protein} - using first occurrence"
                        )
                        # Use the first origin sequence encountered
            data_info[protein] = f"{len(csv)} seqs with length {length}"
            if self.num_samples is not None:
                data_info[protein] += f" (limited to {self.num_samples} samples)"
        print("-" * 5, "Data Info", "-" * 5)
        print("\n".join([f"{k}: {v}" for k, v in data_info.items()]))
        self.data_info = data_info
        return result

    def _extract_origin(self, mutated_sequence: str, mutant_info: str) -> str:
        """
        Reconstruct the original (unmutated) sequence from a mutated sequence.

        Args:
            mutated_sequence: The mutated sequence
            mutant_info: Mutation string, e.g., "C46A,C49A" or "C46A"

        Returns:
            origin: The reconstructed original sequence
        """
        if pd.isna(mutant_info):
            return mutated_sequence

        mutant_info = mutant_info.strip('"').split(",")

        # Create a mutable copy and reverse-apply each mutation
        origin = list(mutated_sequence)
        
        for mutation in mutant_info:
            if not mutation.strip():
                continue

            # Parse mutation like "C46A" -> position 46, from C to A
            match = re.match(r"([ACGTUN]?)(\d+)([ACGTUN]?)", mutation.strip())
            if not match:
                print(f"Warning: cannot parse mutation '{mutation}', skipping")
                continue

            original_base, position, mutated_base = match.groups()
            position = int(position) - 1  # 0-based index

            # Validate and reverse-apply
            if position < len(origin) and origin[position] == mutated_base:
                # Replace mutated base with original base
                origin[position] = original_base
            else:
                actual = origin[position] if position < len(origin) else 'N/A'
                print(
                    f"Warning: mutation '{mutation}' mismatches at position {position + 1} (expected: {mutated_base}, actual: {actual})"
                )

        return "".join(origin)

    def _generate_mutated(self, origin: str, mutant_info: str) -> str:
        """Apply mutations onto the origin to generate the mutated sequence.

        Args:
            origin: The original (unmutated) sequence
            mutant_info: Mutation string like "C46A,C49A" or "C46A"
        Returns:
            mutated_sequence: The sequence after applying mutations
        """
        # If empty/invalid, return origin directly
        if pd.isna(mutant_info) or not mutant_info:
            return origin
        # Normalize and split
        mutant_info_list = str(mutant_info).strip('"').split(",")
        # Convert to list for in-place edits
        mutated = list(origin)
        # Apply each mutation
        for mutation in mutant_info_list:
            mutation = mutation.strip()
            if not mutation:
                continue
            # Parse mutation like "C46A" -> original base C, position 46, mutated base A
            match = re.match(r"([ACGTUN])(\d+)([ACGTUN])", mutation)
            if not match:
                print(f"Warning: cannot parse mutation '{mutation}', skipped.")
                continue
            original_base, position_str, mutated_base = match.groups()
            position = int(position_str) - 1  # 0-based index
            # Validation
            if not (0 <= position < len(mutated)):
                print(f"Warning: mutation '{mutation}' position {position + 1} exceeds sequence length {len(mutated)}, skipped.")
                continue
            # Check base matches the expected original base
            if mutated[position] != original_base:
                print(
                    f"Warning: mutation '{mutation}' mismatch at position {position + 1}: expected '{original_base}', got '{mutated[position]}'"
                )
                continue
            
            # Apply mutation
            mutated[position] = mutated_base
        return "".join(mutated)

    def __getitem__(self, idx):
        protein, mutant_info, label = self.data[idx]
        mutated_sequence = self._generate_mutated(self.origins[protein], mutant_info)
        if self.seq_type == "dna":
            mutated_sequence = mutated_sequence.replace("U", "T")
        return [protein, mutated_sequence, label]

    def __len__(self):
        return len(self.data)
