import torch
import torch.nn.functional as F
import transformers
import pandas as pd

from .utils import Timing


class Generator:
    def __init__(self, draft_model, target_model, max_cache_len=2048):
        self.draft_model = draft_model
        self.target_model = target_model
        self.draft_cache = transformers.StaticCache(
            self.draft_model.config, max_batch_size=1, max_cache_len=max_cache_len, device=self.draft_model.device, dtype=self.draft_model.dtype
        )
        self.target_cache = transformers.StaticCache(
            self.target_model.config, max_batch_size=1, max_cache_len=max_cache_len, device=self.target_model.device, dtype=self.target_model.dtype
        )
        self.draft_reencode_cache = transformers.StaticCache(
            self.draft_model.config, max_batch_size=1, max_cache_len=max_cache_len, device=self.draft_model.device, dtype=self.draft_model.dtype
        )
        self.eot_ids = torch.tensor(draft_model.config.eos_token_id).to(self.draft_model.device)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.draft_model.config._name_or_path}, {self.target_model.config._name_or_path})"

    def generate(self, *args, **kwargs):
        raise NotImplementedError


class SpecGeneratorGreedy(Generator):
    def __init__(self, draft_model, target_model, max_cache_len=2048, max_prob_threshold=0.5):
        self.max_prob_threshold = max_prob_threshold
        super().__init__(draft_model, target_model, max_cache_len)

    def __repr__(self):
        return super().__repr__() + f"  max_prob_threshold={self.max_prob_threshold:.5f}"

    @torch.inference_mode()
    def generate(self, input_ids, max_new_tokens):
        max_prob_threshold = self.max_prob_threshold
        draft_model = self.draft_model
        target_model = self.target_model

        self.time_draft, self.time_target = 0.0, 0.0  # timers for draft and target model
        result = F.pad(input_ids.clone(), (0, max_new_tokens), value=0)
        original_input_len = input_ids.shape[-1]
        generated_tokens = []
        draft_seen_tokens = 0
        target_seen_tokens = 0
        self.target_calls = 0
        self.draft_calls = 0
        self.log = []  # saves key stats during generation steps

        for cursor in range(original_input_len, original_input_len + max_new_tokens):
            draft_input_ids = result[:, draft_seen_tokens:cursor]
            draft_attention_mask = torch.ones_like(draft_input_ids)
            draft_position_ids = torch.arange(draft_seen_tokens, draft_seen_tokens + draft_input_ids.shape[1]).reshape(1, -1).to(draft_model.device)
            draft_cache_position = torch.arange(draft_seen_tokens, draft_seen_tokens + draft_input_ids.shape[1]).to(draft_model.device)
            with Timing(synchronize=True) as t:
                draft_outputs = draft_model(
                    input_ids=draft_input_ids,
                    attention_mask=draft_attention_mask,
                    position_ids=draft_position_ids,
                    past_key_values=self.draft_cache,
                    cache_position=draft_cache_position,
                    use_cache=True,
                    num_logits_to_keep=1,
                )
            self.time_draft += t.elapsed
            self.draft_calls += 1
            draft_seen_tokens += draft_input_ids.shape[1]

            next_token_p = draft_outputs.logits.softmax(dim=-1)
            max_prob = next_token_p.max(dim=-1).values[:, -1]

            if max_prob.item() >= max_prob_threshold:
                argmax_token = draft_outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1)
            else:
                with Timing(synchronize=True) as t:
                    target_input_ids = result[:, target_seen_tokens:cursor]
                    target_attention_mask = torch.ones_like(target_input_ids)
                    target_position_ids = (
                        torch.arange(target_seen_tokens, target_seen_tokens + target_input_ids.shape[1]).reshape(1, -1).to(target_model.device)
                    )
                    target_cache_position = torch.arange(target_seen_tokens, target_seen_tokens + target_input_ids.shape[1]).to(target_model.device)
                    target_outputs = target_model(
                        input_ids=target_input_ids,
                        attention_mask=target_attention_mask,
                        position_ids=target_position_ids,
                        output_hidden_states=True,
                        past_key_values=self.target_cache,
                        cache_position=target_cache_position,
                        use_cache=True,
                        num_logits_to_keep=1,
                    )
                self.time_target += t.elapsed
                self.target_calls += 1
                target_seen_tokens += target_input_ids.shape[1]
                argmax_token = target_outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1)

            # cursor += 1
            result[:, cursor] = argmax_token

            self.log.append(
                dict(
                    draft_token=draft_outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1),
                    draft_prob=max_prob,
                    generated_token=argmax_token,
                )
            )

            if argmax_token.item() in draft_model.config.eos_token_id:
                break

        generated_tokens = result[:, original_input_len : cursor + 1]
        return generated_tokens


