import os
DEVICE = "2"
os.environ["CUDA_VISIBLE_DEVICES"] = DEVICE
import json
import argparse
#from sklearn.feature_extraction.text import TfidfVectorizer
#from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from nltk import sent_tokenize
import re
import numpy as np
import string
import torch
import yaml

import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from utils import *
import inseq
from inseq.commands.attribute_context.attribute_context import AttributeContextArgs, attribute_context
from myattribute import attribute_context_with_model

def remove_citations(sent):
    return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")

def load_model(model_name_or_path):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path,
        padding_side="left"
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    return model, tokenizer


def mirage_cite(res_mirage, cti_threshold, start_pos_sent, end_pos_sent, topk_CCI, doc_seps):
    res = []

    sum_weight = 0
    sum_value = np.zeros(len(res_mirage['input_context_tokens']))

    for i in res_mirage['cci_scores']:
        # CTI Filtering
        if not (i["cti_idx"] >= start_pos_sent and i["cti_idx"] < end_pos_sent): continue
        if i['cti_score'] >= cti_threshold:
            # CCI Focus
            CCI_value = np.array(i['input_context_scores'])
            if topk_CCI == 0:
                cci_threshold = np.mean(CCI_value)
            elif topk_CCI < 0:
                cci_threshold = (1 + topk_CCI / 100) * np.max(CCI_value) - topk_CCI / 100 * np.min(CCI_value)
            else:
                cci_threshold = np.sort(CCI_value)[-topk_CCI]
            zero_idx = CCI_value < cci_threshold
            CCI_value[zero_idx] = 0

            sum_value += CCI_value

        if i['cti_score'] < cti_threshold: break

    sum_tmp = 0
    for i, v in enumerate(sum_value):
        sum_tmp += v
        if doc_seps[i] or (i == len(sum_value) - 1):  # meet '\n'
            res.append(sum_tmp)
            sum_tmp = 0
    return res

def make_doc_prompt(item):
    doc_prompt = ""
    docs = item['docs']
    for doc in docs:
        if doc['title'] != "":
            title = f"{doc['title']}:"
        else:
            title = ""
        doc_prompt += f"{title}{doc['text']}\n"
    doc_prompt = "Docs:" + doc_prompt
    return doc_prompt

def get_system_prompt(with_cred):
    # cred有不同的指示，但和我预想的不太一样
    if with_cred:
        return "You are an assistant who can answer questions based on the given passages. Each passage has a credibility score that indicates the relevance and accuracy of the passage to the question. Your answer need to combine multiple passages and their credibility."
    else:
        return "You're a helpful AI assistant. The assistant answers questions based on given passages.\n"

