import json
import torch
from transformers import AutoTokenizer
from AutoCompressors.auto_compressor import LlamaAutoCompressorModel
import math
import argparse
from tqdm import tqdm
import random
import numpy as np
import os
from QA_text_dataset_autocompressor import InferDataset
from chat import apply_chat_template
from icecream import ic as pprint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset2maxgen = {
    "NQ": 100,
    "HQA": 32,
    "WikiQA": 32,
}

def seed_everything(TORCH_SEED):
    random.seed(TORCH_SEED)
    os.environ['PYTHONHASHSEED'] = str(TORCH_SEED)
    np.random.seed(TORCH_SEED)
    torch.manual_seed(TORCH_SEED)
    torch.cuda.manual_seed_all(TORCH_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, default='/path/to/project/data/NQ/test.jsonl')
    parser.add_argument('--model_path', type=str, default='/path/to/project/llmodel/AutoCompressor-Llama-2-7b-6k')
    parser.add_argument('--output_filename', type=str, default='NQ_outputs.jsonl')
    parser.add_argument('--max_doc_tokens', type=int, default=5120)
    parser.add_argument('--max_num_documents', type=int, default=20)
    parser.add_argument('--min_num_documents', type=int, default=20)
    parser.add_argument('--random_num_documents', action='store_true')
    parser.add_argument('--instruction_name', type=str, default='base')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--compression_rate', type=int, default=4)
    return parser.parse_args()

def get_segments_len(input_ids, segment_size=50):
    n = math.ceil(input_ids.size(1) / segment_size)
    if n == 1:
        return [input_ids.size(1)]
    else:
        if input_ids.size(1) % segment_size == 0:
            l = [segment_size] * n
        else:
            l = [segment_size if i != (n - 1) else input_ids.size(1) % segment_size for i in range(n)]
    return l

def load_data(args, model, tokenizer):
    
    args.instruction_name = 'base'
    dataset = InferDataset(
        filepath=args.data_path,
        model=model,
        tokenizer=tokenizer,
        max_doc_tokens=args.max_doc_tokens,
        max_num_documents=args.max_num_documents,
        min_num_documents=args.min_num_documents,
        random_num_documents=args.random_num_documents,
        instruction_name=args.instruction_name,
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=dataset.collate_fn,
    )
    return data_loader

def generate_answer(context, prompt, model, tokenizer, eos_token_id, dataset_name, compression_rate):
    prompt_tokens = tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids.cuda()
    inputs = tokenizer(context, add_special_tokens=False, return_tensors="pt").to(device)
    
    context_tokens = inputs["input_ids"]
    segment_size = compression_rate*50
    segment_lengths = get_segments_len(context_tokens, segment_size)
    print(f"Compression rate: {compression_rate}, Segment size: {segment_size}")
    print(f"Input length: {context_tokens.size(1)}, Segments: {len(segment_lengths)}")
    
    
    summary_vectors = model(
        context_tokens,
        segment_lengths=segment_lengths,
        output_softprompt=True
    ).softprompt
    
    outputs = model.generate(
        prompt_tokens,
        do_sample=False,
        softprompt=summary_vectors,
        max_new_tokens=dataset2maxgen[dataset_name],
        # eos_token_id=eos_token_id,
        # begin_suppress_tokens=eos_token_id,
    )[:, prompt_tokens.size(1):]
    
    generation = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    pprint(dataset2maxgen[dataset_name], generation)
    return generation, context_tokens.size(1), summary_vectors.shape[1]

def main():
    args = get_args()

    seed_everything(args.seed)
    
    pprint(args)

    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    model = LlamaAutoCompressorModel.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
    ).eval().to(device)
    pprint(model.config.summary_length) 

    if hasattr(model, "generation_config"):
        eos_token_id = model.generation_config.eos_token_id
    else:
        eos_token_id = tokenizer.eos_token_id
    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    eos_token_id.append(tokenizer.encode("\n", add_special_tokens=False)[-1])

    data_loader = load_data(args, model, tokenizer)
    
    dataset_name = args.data_path.split("/")[-2]
    total_generations = []
    total_questions = []
    total_answers = []
    total_original_num_tokens = []
    total_compressed_num_tokens = []
    
    
    
    with torch.no_grad():
        for batch in tqdm(data_loader):
            # model.memory.reset()
            batch_generation, original_num_tokens, compressed_num_tokens = generate_answer(
                batch['prompt'],
                batch['question'],
                model,
                tokenizer,
                eos_token_id,
                dataset_name,
                args.compression_rate
            )
            
            total_generations += batch_generation
            total_questions += batch['question']
            total_answers += batch['answers']
            total_original_num_tokens += [original_num_tokens]
            total_compressed_num_tokens += [compressed_num_tokens]
    
    with open(args.output_filename, 'w', encoding='utf-8') as f:
        for x, y, z, o, c in zip(total_questions, total_generations, total_answers, total_original_num_tokens, total_compressed_num_tokens):
            item = {'question': x, 'generation': y, 'answers': z, 'original_num_tokens': o, 'compressed_num_tokens': c}
            json_line = json.dumps(item, ensure_ascii=False)
            f.write(json_line + '\n')

if __name__ == "__main__":
    main()