# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os
import time
from typing import List, Optional
import torch
import torch.distributed as dist
from tqdm import tqdm

from model import DenseAttentionModel, RingAttentionModel, StarAttentionModel
from lring.utils import load_ring_model_tokenizer, ring_prefill_and_decode
import torch.distributed as dist
from minference import MInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from modeling_llama_sirllm import LlamaForCausalLM

def read_jsonl(filename, num_lines=-1):
    lines = []
    with open(filename) as f:
        for i, line in enumerate(f):
            lines.append(json.loads(line))
            if i == num_lines:
                break
    return lines


def init_distributed():
    """Initialize the distributed environment."""

    if 'RANK' in os.environ:
        dist.init_process_group('nccl')
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        print(f'[run_star_attn_inference.init_distributed] Rank: {rank}, World size: {world_size}')
    else:
        rank = 0
        world_size = 1

    return rank, world_size


def get_resume_point(input_data, output_file):
    """Get the total number of samples saved in the output file."""
    if os.path.exists(output_file):
        output_data = read_jsonl(output_file)
        if 'index' in input_data[0]:
            pred_index = [x['index'] for x in output_data]
            input_data = [x for x in input_data if x['index'] not in pred_index]
        else:
            input_data = input_data[len(output_data) :]

    return input_data


def load_model(
    model_path,
    attn_type,
    tokens_to_generate,
    block_size=-1,
    anchor_block_size=-1,
    stop_words=None,
):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    
    model = LlamaForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
        attn_implementation='flash_attention_2',
    )
    
    stop_ids = [tokenizer.encode(s)[-1] for s in stop_words]
    return model, tokenizer, stop_ids

from sir_llm.eval_utils import Evaluator
from sir_llm.enable_streaming_llm import enable_streaming_llm

class SirLLMConfig:
    def __init__(self):
        self.start_size = 4
        # self.token_entropy_size = 16384
        self.token_entropy_size = 6000
        self.recent_size = 1000
        self.max_gen_len = 20
        self.decay_ratio = 1

BS = 4096
@torch.no_grad()
def gen(model, tokenizer, input_ids, max_new_tokens, eos_token_ids):
    position_ids = torch.arange(input_ids.shape[-1]).unsqueeze(0).int().to(model.device)
    s_config = SirLLMConfig()
    s_config.max_gen_len = max_new_tokens
    kv_cache = enable_streaming_llm(
        model, start_size=s_config.start_size, 
        recent_size=s_config.recent_size,
        token_entropy_size=s_config.token_entropy_size,
    )
    generator = Evaluator(model, tokenizer, s_config)
    token_entropy = None
    past_key_values = None
    for b in range(0, input_ids.shape[-1], BS):
        e = min(input_ids.shape[-1], b + BS)
        # breakpoint()
        temp = kv_cache.evict_for_space_token_entropy(past_key_values,token_entropy, e - b)
        past_key_values=temp[0]
        token_entropy=temp[1]
        past_key_values, _token_entropy, logits = generator._greedy_generate_token_entropy_simple(input_ids[:, b:e], continue_len=1,past_key_values=past_key_values,token_entropy=token_entropy)
        if token_entropy is None:
            token_entropy = _token_entropy
        else:
            token_entropy += _token_entropy

    generated_tokens = []
    input_id = logits[:, -1, :].argmax(dim=-1, keepdim=True)
    generated_tokens.append(input_id.item())

    for _ in range(max_new_tokens-1):
        output = model(input_id, past_key_values=past_key_values)
        past_key_values = output.past_key_values
        input_id = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        if input_id.item() in eos_token_ids:
            break
        generated_tokens.append(input_id.item())
    return generated_tokens

