import sys
import io
import base64
import zipfile
import numpy as np
from typing import Optional, List, Tuple, Union
import torch
from ._base_model import BaseModel


try:
    from evo2.models import Evo2
    from vortex.model.sample import sample
    from vortex.model.generation import logits_to_logprobs, GenerationOutput
    from vortex.model.utils import print_rank_0
except ImportError:

    class Evo2:
        pass

    print("The package evo2 is not installed. Please install from https://github.com/ArcInstitute/evo2?tab=readme-ov-file#setup.")


def debase64(logits_base64: str) -> dict:
    # Base64 decode
    data_base64 = logits_base64.replace("\n", "").replace(" ", "")
    data_bytes = base64.b64decode(data_base64)

    # Unzip the file
    zip_file = zipfile.ZipFile(io.BytesIO(data_bytes))

    # Extract all numpy arrays from the zip file
    name = zip_file.namelist()[0]
    logits = torch.from_numpy(np.load(zip_file.open(name))).to(torch.float32)
    return logits


def prepare_batch(
    seqs: List[str],
    tokenizer: object,
    prepend_bos: bool = False,
    device: str = "cuda:0",
) -> Tuple[torch.Tensor, List[int]]:
    """
    Takes in a list of sequences, tokenizes them, and puts them in a tensor batch.
    If the sequences have differing lengths, then pad up to the maximum sequence length.
    """
    seq_lengths = [len(seq) for seq in seqs]
    max_seq_length = max(seq_lengths)

    input_ids = []
    for seq in seqs:
        padding = [tokenizer.pad_id] * (max_seq_length - len(seq))
        input_ids.append(
            torch.tensor(
                ([tokenizer.eod_id] * int(prepend_bos))
                + tokenizer.tokenize(seq)
                + padding,
                dtype=torch.long,
            )
            .to(device)
            .unsqueeze(0)
        )
    input_ids = torch.cat(input_ids, dim=0)

    return input_ids, seq_lengths


class EVOTokenizer:
    """Character Level Tokenizer"""

    def __init__(self, vocab_size):
        name = "CharLevelTokenizer"
        self.name = name
        self._vocab_size = vocab_size
        self.eod_id = 0
        self.eos_id = 0
        self.pad_id = 1

    def clamp(self, n):
        return max(32, min(n, self.vocab_size))

    @property
    def vocab_size(self):
        return self._vocab_size

    @property
    def vocab(self):
        raise NotImplementedError

    @property
    def inv_vocab(self):
        raise NotImplementedError

    def decode_token(self, token: int):
        return str(chr(self.clamp(token)))

    def tokenize(self, text: str):
        return list(ord(t) for t in text)

    def tokenize_batch(self, text_batch: Union[List[str], str]):
        if isinstance(text_batch, list):
            return [self.tokenize(s) for s in text_batch]
        else:
            return self.tokenize(text_batch)

    def detokenize(self, token_ids):
        return "".join(list(map(self.decode_token, token_ids)))

    def detokenize_batch(self, token_ids: Union[List[str], str]):
        if isinstance(token_ids, list):
            return [self.detokenize(s) for s in token_ids]
        # elif if tensor, convert to list first
        elif isinstance(token_ids, torch.Tensor):
            return [self.detokenize(s) for s in token_ids.tolist()]
        else:
            return self.detokenize(token_ids)

    @property
    def eod(self):
        return self.eod_id

    # duplicate to suppose both names, eos and eod
    @property
    def eos(self):
        return self.eod_id

    def encode(self, seq: str, add_special_tokens: bool = False) -> List[int]:
        return self.tokenize(seq)

    def __call__(
        self,
        sequences: List[str],
        add_special_tokens: bool = False,
        return_tensors: str = "pt",
    ) -> List[int]:
        input_ids = []
        for seq in sequences:
            input_ids.append(self.encode(seq, add_special_tokens))

        if return_tensors == "pt":
            input_ids = torch.tensor(input_ids, dtype=torch.long)
        elif return_tensors == "np":
            input_ids = np.array(input_ids, dtype=np.int32)

        result_dict = {
            "input_ids": input_ids,
        }
        return result_dict


