import time
import torch
from tqdm import tqdm
from termcolor import colored



class LLM:
    """
    A class representing the LLM (currently support Llama, ...).
    """

    def __init__(
        self, 
        model_name: str,
        max_length: int,
        dtype: torch.dtype,
        device_map: str
    ) -> None:
        """ Initializes the LLM.
        Args:
            model_name (str): The name of the model.
            max_length (int): The maximum length (prefill+decode) of sequences.
            dtype (torch.dtype): The data type for model computations.
            device_map (str): The device for model, suppor 'cuda:x' or 'auto (automatically use all visible GPUs)'.
        """

        self.model_name = model_name
        self.max_length = max_length
        self.dtype = dtype
        self.device_map = device_map


    def layer_prefill(self, layer_idx, hidden_states):
        print(f'Layer = {layer_idx}')

        bsz, seq_len, dim = hidden_states.shape
        layer = self.layers[layer_idx]
        
        for start_bdx in range(bsz):
            torch.cuda.empty_cache()

            residual = hidden_states[start_bdx:start_bdx+1, :, :].clone()

            # chunk for lower memory comsumption
            for start_idx in range(0, seq_len, 8192):
                end_idx = min(seq_len, start_idx + 8192)
                hidden_states[start_bdx:start_bdx+1, start_idx:end_idx, :] = self.layernorm(hidden_states[start_bdx:start_bdx+1, start_idx:end_idx, :], 
                                                                                        layer.input_layernorm_variance_epsilon, 
                                                                                        layer.input_layernorm_weight)
            
            query_states, key_states, value_states = self.wqkv(hidden_states[start_bdx:start_bdx+1, :, :], layer)
            query_states, key_states = self.position_embedd(query_states, key_states)

            query_states = query_states.view(1, seq_len, self.num_heads, self.head_dim)       # reshape [bs, seq_len, dim] => [bs, seq_len, head, head_dim]
            key_states = key_states.view(1, seq_len, self.num_key_value_heads, self.head_dim)
            value_states = value_states.view(1, seq_len, self.num_key_value_heads, self.head_dim)

            key_states, value_states = self.kv_cache.prefill_update_kv_cache(query_states, key_states, value_states, layer_idx, start_bdx)    # kv cache update

            temp_attn_out = self.prefill_attention(query_states, key_states, value_states)
            del query_states, key_states, value_states
            torch.cuda.empty_cache()

            hidden_states[start_bdx:start_bdx+1, :, :] = self.wo(temp_attn_out, layer, temp_attn_out.shape[0], seq_len, dim)
            del temp_attn_out
            torch.cuda.empty_cache()
            
            hidden_states[start_bdx:start_bdx+1, :, :] += residual

            # post attention
            residual = hidden_states[start_bdx:start_bdx+1, :, :].clone()

            # chunk for lower memory comsumption
            for start_idx in range(0, seq_len, 8192):
                end_idx = min(seq_len, start_idx + 8192)
                hidden_states[start_bdx:start_bdx+1, start_idx:end_idx, :] = self.layernorm(hidden_states[start_bdx:start_bdx+1, start_idx:end_idx, :], 
                                                                                        layer.post_attention_layernorm_variance_epsilon, 
                                                                                        layer.post_attention_layernorm_weight)
                hidden_states[start_bdx:start_bdx+1, start_idx:end_idx, :] = self.mlp(hidden_states[start_bdx:start_bdx+1, start_idx:end_idx, :], layer)   
            
            hidden_states[start_bdx:start_bdx+1, :, :] += residual

            del residual
            torch.cuda.empty_cache()
                                                                                                   
        return hidden_states


    def layer_decode(self, layer_idx, hidden_states):
        # print(f'Layer = {layer_idx}')

        residual = hidden_states
        bsz, seq_len, dim = hidden_states.shape
        layer = self.layers[layer_idx]

        hidden_states = self.layernorm(hidden_states[:, :, :], layer.input_layernorm_variance_epsilon, layer.input_layernorm_weight)            # input layernorm
        
        query_states, key_states, value_states = self.wqkv(hidden_states, layer)                                                                # Wqkv
        query_states, key_states = self.position_embedd(query_states, key_states)                                                               # position embedding

        query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)                                                                # reshape [bs, seq_len, dim] => [bs, seq_len, head, head_dim]
        key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim)
        value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim)

        key_states, value_states = self.kv_cache.decode_update_kv_cache(key_states, value_states, layer_idx)                                    # kv cache update
        attn_out = self.decode_attention(query_states, key_states, value_states, layer_idx)                                                     # attention
        hidden_states = self.wo(attn_out, layer, bsz, seq_len, dim)                                                                             # Wo
        
        hidden_states = residual + hidden_states                                                                                                # residual
        residual = hidden_states
        hidden_states = self.layernorm(hidden_states, layer.post_attention_layernorm_variance_epsilon, layer.post_attention_layernorm_weight)   # output layernorm
        hidden_states = self.mlp(hidden_states, layer)                                                                                          # mlp
        hidden_states = residual + hidden_states                                                                                                # residual

        return hidden_states


    def prefill_forward(self, inputs_ids):
        bsz, seq_len = inputs_ids.shape
        device = inputs_ids.device

        hidden_states = torch.empty((bsz, seq_len, self.hidden_size), dtype=self.dtype, device=device)
        for start_bdx in range(0, bsz, 1):
            end_bdx = min(bsz, start_bdx + 1)
            hidden_states[start_bdx:end_bdx, :, :] = self.word_embedding(inputs_ids[start_bdx:end_bdx])

        if self.num_gpus > 1:
            for ldx in range(self.num_layers):
                hidden_states = self.layer_prefill(ldx, hidden_states)
                hidden_states = self.parameter_move(hidden_states, ldx)
                torch.cuda.empty_cache()
            hidden_states = hidden_states.to(self.layers[0].device)
        else:
            for ldx in range(self.num_layers):
                hidden_states = self.layer_prefill(ldx, hidden_states)
                torch.cuda.empty_cache()
        
        hidden_states = self.layernorm(hidden_states[:, -1:, :].contiguous(), self.norm_variance_epsilon, self.norm_weight)
        logits = self.lm(hidden_states)
        
        return logits
        

    def decode_forward(self, inputs_ids):
        hidden_states = self.word_embedding(inputs_ids)

        if self.num_gpus > 1:
            for ldx in range(self.num_layers):
                hidden_states = self.layer_decode(ldx, hidden_states)
                hidden_states = self.parameter_move(hidden_states, ldx)
            hidden_states = hidden_states.to(self.layers[0].device)
        else:
            for ldx in range(self.num_layers):
                hidden_states = self.layer_decode(ldx, hidden_states)
        
        hidden_states = self.layernorm(hidden_states[:, -1:, :], self.norm_variance_epsilon, self.norm_weight)
        logits = self.lm(hidden_states)
        
        return logits


    def inference(self, inputs_ids):
        outputs_ids = []    # multi iteration, multi request
        output_ids = []     # single iteration, multi request
        
        logits = self.prefill_forward(inputs_ids=inputs_ids)
        output_ids = logits.argmax(dim=-1)
        outputs_ids.append(output_ids)

        torch.cuda.synchronize()

        decode_start = time.time()

        for _ in range(self.max_new_length-1):
            logits = self.decode_forward(inputs_ids=output_ids)
            output_ids = logits.argmax(dim=-1)
            outputs_ids.append(output_ids)

        decode_end = time.time()
        print(colored(
            f"Latency: {round((decode_end - decode_start) * 1000 / (self.max_new_length - 1), 2)} ms/iter,"
            f"Throughput: {round(self.batch_size * (self.max_new_length - 1) / (decode_end - decode_start), 2)} tokens/s",
            'red'
        ))
        
        outputs_ids = torch.cat(outputs_ids, dim=-1).tolist()
        
        return outputs_ids


    def generate(self, attention_type, inputs_ids, max_new_length):
        """ LLM Inference.
        Args:
            attention_type: str,
            max_length (int): The maximum length (prefill+decode) of sequences.
            input_ids (torch.tensor): The input of LLM.
            max_new_length (int): The maximum length of generated sequences.
        """

        bs, input_length = inputs_ids.shape
        assert input_length + max_new_length <= self.max_length, f"Error: input_length({input_length}) + max_new_length({max_new_length}) exceeds max_length({self.max_length})"
        self.batch_size = bs
        self.input_length = input_length
        self.max_new_length = max_new_length
        self.attention_type = attention_type

        self.init_kv_cache()

        outputs = self.inference(inputs_ids)

        return outputs