import torch
from typing import List, Optional, Dict
import numpy as np


def reverse_complement(sequence, seq_type="dna"):
    if seq_type == "dna":
        return sequence[::-1].translate(str.maketrans("ACGT", "TGCA"))
    elif seq_type == "rna":
        return sequence[::-1].translate(str.maketrans("ACGU", "UGCA"))
    else:
        raise ValueError(f"Invalid sequence type: {seq_type}")


def logits_to_logprobs(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
) -> torch.Tensor:
    """
    Takes in a tensor of logits of dimension (batch, length, vocab).
    Computes the log-likelihoods using a softmax along the vocab dimension.
    Uses the `input_ids` to index into the log-likelihoods and returns the likelihood
    of the provided sequence at each position with dimension (batch, length).
    """
    input_ids = input_ids[1:].clone()
    L = len(input_ids)

    if len(logits.shape) == 2:
        logits = logits.unsqueeze(0)
    elif len(logits.shape) == 1:
        logits = logits.reshape(L + 1, -1).unsqueeze(0)

    if L == 0:
        n_vocab = logits.shape[-1]
        softmax_logprobs = torch.tensor([1 / n_vocab]).log()
        return softmax_logprobs

    softmax_logprobs = torch.log_softmax(logits, dim=-1)
    softmax_logprobs = softmax_logprobs[:, :-1]
    assert softmax_logprobs.shape[1] == L

    logprobs = torch.gather(
        softmax_logprobs,  # Gather likelihoods...
        2,  # along the vocab dimension...
        input_ids.unsqueeze(0).unsqueeze(-1),  # using the token ids to index.
    ).squeeze(-1)

    return logprobs


def compute_seq_score(
    seq: str,
    logits: torch.Tensor,
    tokenizer: object,
) -> np.ndarray:
    """
    Compute sequence score from logits.
    """
    logits = logits.squeeze(0)
    assert logits.ndim == 2, f"logits dimension should be 2, but got {logits.ndim}"

    if isinstance(seq, str):
        input_ids = tokenizer.encode(seq, add_special_tokens=False)
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        assert len(seq) % input_ids.shape[0] == 0, (
            f"Sequence length {len(seq)} is not divisible by input_ids length {input_ids.shape[0]}"
        )
    else:
        input_ids = seq

    logprobs = logits_to_logprobs(logits, input_ids)
    logprobs = logprobs.float().cpu().numpy().flatten()

    return np.mean(logprobs), logprobs.shape[0]


def left_padding(sequence, padding_char="A", multiple=6):
    remainder = len(sequence) % multiple
    if remainder != 0:
        padding_length = multiple - remainder
        return padding_char * padding_length + sequence
    return sequence


def left_truncation(sequence, multiple=6):
    remainder = len(sequence) % multiple
    if remainder != 0:
        return sequence[remainder:]
    return sequence


def tackle_sequences(sequences: List[str], tokenizer_tackle: str, kmer: int):
    if kmer == 1:
        return sequences

    if tokenizer_tackle == "padding":
        sequences = [left_padding(sequence, multiple=kmer) for sequence in sequences]
    elif tokenizer_tackle == "truncation":
        sequences = [left_truncation(sequence, multiple=kmer) for sequence in sequences]
    else:
        raise ValueError(f"Invalid tokenizer_tackle: {tokenizer_tackle}")

    L = len(sequences[0])
    assert L % kmer == 0, f"Sequence length mismatch: {L} % {kmer} != 0"

    return sequences