def main(
    model_path: str,
    attn_type: str,
    tokens_to_generate: int,
    input_file: str,
    output_file: str,
    num_samples: int = -1,
    stop_words: Optional[List[str]] = None,
    block_size: int = -1,
    anchor_block_size: int = -1,
    use_cache: bool = False,
):
    """Run inference using Star-Attention.

    Args:
        model_path: path to the model checkpoint
        attn_type: type of attention. One of ['dense', 'star', 'starkv']
        tokens_to_generate: number of tokens to generate during generation
        input_file: path to the input jsonl file
        output_file: path to the output jsonl file where the generated predictions will be saved
        stop_words: list of stop words for generation. Default: None
        block_size: block size for star attention. Default: -1 (should be provided for star attention)
        anchor_block_size: anchor block size for star attention. Default: -1 (should be provided for star attention)
        use_cache: resume from last generation if the output file already exists. Default: False
    """
    rank, _ = init_distributed()

    if rank == 0:
        process_start_time = time.time()

    # Load data
    input_data = read_jsonl(input_file)
    if num_samples > 0:
        input_data = input_data[:num_samples]

    # Resume from last generation
    output_file_mode = 'wt'
    if os.path.exists(output_file) and use_cache:
        print(f'Resuming from last generation. Output file already exists: {output_file}')
        input_data = get_resume_point(input_data, output_file)
        output_file_mode = 'at'

    # Load model
    model, tokenizer, stop_ids = load_model(
        model_path,
        attn_type,
        tokens_to_generate,
        block_size,
        anchor_block_size,
        stop_words=stop_words,
    )
    max_new_tokens = tokens_to_generate

    if rank == 0:
        inference_start_time = time.time()

    # Generate predictions
    # setting buffering=1 to force to dump the output after every line, so that we can see intermediate generations
    with open(output_file, output_file_mode, encoding='utf-8', buffering=1) as fout:
        for input_sample in tqdm(input_data, total=len(input_data)):
            # pred = model(prompt_context=input_sample['input_context'], prompt_query=input_sample['input_query'])
            input_str = input_sample['input_context'] + input_sample['input_query']
            input_ids = tokenizer(input_str, return_tensors='pt', add_special_tokens=False).to(f"cuda:{rank}").input_ids
            position_ids = torch.arange(input_ids.shape[-1], dtype=torch.long, device=torch.device(f"cuda:{rank}")).unsqueeze(0)
            
            
            # query_str = input_sample['input_query'].split("<|")
            # query_ids = tokenizer(query_str[0], return_tensors='pt', add_special_tokens=False).to(f"cuda:{rank}").input_ids
            
            # input_ids = torch.cat([input_ids[:, :16], query_ids, input_ids[:, 16:]], dim=-1)
            # prefix_len = 16 + query_ids.shape[-1]
            # input_len = input_ids.shape[-1]
            generated_ids = gen(model, tokenizer, input_ids, max_new_tokens, stop_ids)
            # generated_ids = output_ids[0, input_len:]
            
            pred = tokenizer.decode(generated_ids, skip_special_tokens=True)
            if rank == 0:
                fout.write(
                    json.dumps(
                        {
                            'index': input_sample.get('index', -1),
                            'pred': pred,
                            'input_context': input_sample['input_context'],
                            'input_query': input_sample['input_query'],
                            'outputs': (
                                input_sample['outputs'] if 'outputs' in input_sample else [input_sample['output']]
                            ),
                            'others': input_sample.get('others', {}),
                            'truncation': input_sample.get('truncation', -1),
                            'length': input_sample.get('length', -1),
                        }
                    )
                    + '\n'
                )

    if rank == 0:
        end_time = time.time()
        print(f'Total time: {round((end_time - process_start_time) / 60, 1)} minutes')
        print(f'Inference time: {round((end_time - inference_start_time) / 60, 1)} minutes')



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', required=True, help='path to the model checkpoint')
    parser.add_argument('--attn_type', required=True, help='type of attention')
    parser.add_argument('--block_size', type=int, default=-1, help='block size for star attention')
    parser.add_argument('--anchor_block_size', type=int, default=-1, help='anchor block size for star attention')
    parser.add_argument('--tokens_to_generate', type=int, required=True, help='number of tokens to generate')
    parser.add_argument('--stop_words', default='', help='comma separated stop words for generation')
    parser.add_argument('--input_path', required=True, help='path to the input jsonl file')
    parser.add_argument('--num_samples', type=int, default=-1, help='number of samples to use from the input file')
    parser.add_argument(
        '--output_path', required=True, help='path to the jsonl file where the generated predictions will be saved'
    )
    parser.add_argument(
        '--use_cache', action='store_true', help='resume from last generation if the output file already exists'
    )
    args = parser.parse_args()

    if not os.path.exists(args.model_path):
        raise FileNotFoundError(f'Invalid model path: {args.model_path}')

    if not os.path.exists(args.input_path):
        raise FileNotFoundError(f'Invalid input path: {args.input_path}')
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)

    main(
        args.model_path,
        args.attn_type,
        args.tokens_to_generate,
        args.input_path,
        args.output_path,
        num_samples=args.num_samples,
        stop_words=list(filter(None, args.stop_words.split(','))),
        block_size=args.block_size,
        anchor_block_size=args.anchor_block_size,
        use_cache=args.use_cache,
    )
