import random
from copy import deepcopy
from queue import PriorityQueue
from typing import Optional, List, Tuple

import torch
import numpy as np

from .kv_cache import FastKVCache as KVCache
from token_module import KVCache as TKVCache
from token_module import Token
from token_module import TreeAttentionManager as TTreeAttentionManager


def reversed_group_indices(arr):
    n = len(arr)
    result = []
    start = 0

    for i in range(1, n + 1):
        if i == n or arr[i] != arr[start]:
            v = list(range(i - 1, start - 1, -1))
            random.shuffle(v)
            result.extend(v)
            start = i

    return result


class TreeAttentionManager(TTreeAttentionManager):
    pass


class EAGLETreeAttentionManager:
    def __init__(
            self,
            base_model,
            draft_model,
            lm_head,
            draft_lm_head,
            temperature=0.0,
            threshold=0.0,  # not used
            top_draft=4,
            top_node=32,
            depth=8,
            top_base=16,
    ):

        assert top_node >= top_draft, "top_node must be greater than or equal to top_draft"

        self.base_model = base_model
        self.draft_model = draft_model
        self.lm_head = lm_head
        self.draft_lm_head = draft_lm_head

        self.temperature = temperature
        self.top_draft = top_draft
        self.top_node = top_node
        self.depth = depth
        self.top_base = top_base

        self.device = base_model.device
        self.dtype = base_model.dtype
        self.draft_dtype = draft_model.dtype
        self.lm_head_dtype = lm_head.weight.dtype

        self.num_hidden_layers = base_model.config.num_hidden_layers
        self.num_key_value_heads = base_model.config.num_key_value_heads
        self.head_dim = base_model.config.head_dim
        self.max_cache_len = base_model.config.max_position_embeddings * 2

        self.base_kvcache_pool = KVCache(
            self.num_hidden_layers,
            self.num_key_value_heads,
            self.max_cache_len,
            self.head_dim,
            self.device,
            self.dtype,
        )

        self.draft_kvcache_pool = KVCache(
            draft_model.config.draft_num_hidden_layers,
            self.num_key_value_heads,
            self.max_cache_len,
            self.head_dim,
            self.device,
            self.draft_dtype,
        )

        self.base_kv_cache_indices = []
        self.draft_kv_cache_indices = []

        self.input_ids = None  # (1, top_draft or 1)
        self.input_features = None  # (1, top_draft, hidden_size) or (1, 1, hidden_size)
        self.position_ids = None  # (1, top_draft or 1)

        self.last_output_id = None  # int
        self.last_base_position_id = None  # int
        self.last_hidden_state = None  # (1, hidden_size)

        self.draft_attention_mask = None  # (batch_size, 1, query_length, key_value_length)
        self.base_attention_mask = None  # (batch_size, 1, query_length, key_value_length)
        self.cumulative_scores = None  # (top_draft or 1)

        self.input_ids_history = []  # List[torch.Tensor (top_draft or 1)]
        self.parent_indices_history = []  # List[torch.Tensor (top_draft or 1)]
        self.cumulative_scores_history = []  # List[torch.Tensor (top_draft or 1)]
        self.base_attention_mask_history = []  # List[torch.Tensor (1, 1, 1 or top_draft, self.max_cache_len)]

        self.first_draft = True

        self.output_buffer = []

    @torch.no_grad()
    def initialize(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor,
    ):
        assert input_ids.size(0) == 1, "Only one input sequence is supported"

        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0)

        base_kv_cache_indices = self.base_kvcache_pool.allocate(input_ids.shape[1])
        use_cache = True

        output = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=self.base_kvcache_pool,
            past_key_value_indices=base_kv_cache_indices,
            use_cache=use_cache,
            output_hidden_states=True,
        )
        hidden_states = output.last_hidden_state
        draft_hidden_states = output.last_hidden_state

        b, t, _ = hidden_states.size()
        logits = self.lm_head(hidden_states[:, -1:, :].to(self.lm_head_dtype))
        b, _, v = logits.size()

        if self.temperature > 0:
            logits = logits / self.temperature
            probs = torch.nn.functional.softmax(logits, dim=-1)
            output_id = torch.multinomial(probs.reshape(b, v), num_samples=1).flatten()[0]
        else:
            output_id = torch.argmax(logits, dim=-1).flatten()[0]

        self.cumulative_scores = torch.zeros(1, device=self.device, dtype=self.dtype)  # (1)
        self.input_ids = torch.cat([input_ids, output_id.view(1, 1)], dim=1)  # (1, ctx_len + 1)
        self.input_features = torch.cat([
            torch.zeros_like(draft_hidden_states[:, :1]),
            draft_hidden_states,
        ], dim=1)  # (1, ctx_len + 1, hidden_size)
        self.position_ids = torch.arange(t + 1, device=input_ids.device).unsqueeze(0)  # (1, ctx_len + 1)

        min_dtype = torch.finfo(self.dtype).min
        causal_mask = torch.full(
            (t + 1, t + 1),
            fill_value=min_dtype,
            dtype=self.dtype,
            device=self.device,
        )
        causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask = causal_mask[None, None, :, :].expand(1, 1, -1, -1)

        causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
        mask_length = attention_mask.shape[-1]
        padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
        padding_mask = padding_mask == 0
        causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
            padding_mask, min_dtype
        )
        self.draft_attention_mask = causal_mask.clone()
        self.base_attention_mask = causal_mask.clone()

        self.input_ids_history.append(output_id.view(1))
        self.parent_indices_history.append(torch.zeros(1, device=self.device, dtype=self.dtype))
        self.cumulative_scores_history.append(torch.zeros(1, device=self.device, dtype=self.dtype))
        self.base_attention_mask_history.append(
            torch.nn.functional.pad(
                self.base_attention_mask[:, :, -1:],  # (1, 1, 1, t + 1)
                (0, self.max_cache_len - t - 1, 0, 0, 0, 0, 0, 0),
                value=torch.finfo(self.dtype).min,
            )
        )

        self.last_output_id = output_id.item()
        self.last_position_id = t - 1

        self.base_kv_cache_indices = base_kv_cache_indices

        self.output_buffer = [output_id.item()]
        self.first_draft = True

    @torch.no_grad()
    # @profile
    def draft_create(self):
        draft_kv_cache_indices = deepcopy(self.draft_kv_cache_indices)
        new_draft_kv_cache_indices = self.draft_kvcache_pool.allocate(self.input_ids.size(1) + (self.depth - 1) * self.top_draft)
        self.draft_kv_cache_indices = draft_kv_cache_indices + new_draft_kv_cache_indices[:self.input_ids.size(1)]
        del_draft_kv_cache_indices = new_draft_kv_cache_indices[self.input_ids.size(1):]

        position_ids = (self.position_ids[0, -1] + 1).item()
        for depth in range(self.depth):
            hidden_states = self.draft_model(
                input_ids=self.input_ids,
                input_features=self.input_features,
                attention_mask=self.draft_attention_mask,
                position_ids=self.position_ids,
                past_key_values=self.draft_kvcache_pool,
                past_key_value_indices=self.draft_kv_cache_indices,
                use_cache=True,
                shift_tokens=False,
                cut_last_token=False,
            )[0]

            if depth == 0:
                self.input_ids = self.input_ids[:, -1:]
                self.draft_attention_mask = self.draft_attention_mask[:, :, -1:]
                self.base_attention_mask = self.base_attention_mask[:, :, -1:]
                self.position_ids = self.position_ids[:, -1:]
                hidden_states = hidden_states[:, -1:, :]

            logits = self.draft_lm_head(hidden_states.to(self.lm_head_dtype)).float()
            logits = torch.nn.functional.softmax(
                logits ,
                dim=-1,
            )
            sampled_m_probs, sampled_m_indices = torch.topk(logits[0], self.top_node, dim=-1)

            cumulative_score = self.cumulative_scores.unsqueeze(-1) + torch.log(sampled_m_probs)
            top_cumulative_score, top_cumulative_indices = torch.topk(cumulative_score.flatten(), self.top_draft)
            self.cumulative_scores = top_cumulative_score.view(self.top_draft)

            self.parent_indices_history.append(top_cumulative_indices // self.top_node)

            self.input_ids_history.append(sampled_m_indices.flatten()[top_cumulative_indices])
            self.cumulative_scores_history.append(self.cumulative_scores)

            self.input_ids = sampled_m_indices.flatten()[top_cumulative_indices].view(1, self.top_draft)
            self.input_features = hidden_states[:, top_cumulative_indices // self.top_node]
            self.position_ids = torch.full((1, self.top_draft), fill_value=self.position_ids[0, -1] + 1, device=self.device, dtype=self.dtype)
            position_ids += 1

            new_base_attention_mask = torch.full(
                (1, 1, self.top_draft, self.top_draft),
                fill_value=torch.finfo(self.dtype).min,
                dtype=self.dtype,
                device=self.device,
            )
            new_base_attention_mask.diagonal(dim1=-2, dim2=-1).fill_(0)

            new_draft_attention_mask = torch.full(
                (1, 1, self.top_draft, self.top_draft),
                fill_value=torch.finfo(self.dtype).min,
                dtype=self.dtype,
                device=self.device,
            )
            new_draft_attention_mask.diagonal(dim1=-2, dim2=-1).fill_(0)

            draft_attention_mask = torch.cat([
                self.draft_attention_mask[:, :, top_cumulative_indices // self.top_node, :],
                new_draft_attention_mask,
            ], dim=-1)
            base_attention_mask = torch.cat([
                self.base_attention_mask[:, :, top_cumulative_indices // self.top_node, :],
                new_base_attention_mask,
            ], dim=-1)
            self.draft_attention_mask = draft_attention_mask
            self.base_attention_mask = base_attention_mask

            self.base_attention_mask_history.append(
                torch.nn.functional.pad(
                    self.base_attention_mask,
                    (0, self.max_cache_len - self.base_attention_mask.size(3), 0, 0, 0, 0, 0, 0),
                    value=torch.finfo(self.dtype).min,
                )
            )

            self.first_draft = False

        self.draft_kvcache_pool.free(del_draft_kv_cache_indices)

    @torch.no_grad()
    # @profile
    def base_check(self):
        input_ids_history = torch.cat(self.input_ids_history, dim=0)
        parent_indices_history = torch.cat(self.parent_indices_history, dim=0)
        cumulative_scores_history = torch.cat(self.cumulative_scores_history, dim=0)
        base_attention_mask_history = torch.cat(self.base_attention_mask_history, dim=2)

        top_base = min(self.top_base, input_ids_history.size(0))
        top_cumulative_score, top_cumulative_indices = torch.topk(cumulative_scores_history, top_base)
        set_top_cumulative_indices = set(top_cumulative_indices.tolist())
        input_mask = torch.tensor([i in set_top_cumulative_indices for i in range(input_ids_history.size(0))], device=self.device, dtype=torch.bool)
        top_position_ids = self.last_position_id + 1 + (top_cumulative_indices.unsqueeze(0) + self.top_draft - 1) // self.top_draft

        _, position_sort_indices = torch.sort(top_position_ids.flatten(), stable=True)
        ct_top_cumulative_indices = top_cumulative_indices[position_sort_indices]
        top_position_ids = top_position_ids.flatten()[position_sort_indices].unsqueeze(0)
        top_input_ids = input_ids_history[ct_top_cumulative_indices].unsqueeze(0)
        top_base_attention_mask = torch.cat([
            base_attention_mask_history[:, :, ct_top_cumulative_indices, :len(self.base_kv_cache_indices)],
            base_attention_mask_history[:, :, ct_top_cumulative_indices, len(self.base_kv_cache_indices):len(self.base_kv_cache_indices) + input_ids_history.size(0)][..., input_mask],
        ], dim=-1).to(self.dtype)
        parent_indices = parent_indices_history[ct_top_cumulative_indices]

        new_base_kv_cache_indices = self.base_kvcache_pool.allocate(top_input_ids.size(1))
        base_kv_cache_indices = self.base_kv_cache_indices + new_base_kv_cache_indices

        _, _, q, k = top_base_attention_mask.size()
        assert q == top_input_ids.size(1), "Query length must be equal to the number of input tokens"
        assert k == len(base_kv_cache_indices), "Key length must be equal to the number of key-value caches"

        output = self.base_model(
            input_ids=top_input_ids,
            attention_mask=top_base_attention_mask,
            position_ids=top_position_ids,
            past_key_values=self.base_kvcache_pool,
            past_key_value_indices=base_kv_cache_indices,
            use_cache=True,
            output_hidden_states=True,
        )
        base_last_hidden_state = output.last_hidden_state
        base_hidden_states = output.last_hidden_state

        base_logits = self.lm_head(base_last_hidden_state.to(self.lm_head_dtype))

        remaining_base_kv_cache_indices = []

        input_ids = []
        input_features = []
        position_ids = []
        cur_indices = torch.tensor([0]+list(range(self.top_draft)) * self.depth, device=self.device)[input_mask]

        top_input_ids = top_input_ids[0].tolist()
        top_position_ids = top_position_ids[0].tolist()
        parent_indices = parent_indices.tolist()

        if self.temperature > 0:
            base_logits = base_logits / self.temperature
            probs = torch.nn.functional.softmax(base_logits, dim=-1).cpu()

            last_parent_idx = 0
            last_parent_prob = probs[0, last_parent_idx]
            remaining_base_kv_cache_indices.append(new_base_kv_cache_indices[0])

            self.last_hidden_state = base_hidden_states[:, 0, :]
            self.last_position_id = top_position_ids[0]

            for i in reversed_group_indices(top_position_ids)[1:]:
                top_input_id = top_input_ids[i]
                top_position_id = top_position_ids[i]
                parent_idx = parent_indices[i]

                if self.last_position_id + 1 == top_position_id and last_parent_idx == parent_idx:
                    px = last_parent_prob[top_input_id].item()
                    qx = 1.0
                    acp = px / qx
                    u = random.uniform(0, 1)
                    if u <= acp:
                        input_ids.append(top_input_id)
                        input_features.append(self.last_hidden_state)
                        position_ids.append(self.last_position_id + 1)

                        last_parent_prob = probs[0, i]
                        self.last_hidden_state = base_hidden_states[:, i, :]
                        self.last_position_id = top_position_id
                        last_parent_idx = cur_indices[i].item()

                        self.output_buffer.append(top_input_id)
                        remaining_base_kv_cache_indices.append(new_base_kv_cache_indices[i])
                    else:
                        last_parent_prob[top_input_id] = 0
                        last_parent_prob = last_parent_prob / last_parent_prob.sum()

            self.last_output_id = torch.multinomial(last_parent_prob, num_samples=1).item()

            input_ids.append(self.last_output_id)
            input_features.append(self.last_hidden_state)
            position_ids.append(self.last_position_id + 1)

            self.output_buffer.append(self.last_output_id)
        else:
            base_output_ids = torch.argmax(base_logits, dim=-1).flatten()
            last_parent_idx = 0

            for i in range(len(top_input_ids)):
                top_input_id = top_input_ids[i]
                top_position_id = top_position_ids[i]
                parent_idx = parent_indices[i]

                if self.last_output_id == top_input_id and self.last_position_id + 1 == top_position_id and last_parent_idx == parent_idx:
                    self.last_output_id = base_output_ids[i].item()
                    self.last_hidden_state = base_hidden_states[:, i, :]
                    self.last_position_id = top_position_id
                    last_parent_idx = cur_indices[i].item()

                    input_ids.append(self.last_output_id)
                    input_features.append(self.last_hidden_state)
                    position_ids.append(self.last_position_id + 1)

                    self.output_buffer.append(self.last_output_id)
                    remaining_base_kv_cache_indices.append(new_base_kv_cache_indices[i])

        self.input_ids = torch.tensor(input_ids, device=self.device).unsqueeze(0)
        self.input_features = torch.stack(input_features, dim=1)
        self.position_ids = torch.tensor(position_ids, device=self.device).unsqueeze(0)

        new_base_attention_mask = torch.full(
            (self.input_ids.size(1), self.input_ids.size(1) + 1),
            fill_value=torch.finfo(self.dtype).min,
            dtype=self.dtype,
            device=self.device
        )
        new_base_attention_mask = torch.triu(new_base_attention_mask, diagonal=2)
        new_base_attention_mask = new_base_attention_mask[None, None, :, :].expand(1, 1, -1, -1)
        self.base_attention_mask = torch.cat([
            self.base_attention_mask[:, :, :1, :len(self.base_kv_cache_indices)].repeat(1, 1, self.input_ids.size(1), 1),
            new_base_attention_mask,
        ], dim=-1)
        self.draft_attention_mask = self.base_attention_mask.clone()
        self.cumulative_scores = torch.zeros(1, device=self.device, dtype=self.dtype)

        self.base_kv_cache_indices.extend(remaining_base_kv_cache_indices)
        self.base_kvcache_pool.free(set(new_base_kv_cache_indices) - set(remaining_base_kv_cache_indices), remaining_base_kv_cache_indices)

        self.input_ids_history = [torch.tensor(self.last_output_id, device=self.device).unsqueeze(0)]
        self.parent_indices_history = [torch.zeros(1, device=self.device, dtype=self.dtype)]
        self.cumulative_scores_history = [torch.zeros(1, device=self.device, dtype=self.dtype)]  # List[torch.Tensor (top_draft or 1)]
        self.base_attention_mask_history = [
            torch.nn.functional.pad(
                self.base_attention_mask[:, :, -1:],  # (1, 1, 1, t + 1)
                (0, self.max_cache_len - self.base_attention_mask.size(3), 0, 0, 0, 0, 0, 0),
                value=torch.finfo(self.dtype).min,
            )
        ]  # List[torch.Tensor (1, 1, 1 or top_draft, self.max_cache_len)]

        self.first_draft = True

        output_ids = deepcopy(self.output_buffer)
        self.output_buffer = []

        return output_ids

    def print_times(self):
        pass