class BaseModel:
    def __init__(self, cfg, load_model=True):
        self.cfg = cfg
        # External embedding/sample caches removed

        self.seed = getattr(cfg, "seed", 42)
        self.kmer = getattr(cfg, "kmer", 1)
        self.tokenizer_tackle = getattr(cfg, "tokenizer_tackle", "padding")

        if load_model:
            self.load_model()

    def load_model(self):
        """Hook: Load model and tokenizer"""
        raise NotImplementedError

    def before_score_sequences(
        self, sequences: List[str]
    ):
        """Hook: Process sequences before scoring"""
        return sequences

    def before_generate_sequences(
        self, sequences: List[str]
    ):
        """Hook: Process sequences before generation"""
        return sequences

    def generate_sequences(
        self,
    ) -> List[str]:
        """Hook: Generate sequences"""
        raise NotImplementedError

    def forward_generation(
        self,
        sequences: List[str],
        new_length: int = 0,
        **kwargs,
    ) -> Optional[List[str]]:
        """Generate sequences using forward generation method"""
        if new_length // self.kmer == 0:
            return sequences

        # Generation settings
        with_original = kwargs.pop("with_original", False)
        temperature = kwargs.pop("temperature", 1.0)
        top_k = kwargs.pop("top_k", 0)
        top_p = kwargs.pop("top_p", 1.0)

        # Process sequences before generation
        sequences = self.before_generate_sequences(sequences)

        # Directly generate without external caching
        generated_seqs = self.generate_sequences(
            sequences=sequences,
            num_tokens=new_length // self.kmer,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            **kwargs,
        )

        # Combine results
        result = []
        for i, seq in enumerate(sequences):
            gen = generated_seqs[i]
            if with_original:
                result.append(seq + gen)
            else:
                result.append(gen)

        return result

    def reverse_generation(
        self, sequences: List[str], new_length: int = 0, **kwargs
    ) -> Optional[Dict]:
        """
        Generate sequences using the reverse generation method.
        """
        if new_length == 0:
            return sequences

        # Get the with_original flag
        with_original = kwargs.pop("with_original", False)

        # Step 1: Get reverse complements of input sequences, bs x len(seq)
        rc_sequences = [reverse_complement(sequence) for sequence in sequences]

        # Step 2: Generate from reverse complements to expand context, bs x new_length
        rc_generated = self.forward_generation(
            sequences=rc_sequences,
            new_length=new_length,
            **kwargs,
        )

        # Step 3: Get reverse complements of the generated context, bs x new_length
        expanded_context = [reverse_complement(gen) for gen in rc_generated]

        if with_original:
            result = [gen + seq for gen, seq in zip(expanded_context, sequences)]
        else:
            result = expanded_context

        return result

    def base_prompt_generation(
        self,
        sequences: List[str],
        new_length: int = 0,
        **kwargs,
    ):
        """
        First expand context by generating from reverse complement, then use the expanded context to generate forward.

        Args:
            sequences: Input DNA sequences
            new_length: Length of new sequence to generate

        Returns:
            List of generated sequences
        """
        _kwargs = kwargs.copy()

        # For backward compatibility
        extra_length = _kwargs.pop("extra_length")
        if extra_length == -1:
            extra_length = new_length

        if new_length == 0 and extra_length == 0:
            return sequences

        with_original = _kwargs.pop("with_original", True)
        with_expanded = _kwargs.pop("with_expanded", True)
        assert with_original - with_expanded >= 0, (
            "with_expanded is only allowed when with_original is True"
        )

        # Step 1: Generate from reverse complements to expand context, bs x extra_length
        expanded_context = self.reverse_generation(
            sequences=sequences,
            new_length=extra_length,
            **_kwargs,
        )

        # Step 2: Combine original sequences with the expanded context, bs x (extra_length + len(seq))
        enhanced_sequences = [
            context + seq for context, seq in zip(expanded_context, sequences)
        ]

        # Step 3: Generate forward from the enhanced sequences, bs x new_length
        result = self.forward_generation(
            sequences=enhanced_sequences,
            new_length=new_length,
            **_kwargs,
        )

        if with_original:
            if with_expanded:
                result = [
                    context + gen for context, gen in zip(enhanced_sequences, result)
                ]
            else:
                result = [seq + gen for seq, gen in zip(sequences, result)]

        return result

    def seq2token(self, sequences, tokenizer, device):
        """Default seq2token implementation"""
        inputs = tokenizer(
            sequences,
            add_special_tokens=False,
            return_tensors="pt",
        ).to(device)
        return inputs

    def get_logits(self, sequences: List[str]):
        # Add bos token only for logits computation
        sequences_with_bos = [
            self.tokenizer.bos_token + sequence for sequence in sequences
        ]
        inputs = self.seq2token(
            sequences_with_bos, self.tokenizer, getattr(self.model, "device", "cuda")
        )
        with torch.inference_mode():
            outputs = self.model(**inputs)
            logits = outputs.logits[:, 1:].cpu()  # remove the bos token

        return {seq: logit for seq, logit in zip(sequences, logits)}

    def score_sequences(
        self,
        sequences: List[str],
    ) -> Optional[List[float]]:
        if isinstance(sequences, str):
            sequences = [sequences]
        sequences = self.before_score_sequences(sequences)
        
        # Directly compute logits for all sequences (no external cache)
        all_embeds = self.get_logits(sequences)
        scores, lens_logits = [], []
        for i, seq in enumerate(sequences):
            logit = all_embeds[seq]
            score, len_logits = compute_seq_score(
                seq,
                logit,
                self.tokenizer,
            )
            scores.append(score)
            lens_logits.append(len_logits)

        return scores, lens_logits