class Generator:
    def __init__(self, model, tokenizer, top_k=50, top_p=0.7, temperature=1):
        self.model = model
        self.tokenizer = tokenizer
        self.top_k = top_k
        self.top_p = top_p
        self.temperature = temperature
        self.untils = ["\n\n"]

    def generate(
        self,
        device: str,
        input_string: str = None,
        input_ids: torch.Tensor = None,
        num_tokens: int = 32,
        cached_generation: bool = True,
        force_prompt_threshold: int = 500,
        print_generation: bool = True,
        verbose: bool = False,
        skip_special_tokens: bool = False,
        stop_at_eos: bool = True,
        max_seqlen: int = None,
        inference_params_dict: dict = None,
        token_callback=None,
    ) -> tuple[torch.Tensor, torch.Tensor, dict]:
        """
        A version of the vortex generate() method that enables passing in and that returns the
        `inference_params_dict` for replaying cached sampling from a given state.
        """
        if isinstance(self.tokenizer.eos, int):
            eos_token_ids = torch.LongTensor([self.tokenizer.eos]).to(device)
        else:
            eos_token_ids = self.tokenizer.tokenize(self.tokenizer.eos).to(device)

        if input_ids is None:
            input = self.tokenizer.tokenize(input_string)
            if isinstance(input, list):
                input = torch.LongTensor(input).unsqueeze(0).to(device)
            else:
                input = input.unsqueeze(0).to(device)
        else:
            input = input_ids
        x = input

        if max_seqlen is not None:
            x = x[:, -max_seqlen:]

        num_tokens = int(num_tokens)
        batch_size = x.shape[0]

        prompt_length = x.shape[1]
        prompt_forcing = prompt_length > force_prompt_threshold
        if prompt_forcing:
            forced_prompt_length = prompt_length - force_prompt_threshold
            x_force = x[:, force_prompt_threshold:]
            x = x[:, :force_prompt_threshold]
        else:
            forced_prompt_length = 0
        tot_length = prompt_length + num_tokens

        generation = torch.empty(
            x.shape[0],
            num_tokens,
            dtype=torch.long,
            device=x.device,
        )

        scores = torch.empty(
            x.shape[0],
            num_tokens,
            self.tokenizer.vocab_size,
            dtype=torch.float,
            device=x.device,
        )

        if inference_params_dict is not None:
            cached_generation = True
            prefilled = True
            # Ensure that the cached data is loaded on the correct device.
            if any(
                data.device != x.device
                for data in inference_params_dict["hcl"].fir_state_dict.values()
            ):
                for key, data in inference_params_dict[
                    "mha"
                ].key_value_memory_dict.items():
                    inference_params_dict["mha"].key_value_memory_dict[key] = data.to(
                        x.device
                    )
                for key, data in inference_params_dict["hcl"].fir_state_dict.items():
                    inference_params_dict["hcl"].fir_state_dict[key] = data.to(x.device)
                for key, data in inference_params_dict["hcl"].state_dict.items():
                    inference_params_dict["hcl"].state_dict[key] = data.to(x.device)
                for key, data in inference_params_dict[
                    "hcm"
                ].fir_inner_state_dict.items():
                    inference_params_dict["hcm"].fir_inner_state_dict[key] = data.to(
                        x.device
                    )
                for key, data in inference_params_dict["hcm"].fir_state_dict.items():
                    inference_params_dict["hcm"].fir_state_dict[key] = data.to(x.device)
                for key, data in inference_params_dict["hcm"].state_dict.items():
                    inference_params_dict["hcm"].state_dict[key] = data.to(x.device)
                for key, data in inference_params_dict["hcs"].fir_state_dict.items():
                    inference_params_dict["hcs"].fir_state_dict[key] = data.to(x.device)
                for key, data in inference_params_dict[
                    "hcs"
                ].fir_inner_state_dict.items():
                    inference_params_dict["hcs"].fir_inner_state_dict[key] = data.to(
                        x.device
                    )
                for key, data in inference_params_dict["hcs"].state_dict.items():
                    inference_params_dict["hcs"].state_dict[key] = data.to(x.device)
            inference_params_dict["mha"].max_batch_size = batch_size
        elif cached_generation:
            try:
                inference_params_dict = self.model.initialize_inference_params(
                    max_seq_len=tot_length
                )
            except Exception:
                inference_params_dict = self.model.initialize_inference_params(
                    max_seqlen=tot_length
                )
            inference_params_dict["mha"].max_batch_size = batch_size
            prefilled = False
        else:
            inference_params_dict = None
            prefilled = False

        if verbose:
            mem_after_tok = torch.cuda.memory_allocated(device=x.device) / 1e9
            print_rank_0(f"Memory after tokenization: {mem_after_tok} GB")
            print_rank_0("Starting generation...")
            if input_string is not None:
                print_rank_0("Prompt: " + input_string)
            else:
                print_rank_0(f"Prompt ids: {input_ids} {input_ids.shape}")

        for i in range(forced_prompt_length + num_tokens):
            post_prefill = prefilled or (cached_generation and i > 0)

            # do forward pass with no gradient
            with torch.inference_mode():
                logits, inference_params_dict = self.model(
                    x,
                    inference_params_dict=inference_params_dict,
                )
                logits = logits

            last_logits = logits[:, -1]

            if prompt_forcing and i < forced_prompt_length:
                new_idx = x_force[:, i]
            else:
                if skip_special_tokens:
                    keep_indices = [ord('A'), ord('T'), ord('C'), ord('G')]
                    keep_indices_tensor = torch.tensor(keep_indices, device=last_logits.device)
                    mask = torch.ones(last_logits.shape[-1], dtype=torch.bool, device=last_logits.device)
                    mask[keep_indices_tensor] = False
                    last_logits_skip = last_logits.clone().detach()
                    last_logits_skip[..., mask] = -float('inf')
                else:
                    last_logits_skip = last_logits

                new_idx = sample(
                    last_logits_skip,
                    top_k=self.top_k,
                    top_p=self.top_p,
                    temperature=self.temperature,
                )

            if stop_at_eos and (generation[0, -1:] == eos_token_ids).all():
                print("Stopping generation at EOS")

            if print_generation and verbose and batch_size == 1:
                print(
                    f"{self.tokenizer.detokenize([new_idx.item()])}",
                    end=" ",
                    flush=True,
                )

            if prompt_forcing:
                if i >= forced_prompt_length:
                    scores[:, i - forced_prompt_length] = last_logits
                    generation[:, i - forced_prompt_length] = new_idx
            else:
                scores[:, i] = last_logits
                generation[:, i] = new_idx

            if post_prefill:
                x = new_idx[:, None]
            else:
                x = torch.cat([x, new_idx[:, None]], dim=-1)

            if token_callback is not None:
                is_break = token_callback(logits)
                if is_break:
                    break

        if verbose:
            y = self.tokenizer.detokenize_batch(generation[:, : i + 1])

            for until in self.untils:
                if until in y:
                    y = y.split(until)[0]
                    break

            print(f"\nInput: {input_string}, Output: {y}")

            mem_end = torch.cuda.memory_allocated(device=x.device) / 1e9
            print(f"Memory after generation: {mem_end} GB")

        return (
            generation[:, : i + 1],
            scores[:, : i + 1],
            inference_params_dict,
        )