def get_shots(args):
    with open(args.demos, 'r') as f_shot:
        shots = f_shot.read()
    return shots

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--f", type=str, help="Input data file")
    parser.add_argument("--config", type=str, help="Configuration file")
    # 这两个参数什么意思：CCI的正数是个数，负数是百分数。
    parser.add_argument("--CTI", type=int, default=1,
                        help="CTI filtering strategy: How many standard deviations over average")
    parser.add_argument("--CCI", type=int, default=-5,
                        help="CCI filtering strategy: Top k if k > 0; Top (-k)% if k < 0")

    parser.add_argument("--seed", type=int, default=42, help="Seed for random stuffs")
    parser.add_argument("--at_most_citations", type=int, default=3,
                        help="At most take this many documents (mostly for precision)")

    parser.add_argument("--f_with_ans", action="store_true",
                        help="Whether input data file already has LLM generations.")
    parser.add_argument("--only_cite", action="store_true",
                        help="Only re-generate citations with new CTI and CCI thresholds")

    parser.add_argument("--save_dir", type=str, help="output data dir")
    parser.add_argument("--internal_dir", type=str, help="internal data dir")
    parser.add_argument("--zero_shot", action="store_true")
    args = parser.parse_args()
    config = yaml.safe_load(open(args.config)) if args.config is not None else {}
    parser.set_defaults(**config)
    args = parser.parse_args()

    if args.only_cite:
        assert args.f_with_ans, "--only_cite can only used when the input data contains the LLM outputs, namely setting --f_with_ans"

    np.random.seed(args.seed)

    # CTI and CCI parameters
    topk_CTI = args.CTI
    # topk_CTI = 1 # 1 means over average+1SD
    # topk_CTI = 0 # 0 means over average

    topk_CCI = args.CCI
    # topk_CCI = -5 # -5 means range top5%
    # topk_CCI = 3 # 3 means top 3
    # topk_CCI = 0 # 0 means average (not used)

    cite_idx_acs = False  # whether MIRAGE citations in ascending order

    model, tokenizer = load_model(args.model)
    #data = json.load(open(args.f))
    data = []
    print(args.f)
    with open(args.f, "r") as f:
        for line in f:
            data.append(json.loads(line))

    print(len(data))
    if not args.f_with_ans:
        prefix = args.model.lower().replace('/', '_') + "-" + args.f.split("/")[-1].split(".")[0] + "-" + \
                 args.config.split("/")[-1].split(".")[0] + '-seed' + str(args.seed)
    else:
        prefix = "".join(args.f.split("/")[-1].split(".")[:-1])
    print(prefix)
    # First, generate and save LLM generation
    # If already have LLM generation
    # 处理回车，或者生成答案，可以不要第一步
    # 这个代码中都是先生成然后PECORE
    if args.f_with_ans:
        for idx, item in enumerate(tqdm(data)):
            item['output'] = item['output'].strip()
            for i in range(10):
                r_tmp = "\n" * (10 - i)
                item['output'] = item['output'].replace(r_tmp, " ")
    else:
        # 肯定要修改的
        '''
        for idx, item in enumerate(tqdm(data)):
            doc_list = item['docs']
            input_context_text = "".join(
                [make_doc_prompt(doc, doc_id, args.doc_prompt, use_shorter=None) for doc_id, doc in
                 enumerate(doc_list)])
            input_current_text = item['question']
            input_template = args.demo_prompt.replace("{INST}", args.instruction).replace("{Q}", "{current}").replace(
                "{A}</s>", "").replace("{A}", "").replace("{D}", "{context}").rstrip()

            prompt = input_template.replace("{current}", input_current_text).replace("{context}", input_context_text)
            prompt_len = len(tokenizer.tokenize(prompt))
            item['output'] = generate_answer(prompt, model, tokenizer,
                                             min(args.max_new_tokens, args.max_length - prompt_len), args.temperature,
                                             args.top_p)

            item['output'] = item['output'].strip()
            for i in range(10):
                r_tmp = "\n" * (10 - i)
                item['output'] = item['output'].replace(r_tmp, " ")
        '''
        exit(-1)



        if not os.path.exists("data_input_with_ans"):
            os.makedirs("data_input_with_ans")
        json.dump(data, open("data_input_with_ans/" + prefix + ".json", "w"), indent=4)

    # Second, analyze model internals with MIRAGE
    #save_dir_mirage = './test_internal_res/'
    save_dir_mirage = args.internal_dir
    if not os.path.exists(save_dir_mirage):
        os.makedirs(save_dir_mirage)

    print("key start:")
    blows = []
    if not args.only_cite:
        # 核心：计算中间状态
        # Load model
        model_mirage = inseq.load_model(
            model,
            "saliency",
            model_kwargs={"device_map": 'cuda:3', "torch_dtype": torch.float16},
            tokenizer_kwargs={"use_fast": False},
        )
        if "qwen" in args.config :
            model_mirage.tokenizer.add_special_tokens({'pad_token': model_mirage.tokenizer.eos_token})
        if "3" in args.config:
            model_mirage.tokenizer.unk_token_id = model_mirage.tokenizer.eos_token_id
        stop = []
        stop_token_ids = list(
            set([tokenizer._convert_token_to_id(stop_token) for stop_token in stop] + [model.config.eos_token_id]))
        if tokenizer.unk_token_id in stop_token_ids:
            stop_token_ids.remove(tokenizer.unk_token_id)
        print(stop_token_ids)
        decoder_input_output_separator = ' '
        special_tokens_to_keep = []
        if "zephyr" in args.model.lower():
            decoder_input_output_separator = '\n '
            special_tokens_to_keep = ["</s>"]

        num_empty = 0
        #bugs = [65, 90, 102, 119, 193, 264, 273, 319, 320, 321, 322, 323, 324, 325, 326, 328, 329, 334, 335, 343, 345, 346, 354, 360, 368, 371, 373, 374, 377, 379, 395, 396, 415, 428, 429, 480, 482, 483, 492]

        for idx, item in enumerate(tqdm(data)):
            #print(item)
            if item["output"] == "":
                num_empty += 1
                continue
            #if idx not in bugs:
            #    continue
            #else:
            #    print(idx)
            input_context_text = make_doc_prompt(item)
            input_current_text = item['question']
            #input_template = args.demo_prompt.replace("{INST}", args.instruction).replace("{Q}", "{current}").replace(
             #   "{A}</s>", "").replace("{A}", "").replace("{D}", "{context}").rstrip()
            if args.zero_shot:
                shots = ""
            else:
                shots = get_shots(args)
            input_template = get_system_prompt(args.with_cred) + shots + "\n\n" + \
                             item['template'].replace("{question}","{current}").replace("{doc_prompt}","{context}").rstrip()
            contextless_input_current_text = input_template.replace("Docs:{context}\n", "")
            output_current_text = item["output"]

            save_path = save_dir_mirage + prefix + '-' + str(idx) + '.json'
            lm_rag_prompting_example = AttributeContextArgs(
                model_name_or_path=args.model,
                input_context_text=input_context_text,
                input_current_text=input_current_text,
                output_template="{current}",
                input_template=input_template,
                contextless_input_current_text=contextless_input_current_text,
                show_intermediate_outputs=False,
                attributed_fn="contrast_prob_diff",
                context_sensitivity_std_threshold=0,
                output_current_text=output_current_text,
                attribution_method="saliency",
                attribution_kwargs={"logprob": True},
                save_path=save_path,
                tokenizer_kwargs={"use_fast": False},
                model_kwargs={
                    "device_map": 'auto',
                    "torch_dtype": torch.float16,
                    "max_memory": get_max_memory(),
                    "load_in_8bit": False,
                },
                generation_kwargs={
                    "do_sample": True,
                    "temperature": args.temperature,
                    "top_p": args.top_p,
                    "max_new_tokens": args.max_new_tokens,
                    "num_return_sequences": 1,
                    "eos_token_id": stop_token_ids,
                },
                decoder_input_output_separator=decoder_input_output_separator,
                special_tokens_to_keep=special_tokens_to_keep,
                show_viz=False,
            )

            try:
                gen = attribute_context_with_model(lm_rag_prompting_example, model_mirage)
                if idx < 3:
                    print(lm_rag_prompting_example)
                    print(gen)
            except Exception as e:
                print(f"{e} in {idx}.")
                print(item)
                blows.append(idx)




        print(blows)
    # Load the tokenizer
    # 进行cite，
    #if args.only_cite:
    #    blows = [12, 14, 23, 24, 25, 33, 38, 39, 46, 47, 48, 49, 51, 52, 54, 58, 64, 68, 72, 77, 89, 97, 99, 106, 109, 113, 119, 120, 121, 122, 126, 133, 136, 146, 147, 151, 152, 155, 159, 163, 168, 169, 178, 180, 184, 190, 193, 194, 202, 203, 205, 211, 212, 215, 220, 221, 228, 230, 246, 248, 252, 259, 264, 267, 273, 278, 282, 289, 294, 296, 300, 308, 310, 313, 315, 317, 318, 331, 335, 338, 339, 340, 344, 347, 350, 361, 373, 376, 378, 380, 389, 390, 395, 401, 404, 414, 418, 423, 429, 433, 442, 451, 452, 457, 472, 473, 474, 480, 481, 489, 492, 494, 497]

    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    tokenizer.padding_side = "left"

    new_data = []
    num_empty = 0
    for idx, item in enumerate(tqdm(data)):
        if item["output"] == "":
            new_data.append(item)
            num_empty += 1
            continue
        if idx in blows:
            continue

        item["output"] = item["output"].strip()
        output = item["output"]

        # read MIRAGE json results
        read_path = save_dir_mirage + prefix + '-' + str(idx) + '.json'
        with open(read_path) as r:
            res_mirage = json.load(r)

        if topk_CTI >= 0:
            cti_threshold = np.mean(res_mirage["cti_scores"]) + topk_CTI * np.std(res_mirage["cti_scores"])
        else:
            raise ValueError('CTI filtering parameter should be equal or larger than 0.')

        sents = sent_tokenize(output)
        # check num and index of '\n' in the retrieved docs (i.e. <0x0A> in Llama, zephyr, mistral)
        # e.g. num should constantly be 5 on ELI5
        doc_seps = np.array(res_mirage["input_context_tokens"])
        if "qwen" in args.config or "3" in args.config:
            doc_seps = doc_seps == ".\u010a"
            num_doc = pd.value_counts(res_mirage["input_context_tokens"])[".\u010a"]
        else:
            doc_seps = doc_seps == '<0x0A>'
            num_doc = pd.value_counts(res_mirage["input_context_tokens"])['<0x0A>']

        print(num_doc)

        new_output = ""
        start_pos_sent = 0
        end_pos_sent = 0
        item["cite_info"] = []
        for sent in sents:
            # e.g. original citation index: [1,3,4]
            o_sent = sent
            original_ref = [int(r[1:]) - 1 for r in re.findall(r"\[\d+", sent)]
            end_pos_sent = start_pos_sent + len(tokenizer.tokenize(sent))

            # e.g. Filtered CCI values for each doc, e.g. [0, 0, 20, 3, 0]; always length == num_doc
            cite_result_mirage = mirage_cite(res_mirage, cti_threshold, start_pos_sent, end_pos_sent, topk_CCI,
                                             doc_seps)
            print(cite_result_mirage)

            start_pos_sent = end_pos_sent

            if len(cite_result_mirage) >= 0:
                # print("\n-----")
                sent = remove_citations(sent)

                best_doc_id_tmp = {i: v for i, v in enumerate(cite_result_mirage) if v}
                best_doc_id = list(dict(sorted(best_doc_id_tmp.items(), key=lambda item: item[1], reverse=True)).keys())
                best_doc_id = best_doc_id[: min(args.at_most_citations, len(best_doc_id))]

                if cite_idx_acs:
                    best_doc_id = sorted(best_doc_id)

                best_doc_id_str = ""
                for i in best_doc_id:
                    best_doc_id_str += "[" + str(i + 1) + "]"
                sent = best_doc_id_str + " " + sent
                item["cite_info"].append((o_sent, cite_result_mirage, best_doc_id))

            new_output += sent + " "

        item['cited_output'] = new_output.rstrip().rstrip(",")
        print("\n-----")
        print("Output with MIRAGE AA:" + item['cited_output'])
        new_data.append(item)

    print("num_empty:")
    print(num_empty)
    data = new_data

    tag = f".mirage"
    tag += f"_CTI_{topk_CTI}"
    tag += f"_CCI_{topk_CCI}"

    if cite_idx_acs:
        tag += '_acs'

    save_dir_AA = args.save_dir
    if not os.path.exists(save_dir_AA):
        os.makedirs(save_dir_AA)
    json.dump(data, open(save_dir_AA + prefix + f"{tag}.json", 'w'), indent=4)





if __name__ == "__main__":
    main()