import os
import json
import argparse
import datetime
import re

import pandas as pd
from transformers import AutoTokenizer
from transformers.generation.utils import GenerationConfig

from datastore import get_data_store

data_file_path = '../datas/system_benchmark_eval_datas.json'

# THUDM/glm-4-9b-chat, meta-llama/Meta-Llama-3.1-8B-Instruct, Qwen/Qwen2-72B-Instruct
checkpoint_paths = {
    'glm-9b': '/path/to/.cache/huggingface/hub/models--THUDM--glm-4-9b-chat/snapshots/aae8bd74af5c6dff63a49d7fbdcc89349ebf87aa/',
    'llama31-8b': '/path/to/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/8c22764a7e3675c50d4c7c9a4edb474456022b16/',
    'qwen-72b': '/path/to/.cache/huggingface/hub/models--Qwen--Qwen2-72B-Instruct/snapshots/1af63c698f59c4235668ec9c1395468cb7cd7e79/'
}

extra_tokens = {
    'glm-9b': 2,
    'llama31-8b': 25,
    'qwen-72b': 3
}

KEY_MAP = {
    'qwen-72b'     : 'qwen2_72b',
    'llama31-8b'   : 'llama3_8b',
    'glm-9b'      : 'glm_9b_client'
}

REPLACE_TOKEN = {
    'qwen-72b'     : 2,
    'llama31-8b'   : 2,
    'glm-9b'      : 0
}

TOTAL_SYSTEM_ID = 500
TURN_NUMBER = 5

ENTRY_NUMBER = TOTAL_SYSTEM_ID * TURN_NUMBER

placeholder = 30

cached_sid = set()

def load_examples(dataset_filepath):
    data = json.load(open(dataset_filepath, encoding="utf-8"))
    return data


def parse_xls(key, sheet_name='详情', root_dir='../output'):
    file_path = os.path.join(root_dir, KEY_MAP[key], f'{KEY_MAP[key]}_analysis.xlsx')
    df = pd.read_excel(file_path, sheet_name)
    if sheet_name == '详情':
        assert len(df) == ENTRY_NUMBER, f'Reading error: {len(df)} entries found, expected {ENTRY_NUMBER} entries ({file_path})'
    return df


def converation_generator(sysmeg_id):
    for entry in load_examples(data_file_path):
        #print("-"*placeholder*2)
        #print(entry.keys(),'\n', entry.values())
        if entry['system_id'] == sysmeg_id:
            print("System message ID:", sysmeg_id)
            for message in entry['messages']:
                if message['role'] == 'assistant':
                    continue # ignore ground truth
                yield message
            break
    else:
        raise ValueError(f"System message with id {sysmeg_id} not found")

def get_model_type(model_name, offload):
    if model_name == 'glm-9b':
        from modeling_chatglm import ChatGLMForConditionalGeneration
        return ChatGLMForConditionalGeneration
    elif model_name == 'llama31-8b':
        from modeling_llama import LlamaForCausalLM
        return LlamaForCausalLM
    elif model_name == 'qwen-72b':
        if offload:
            from modeling_qwen2_offload import Qwen2ForCausalLM
            return Qwen2ForCausalLM
        else:
            from modeling_qwen2 import Qwen2ForCausalLM
            return Qwen2ForCausalLM
    else:
        raise ValueError(f"Model name {model_name} not found")