class myEvo2(Evo2):
    @torch.no_grad()
    def generate(
        self,
        prompt_seqs: List[str],
        n_tokens: int = 500,
        temperature: float = 1.0,
        top_k: int = 4,
        top_p: float = 1.0,
        batched: bool = True,
        prepend_bos: bool = False,
        verbose: int = 1,
        device: str = "cuda:0",
        **kwargs,
    ) -> Tuple[List[str], List[float]]:
        """
        Generate sequences from a list of prompts.

        force_prompt_threshold: If specified, avoids OOM errors through teacher forcing if the prompt is longer than this threshold.

        If force_prompt_threshold is none, sets default assuming 1xH100 (evo2_7b) and 2xH100 (evo2_40b) to help avoid OOM errors.
        """
        self.model.eval()

        g = Generator(
            self.model,
            self.tokenizer,
            top_k=top_k,
            top_p=top_p,
            temperature=temperature,
        )

        uniform_lengths = all(len(s) == len(prompt_seqs[0]) for s in prompt_seqs)

        if batched and uniform_lengths:
            input_ids_list = [
                prepare_batch(
                    prompt_seqs,
                    self.tokenizer,
                    prepend_bos=prepend_bos,
                    device=device,
                )[0]
            ]
        else:
            sys.stderr.write("WARNING: Batched generation is turned off.\n")
            input_ids_list = [
                prepare_batch(
                    [prompt_seq],
                    self.tokenizer,
                    prepend_bos=prepend_bos,
                    device=device,
                )[0]
                for prompt_seq in prompt_seqs
            ]

        generated_seqs, generated_scores, logitss = [], [], []
        last_k_info = []
        for input_ids in input_ids_list:
            batch_size = input_ids.shape[0]

            output_ids, logits, info = g.generate(
                input_ids=input_ids,
                num_tokens=n_tokens,
                device=device,
                print_generation=(verbose > 1),
                verbose=(verbose > 1),
                stop_at_eos=False,
                **kwargs,
            )

            if verbose > 1:
                print("input_ids.shape", input_ids.shape)
                print("output_ids.shape", output_ids.shape)
                print("logits.shape", logits.shape)

            generated_seqs_batch = list(self.tokenizer.detokenize_batch(output_ids))
            assert len(generated_seqs_batch) == batch_size
            generated_seqs += generated_seqs_batch
            logitss.append(logits)

            logprobs = logits_to_logprobs(logits, output_ids)
            logprobs = logprobs.float().cpu().numpy()

            generated_scores += [np.mean(logprobs[idx]) for idx in range(batch_size)]
            last_k_info += [info]

        assert len(generated_seqs) == len(generated_scores) == len(prompt_seqs)
        if verbose:
            for seq, score, prompt in zip(
                generated_seqs, generated_scores, prompt_seqs
            ):
                print(f'Prompt: "{prompt}",\tOutput: "{seq}",\tScore: {score}')

        return GenerationOutput(
            sequences=generated_seqs, logits=logitss, logprobs_mean=generated_scores
        ), last_k_info