class SpecGeneratorB(Generator):
    def __init__(self, draft_model, target_model, max_cache_len=2048, min_prob_threshold=0.0, window=8, max_target_calls=200):
        super().__init__(draft_model, target_model, max_cache_len)
        self.min_prob_threshold = min_prob_threshold
        self.window = window
        self.max_target_calls = max_target_calls

    def __repr__(self):
        return super().__repr__() + f"  min_prob_threshold={self.min_prob_threshold:.5f}, window={self.window}"

    @torch.inference_mode()
    def generate(self, input_ids, max_new_tokens):
        """assisted generation with smart lossy speculative decoding
        Args:
            min_prob_threshold (_type_, optional): _description_. Defaults to 0..
            window (int, optional): number of draft tokens per iteration. Defaults to 8.
        Returns:
            torch.Tensor [1, generated_length]: generation result
        """
        min_prob_threshold = self.min_prob_threshold
        window = self.window
        max_target_calls = self.max_target_calls
        draft_model = self.draft_model
        target_model = self.target_model

        self.log = []  # saves key stats during generation steps
        self.time_draft, self.time_target = 0.0, 0.0  # timers for draft and target model

        result = F.pad(input_ids[0], (0, max_new_tokens + window), value=0)  # 1d tensor to store prompta + generation
        original_input_len = input_ids.shape[-1]
        draft_seen_tokens = 0
        target_seen_tokens = 0
        self.draft_calls = 0
        self.target_calls = 0

        cursor = original_input_len  # tracks next insertion position for draft model
        cursor_t = original_input_len  # tracks next insertion position for tartget model  # TODO consider combining the cursors
        draft_probas_all = torch.ones_like(result).to(self.draft_model.dtype)

        while True:

            # DRAFTING STEP
            for draft_cnt in range(window):  # TODO: consider combinig draft_cnt and cursor into one variable: for cursor in range(cursor: cursor + window)
                draft_input_ids = result[draft_seen_tokens:cursor].unsqueeze(0)
                draft_attention_mask = torch.ones_like(draft_input_ids)
                draft_position_ids = torch.arange(draft_seen_tokens, draft_seen_tokens + draft_input_ids.shape[1]).reshape(1, -1).to(draft_model.device)
                draft_cache_position = torch.arange(draft_seen_tokens, draft_seen_tokens + draft_input_ids.shape[1]).to(draft_model.device)
                with Timing(synchronize=True) as t:
                    draft_outputs = draft_model(
                        input_ids=draft_input_ids,
                        attention_mask=draft_attention_mask,
                        position_ids=draft_position_ids,
                        past_key_values=self.draft_cache,
                        cache_position=draft_cache_position,
                        use_cache=True,
                        num_logits_to_keep=1,
                    )
                self.time_draft += t.elapsed
                self.draft_calls += 1
                draft_seen_tokens += draft_input_ids.shape[1]

                next_token_p = draft_outputs.logits.softmax(dim=-1)
                max_prob = next_token_p.max(dim=-1).values[:, -1]

                # if max_prob.item() < min_prob_threshold:  # TODO consider early draftign breaks / dynamic window sizing
                #     break

                argmax_token = draft_outputs.logits[:, -1].argmax(dim=-1)  # .unsqueeze(-1)
                result[cursor] = argmax_token
                draft_probas_all[cursor] = max_prob
                cursor += 1

            # VALIDATION STEP
            target_input_ids = result[target_seen_tokens:cursor].view(1, -1)
            target_attention_mask = torch.ones_like(target_input_ids)
            target_position_ids = torch.arange(target_seen_tokens, target_seen_tokens + target_input_ids.shape[1]).reshape(1, -1).to(target_model.device)
            target_cache_position = torch.arange(target_seen_tokens, target_seen_tokens + target_input_ids.shape[1]).to(target_model.device)
            with Timing(synchronize=True) as t:
                target_outputs = target_model(
                    input_ids=target_input_ids,
                    attention_mask=target_attention_mask,
                    position_ids=target_position_ids,
                    output_hidden_states=True,
                    past_key_values=self.target_cache,
                    cache_position=target_cache_position,
                    use_cache=True,
                    num_logits_to_keep=cursor - cursor_t + 1,
                )
            self.time_target += t.elapsed
            self.target_calls += 1
            target_seen_tokens += target_input_ids.shape[1]

            if cursor_t == cursor:
                num_accepted = 0
            else:
                # matching
                draft_tokens = result[cursor_t:cursor]
                draft_probas = draft_probas_all[cursor_t:cursor]
                target_softmax = target_outputs.logits.softmax(dim=-1)
                target_probas = target_softmax[0, torch.arange(cursor - cursor_t), draft_tokens]
                target_max = target_softmax.max(dim=-1).values[0]
                target_argmax = target_softmax.max(dim=-1).indices[0]
                adjusted_probas = (target_probas / target_max[:-1]).clip(0, 1)  # TODO explore smarter ways
                # print(adjusted_probas_1)
                accept_mask = adjusted_probas >= min_prob_threshold
                num_accepted = accept_mask.shape[-1] if accept_mask.all() else torch.nonzero(~accept_mask, as_tuple=True)[0][0].item()



            # log before result adjustment
            self.log.append(
                dict(
                    draft_tokens=draft_tokens.cpu(),
                    draft_probas=draft_probas.cpu(),
                    target_probas=target_probas.cpu(),
                    adjusted_probas=adjusted_probas.cpu(),
                    target_argmax=target_argmax.cpu(),
                    target_max=target_max.cpu(),
                    accept_mask=accept_mask.cpu(),
                    num_accepted=num_accepted + 1,
                )
            )

            # adjust cursors
            cursor_t += num_accepted
            target_seen_tokens += min(0, num_accepted - draft_cnt - 1)
            cursor += min(0, num_accepted - draft_cnt - 1)
            draft_seen_tokens += min(0, num_accepted - draft_cnt)

            result[cursor_t] = target_softmax[0, num_accepted].argmax(dim=-1)  # extra token from the target model
            cursor_t += 1
            cursor += 1

            # check stopping conditions
            eot_position = torch.nonzero(torch.isin(result[original_input_len:cursor], self.eot_ids), as_tuple=True)[0]
            if torch.any(eot_position):
                generated_tokens = result[original_input_len : original_input_len + eot_position[0]]
                break
            if cursor_t - original_input_len >= max_new_tokens:
                generated_tokens = result[original_input_len : original_input_len + max_new_tokens]
                break

            if self.target_calls > max_target_calls:
                generated_tokens = result[original_input_len:cursor_t]
                break

        return generated_tokens.reshape(1, -1)