def workflow(arg, model, tokenizer, generation_config, datastore):
    if arg.id in cached_sid and not arg.ignore_cache:
        print(f"System message {arg.id} already cached, ignore")
        return
    #count
    datastore.clear()
    datastore.add_split_index(extra_tokens[arg.model] - 1, extra=True) # special tokens at the beginning

    generation_length = generation_config.max_length
    
    messages = []
    model_output = parse_xls(arg.model)
    conver_turn = 0
    
    sum_length = 0
    for message in converation_generator(arg.id):
        if message['role'] == 'system':
            print("System message:", message)
            messages.append(message)
            if arg.replace:
                messages[-1]["role"] = "user"
                tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
                input_length = tokenized_chat.shape[-1]
                datastore.add_split_index(input_length - 1 -REPLACE_TOKEN[arg.model])
                #sum_length += input_length
            else:
                tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
                input_length = tokenized_chat.shape[-1]
                datastore.add_split_index(input_length - 1)
                sum_length += input_length
            continue
        
        if len(messages) > 0 and messages[-1]["role"] == message["role"]:
            # concat the content
            print(messages[-1]["role"])
            assert len(messages) == 1 and messages[-1]["role"] == "user" and arg.replace
            messages[-1]["content"] += message["content"]
        else:
            messages.append(message)
        
        
        tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
        input_length = tokenized_chat.shape[-1]
        datastore.add_split_index(input_length - 1)
        sum_length += input_length
        
            
        
        ins_loc = (arg.id-1) * 5 + conver_turn
        output_text = model_output.loc[ins_loc, 'answer']
        
        print(datetime.datetime.now(), "Output text:", output_text, flush=True)
        
        messages.append({"role": "assistant", "content": output_text})
        tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
        
        input_length = tokenized_chat.shape[-1]
        datastore.add_split_index(input_length - 1)
        sum_length += input_length
        
        conver_turn += 1
    
    split_indices = datastore.get_split_indices()
    split_indices = [v + 1 for v in split_indices]
    
    #print("Split indices:", split_indices)

    tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_tensors="pt")
    
    # generation_config.max_length = sum_length + generation_length
    generation_config.max_length = tokenized_chat.shape[-1] + generation_length
    kwargs = {
        'inputs': tokenized_chat.to('cuda') if not arg.offload else tokenized_chat,
        'generation_config' : generation_config
    }
    
    print("+++++++ Generate +++++++")
    print(messages)
    outputs = model.generate(**kwargs)
    #outputs = tokenized_chat
    
    for i, idx in enumerate(split_indices):
        if i == 0:
            if datastore.has_extra:
                continue
            print(f'===== Split {i} =====', tokenizer.decode(outputs[0, :idx], skip_special_tokens=False), sep='\n')
        else:
            print(f'===== Split {i} =====', tokenizer.decode(outputs[0, split_indices[i-1]:idx], skip_special_tokens=False), sep='\n')
    
    datastore.save_data(arg.save_path, file_name=f'sid{arg.id}')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--id", type=int, default=287, help="System message ID, -1 means all")
    parser.add_argument("--save_path", type=str, default='glm', help="Path to save the data")
    parser.add_argument("--model", type=str, default='glm-9b', choices=['glm-9b', 'llama31-8b', 'qwen-72b'], help="Model name")
    parser.add_argument("--ignore_cache", action='store_true', help="Ignore cache")
    parser.add_argument("--replace", action='store_true', help="Replace system message as user message")
    parser.add_argument("--offload", action='store_true', help="Offload to CPU, only for qwen-72b")
    # parser.add_argument("--seed", type=int, default=None, help="Random seed")
    arg = parser.parse_args()
    
    if arg.offload and arg.model != 'qwen-72b':
        raise ValueError("Offload only works for qwen-72b")
    
    model_cls = get_model_type(arg.model, arg.offload)
    checkpoint_path = checkpoint_paths[arg.model]
    
    model = model_cls.from_pretrained(checkpoint_path, device_map="auto", torch_dtype='auto')
    #infer_map = infer_auto_device_map(model)
    #print(f"infer map {infer_map}")
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
    #print("type of tokenizer", type(tokenizer))
    
    generation_config = GenerationConfig.from_pretrained(checkpoint_path)
    #generation_config.max_length = 8192
    generation_config.max_length = 1
    print(generation_config)
    
    datastore = get_data_store()
    
    
    if arg.id == -1: # -1 means all
        from tqdm import tqdm
        id_list = list(range(1, 501))
        
        pattern = re.compile(r"sid(\d+).npy")
        if os.path.exists(arg.save_path):
            for file in os.listdir(arg.save_path):
                match = pattern.match(file)
                if match:
                    cached_sid.add(int(match.group(1)))

        for id in tqdm(id_list):
            arg.id = id
            workflow(arg, model, tokenizer, generation_config, datastore)
        
    else:
        if arg.replace:
            arg.save_path += '_replace'
        
        #pattern = re.compile(r"layer_\d+_sid(\d+).npy")
        pattern = re.compile(r"sid(\d+).npy")
        if os.path.exists(arg.save_path):
            for file in os.listdir(arg.save_path):
                match = pattern.match(file)
                if match:
                    cached_sid.add(int(match.group(1)))
        
        workflow(arg, model, tokenizer, generation_config, datastore)    
    

if __name__ == "__main__":
    main()
    #debug()