# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import os
import time
from typing import List, Optional
import torch
import torch.distributed as dist
from tqdm import tqdm

from model import DenseAttentionModel, RingAttentionModel, StarAttentionModel
from lring.utils import load_ring_model_tokenizer, ring_prefill_and_decode
import torch.distributed as dist
from minference import MInference
from transformers import AutoTokenizer, AutoModelForCausalLM
from modeling_llama_locret import LlamaForCausalLM

def read_jsonl(filename, num_lines=-1):
    lines = []
    with open(filename) as f:
        for i, line in enumerate(f):
            lines.append(json.loads(line))
            if i == num_lines:
                break
    return lines


def init_distributed():
    """Initialize the distributed environment."""

    if 'RANK' in os.environ:
        dist.init_process_group('nccl')
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        print(f'[run_star_attn_inference.init_distributed] Rank: {rank}, World size: {world_size}')
    else:
        rank = 0
        world_size = 1

    return rank, world_size


def get_resume_point(input_data, output_file):
    """Get the total number of samples saved in the output file."""
    if os.path.exists(output_file):
        output_data = read_jsonl(output_file)
        if 'index' in input_data[0]:
            pred_index = [x['index'] for x in output_data]
            input_data = [x for x in input_data if x['index'] not in pred_index]
        else:
            input_data = input_data[len(output_data) :]

    return input_data


def load_model(
    model_path,
    attn_type,
    tokens_to_generate,
    block_size=-1,
    anchor_block_size=-1,
    stop_words=None,
):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    
    model = LlamaForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
        attn_implementation='sdpa',
    )
    locret_path = 'yourpath.bin'
    ckpt = torch.load(locret_path)
    pruned_ckpt = {}
    for k, v in ckpt.items():
        if 'fc' in k:
            pruned_ckpt[k] = v
    model.load_state_dict(pruned_ckpt, strict=False)
    
    stop_ids = [tokenizer.encode(s)[-1] for s in stop_words]
    return model, tokenizer, stop_ids