def evo1_logits_to_logprobs(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
    trim_bos: bool = True,
) -> torch.Tensor:
    """
    Takes in a tensor of logits of dimension (batch, length, vocab).
    Computes the log-likelihoods using a softmax along the vocab dimension.
    Uses the `input_ids` to index into the log-likelihoods and returns the likelihood
    of the provided sequence at each position with dimension (batch, length).
    """
    softmax_logprobs = torch.log_softmax(logits, dim=-1)
    if trim_bos:
        softmax_logprobs = softmax_logprobs[:, :-1] # Remove last prediction.
        input_ids = input_ids[:, 1:] # Trim BOS added by tokenizer.
    assert(softmax_logprobs.shape[1] == input_ids.shape[1])

    logprobs = torch.gather(
        softmax_logprobs,       # Gather likelihoods...
        2,                      # along the vocab dimension...
        input_ids.unsqueeze(-1) # using the token ids to index.
    ).squeeze(-1)

    return logprobs


class EVO2(BaseModel):
    """EVO2 model implementation."""

    def __init__(self, cfg):
        self.MODEL_NAME = cfg.model_name  # e.g. "evo2-7b"

        super().__init__(cfg)
        if getattr(cfg, 'evo1_score', False):
            self.score_sequences = self._score_sequences
            self.addbos = getattr(cfg, 'addbos', False)

    def load_model(self):
        """Load the EVO2 model and tokenizer"""
        self.tokenizer = EVOTokenizer(vocab_size=512)
        self.model = myEvo2(self.MODEL_NAME.replace("-", "_"))

    def _score_sequences(
        self,
        sequences: List[str],
    ) -> Optional[List[float]]:
        input_ids, seq_lengths = prepare_batch(
            sequences,
            self.tokenizer,
            prepend_bos=self.addbos,
            device="cuda",
        )
        assert len(seq_lengths) == input_ids.shape[0]
        
        with torch.inference_mode():
            logits = self.model(input_ids)[0][0]
            
        # Calculate log probabilities
        logprobs = evo1_logits_to_logprobs(logits, input_ids, trim_bos=self.addbos).float().cpu().numpy()
        
        scores = list(np.mean(logprobs, axis=1))
        lens_logits = [len(logit) for logit in logits]

        return scores, lens_logits

    def generate_sequences(
        self,
        sequences: List[str],
        num_tokens: int,
        temperature: float,
        top_k: int,
        top_p: float,
        **kwargs,
    ) -> List[str]:
        """Generate sequences"""
        return self._generate(
            sequences, num_tokens, temperature, top_k, top_p, **kwargs
        )

    def _generate(
        self,
        sequences: List[str],
        num_tokens: int,
        temperature: float,
        top_k: int,
        top_p: float,
        **kwargs,
    ) -> List[str]:
        """Generate sequences using the model"""
        output, lask_k_info = self.model.generate(
            prompt_seqs=sequences,
            n_tokens=num_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            batched=True,
            cached_generation=True,
            verbose=False,
            **kwargs,
        )

        return output.sequences

    def get_logits(self, sequences: List[str]):
        """Get logits for scoring"""
        assert isinstance(sequences, list), "sequences must be a list"
        input_ids = [
            self.tokenizer.encode(seq, add_special_tokens=False)
            for seq in sequences
        ]
        input_ids = torch.tensor(input_ids, dtype=torch.long).to("cuda")

        with torch.inference_mode():
            outputs = self.model(input_ids)[0][0]
            logits = outputs.cpu()

        return {seq: logit for seq, logit in zip(sequences, logits)}


class EVO2_7B_BASE(EVO2):
    def __init__(self, cfg):
        cfg.model_name = "evo2-7b-base"
        super().__init__(cfg)


class EVO2_7B(EVO2):
    def __init__(self, cfg):
        cfg.model_name = "evo2-7b"
        super().__init__(cfg)


class EVO2_40B(EVO2):
    def __init__(self, cfg):
        cfg.model_name = "evo2-40b"
        super().__init__(cfg)