@torch.no_grad()
def generate(model, input_ids, max_new_tokens, eos_token_ids, prefix_len, block_size=128, budget_size=6000, local_len=100, chunk_size=4096, stabilizers=2500):
    seq_len = input_ids.shape[-1]
    # prefill first
    past_key_values = None
    scores = [None for _ in range(100)]

    for i in range(0, seq_len - local_len, chunk_size):
        b = i
        e = min(i + chunk_size, seq_len - local_len)
        ipt = input_ids[:, b:e]
        output = model(ipt, use_cache=True, past_key_values=past_key_values, output_attentions=True)
        past_key_values = output.past_key_values

        pruned_kv_cache = []
        kv_shape = past_key_values[0][0].shape
        for j in range(model.config.num_hidden_layers):
            if scores[j] is None:
                cur_score = output.attentions[j][:, :e-b, :].to("cuda:0") 
                scores[j] = cur_score
            else:
                cur_score = output.attentions[j][:, :e-b, :].to("cuda:0")
                scores[j] = torch.cat(
                    (scores[j], cur_score), dim=-2,
                )
            
            sc = scores[j].clone()
            selected_num = min(budget_size, sc.shape[-2])
            sc[:, :prefix_len] = torch.finfo(sc.dtype).max
            if b + chunk_size < seq_len - local_len:
                sc[:, -stabilizers:, :] = torch.finfo(sc.dtype).max 
            
            # breakpoint()
            rounded_len = (sc.shape[1] // block_size) * block_size
            tail_len = sc.shape[1] - rounded_len
            if tail_len > 0:
                tail = sc[:, -tail_len:, :]
                sc = sc[:, :-tail_len, :]
            sc = sc.reshape(sc.shape[0], sc.shape[1] // block_size, block_size, sc.shape[-1])
            sc = sc.max(dim=2, keepdim=True).values
            sc = sc.repeat(1, 1, block_size, 1)
            sc = sc.reshape(sc.shape[0], -1, sc.shape[-1])
            if tail_len > 0:
                sc = torch.cat((sc, tail), dim=1)
            
            selected_indices = torch.topk(sc[0, :, :], k=selected_num, dim=-2)[1].transpose(0, 1).sort().values # (32, budget_size)
            selected_indices_ = selected_indices.clone().transpose(0, 1).unsqueeze(0)
            scores[j] = torch.gather(scores[j], 1, selected_indices_.to(sc.device))
            selected_indices = selected_indices.unsqueeze(0).unsqueeze(-1).repeat(kv_shape[0], 1, 1, kv_shape[3]) # TODO
            k = torch.gather(past_key_values[j][0], 2, selected_indices.to(past_key_values[j][0].device))
            v = torch.gather(past_key_values[j][1], 2, selected_indices.to(past_key_values[j][1].device))
            pruned_kv_cache.append((k, v))
        past_key_values = pruned_kv_cache                
        del pruned_kv_cache
        torch.cuda.empty_cache()
    b = e
    e = seq_len     
    position_ids = torch.arange(b, e, dtype=torch.int, device=input_ids.device).unsqueeze(0)
    output = model(input_ids[:, b:e], use_cache=True, past_key_values=past_key_values, output_attentions=True)
    del past_key_values
    past_key_values = output.past_key_values
    

    input_tokens = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
    generated_tokens = [input_tokens.item()]
    position_ids = torch.tensor([[seq_len]]).cuda()
    for i in range(max_new_tokens - 1):
        output = model(input_tokens, past_key_values=past_key_values)
        position_ids += 1
        past_key_values = output.past_key_values
        input_tokens = output.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        if input_tokens.item() in eos_token_ids:
            break
        generated_tokens.append(input_tokens.item())
    generated_tokens = torch.tensor(generated_tokens, device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)
    input_ids = torch.cat((input_ids, generated_tokens), dim=-1)
    return input_ids

def main(
    model_path: str,
    attn_type: str,
    tokens_to_generate: int,
    input_file: str,
    output_file: str,
    num_samples: int = -1,
    stop_words: Optional[List[str]] = None,
    block_size: int = -1,
    anchor_block_size: int = -1,
    use_cache: bool = False,
):
    """Run inference using Star-Attention.

    Args:
        model_path: path to the model checkpoint
        attn_type: type of attention. One of ['dense', 'star', 'starkv']
        tokens_to_generate: number of tokens to generate during generation
        input_file: path to the input jsonl file
        output_file: path to the output jsonl file where the generated predictions will be saved
        stop_words: list of stop words for generation. Default: None
        block_size: block size for star attention. Default: -1 (should be provided for star attention)
        anchor_block_size: anchor block size for star attention. Default: -1 (should be provided for star attention)
        use_cache: resume from last generation if the output file already exists. Default: False
    """
    rank, _ = init_distributed()

    if rank == 0:
        process_start_time = time.time()

    # Load data
    input_data = read_jsonl(input_file)
    if num_samples > 0:
        input_data = input_data[:num_samples]

    # Resume from last generation
    output_file_mode = 'wt'
    if os.path.exists(output_file) and use_cache:
        print(f'Resuming from last generation. Output file already exists: {output_file}')
        input_data = get_resume_point(input_data, output_file)
        output_file_mode = 'at'

    # Load model
    model, tokenizer, stop_ids = load_model(
        model_path,
        attn_type,
        tokens_to_generate,
        block_size,
        anchor_block_size,
        stop_words=stop_words,
    )
    max_new_tokens = tokens_to_generate

    if rank == 0:
        inference_start_time = time.time()
    input_data = input_data[:20]
    # Generate predictions
    # setting buffering=1 to force to dump the output after every line, so that we can see intermediate generations
    with open(output_file, output_file_mode, encoding='utf-8', buffering=1) as fout:
        for input_sample in tqdm(input_data, total=len(input_data)):
            # pred = model(prompt_context=input_sample['input_context'], prompt_query=input_sample['input_query'])
            input_str = input_sample['input_context'] + input_sample['input_query']
            input_ids = tokenizer(input_str, return_tensors='pt', add_special_tokens=False).to(f"cuda:{rank}").input_ids
            position_ids = torch.arange(input_ids.shape[-1], dtype=torch.long, device=torch.device(f"cuda:{rank}")).unsqueeze(0)
            
            
            query_str = input_sample['input_query'].split("<|")
            query_ids = tokenizer(query_str[0], return_tensors='pt', add_special_tokens=False).to(f"cuda:{rank}").input_ids
            
            input_ids = torch.cat([input_ids[:, :16], query_ids, input_ids[:, 16:]], dim=-1)
            prefix_len = 16 + query_ids.shape[-1]
            # prefix_len = 0
            input_len = input_ids.shape[-1]
            output_ids = generate(model, input_ids, max_new_tokens, stop_ids, prefix_len)
            generated_ids = output_ids[0, input_len:]
            
            pred = tokenizer.decode(generated_ids, skip_special_tokens=True)
            if rank == 0:
                fout.write(
                    json.dumps(
                        {
                            'index': input_sample.get('index', -1),
                            'pred': pred,
                            'input_context': input_sample['input_context'],
                            'input_query': input_sample['input_query'],
                            'outputs': (
                                input_sample['outputs'] if 'outputs' in input_sample else [input_sample['output']]
                            ),
                            'others': input_sample.get('others', {}),
                            'truncation': input_sample.get('truncation', -1),
                            'length': input_sample.get('length', -1),
                        }
                    )
                    + '\n'
                )

    if rank == 0:
        end_time = time.time()
        print(f'Total time: {round((end_time - process_start_time) / 60, 1)} minutes')
        print(f'Inference time: {round((end_time - inference_start_time) / 60, 1)} minutes')



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', required=True, help='path to the model checkpoint')
    parser.add_argument('--attn_type', required=True, help='type of attention')
    parser.add_argument('--block_size', type=int, default=-1, help='block size for star attention')
    parser.add_argument('--anchor_block_size', type=int, default=-1, help='anchor block size for star attention')
    parser.add_argument('--tokens_to_generate', type=int, required=True, help='number of tokens to generate')
    parser.add_argument('--stop_words', default='', help='comma separated stop words for generation')
    parser.add_argument('--input_path', required=True, help='path to the input jsonl file')
    parser.add_argument('--num_samples', type=int, default=-1, help='number of samples to use from the input file')
    parser.add_argument(
        '--output_path', required=True, help='path to the jsonl file where the generated predictions will be saved'
    )
    parser.add_argument(
        '--use_cache', action='store_true', help='resume from last generation if the output file already exists'
    )
    args = parser.parse_args()

    if not os.path.exists(args.model_path):
        raise FileNotFoundError(f'Invalid model path: {args.model_path}')

    if not os.path.exists(args.input_path):
        raise FileNotFoundError(f'Invalid input path: {args.input_path}')
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)

    main(
        args.model_path,
        args.attn_type,
        args.tokens_to_generate,
        args.input_path,
        args.output_path,
        num_samples=args.num_samples,
        stop_words=list(filter(None, args.stop_words.split(','))),
        block_size=args.block_size,
        anchor_block_size=args.anchor_block_size,
        use_cache=args.use_cache,
    )
