# -*- coding: UTF-8 -*-"""


# Import modules from other files
from utils.save_result import save_results_interaction, save_results_layers
from utils.check_vocab_tokenizer import (
    get_model
)

from harsanyi.interaction_utils import flatten
from harsanyi.aog_inference import reorganize_and_or_harsanyi
from harsanyi import AndHarsanyi, AndOrHarsanyi, AndOrHarsanyiSparsifier

# Import standard libraries and torch-related packages
from operator import itemgetter
from transformers.utils import logging
from typing import Callable, Union, List, Dict
from transformers import OPTForCausalLM, GPTJForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
import torch
import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import os
import os.path as osp
import matplotlib
import argparse
import torch.nn.functional as F
import gc
import random
import rich
import pickle
import sys
import json
import argparse
import gc
matplotlib.use("Agg")


# Control the display of errors
logging.get_logger("transformers").setLevel(logging.ERROR)

# Set the GPU device order
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


def get_forward_function(model) -> Callable:

    def forward_func(input_ids):
        output_dict = model.generate(
            input_ids,
            max_new_tokens=1,
            output_scores=True,
            return_dict_in_generate=True,
            output_hidden_states=True,
            do_sample=False,
        )

        if False:
            a = model.generate(
                input_ids, max_new_tokens=5, output_scores=True, return_dict_in_generate=True,
                output_hidden_states=True, do_sample=False,
            )["hidden_states"]
            a_len = len(a)

        logits = output_dict["scores"][0]
        # import ipdb; ipdb.set_trace()
        # (layers, 1, tokens, dim)
        hidden_states = output_dict["hidden_states"][0]
        return logits, hidden_states  # logits are the scores output by the model, shape = [1, vocab_size]

    return forward_func


def get_input_ids(prompt, tokenizer, device):  # Tokenize the input prompt
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    return input_ids


def get_baseline_id(tokenizer, model_name) -> int:  # Get the token id for masking
    if "qwen" in model_name or "llama-3" in model_name or "olmo" in model_name or model_name == "llama-moe-v2-3_8b-2_8-sft":
        return tokenizer.pad_token_id
    return tokenizer.unk_token_id


def get_ground_truth_label(model, input_ids):  # Get the ground truth label generated by the model
    """
    Calls the model to generate the Ground Truth (a numerical value) and the sequence.
    """
    output_dict = model.generate(
        input_ids,
        output_scores=True,
        max_new_tokens=1,
        return_dict_in_generate=True,
        do_sample=False
    )
    scores = output_dict["scores"]
    label = torch.argmax(scores[0][0]).item()  # Get the index of the maximum value
    sequence = output_dict["sequences"][0]

    assert isinstance(label, int)
    return label, sequence


def get_mask_input_function_for_llm() -> Callable:

    def mask_input_fn(input_ids, baseline_id, S_list):  # S_list contains the positions to mask; it's good to double-check the dimensions
        assert input_ids.shape[0] == 1
        assert len(input_ids.shape) == 2

        device = input_ids.device
        batch_size = len(S_list)
        n_token = input_ids.shape[1]

        masked_input_ids = input_ids.expand(batch_size, n_token).clone() # Expand if batch_size > 1
        mask = torch.ones_like(
            masked_input_ids, dtype=torch.bool, device=device)

        for i, S in enumerate(S_list):
            mask[i, S] = False  # variables in S are NOT masked

        masked_input_ids[mask] = baseline_id
        return masked_input_ids  # Return the masked input_ids

    return mask_input_fn


def get_player_description(tokenizer, input_ids, players):  # Get player descriptions for visualization
    descriptions = []
    for player in players:
        d = tokenizer.decode(input_ids[0][player]).strip()
        descriptions.append(d)
    return descriptions


def experiment_sparsifier(
    calculator: Dict,
    save_folder: str,
    device: list,  
    sparsify_kwargs: Dict
):
    masks = calculator['masks']
    rewards = calculator['rewards']
    p_N = calculator['p_N']
    p_empty = calculator['p_empty']
    sparsifier = AndOrHarsanyiSparsifier(calculator=calculator, device=device, **sparsify_kwargs)
    sparsifier.sparsify(verbose_folder=osp.join(save_folder, "sparsify_verbose"))
    with torch.no_grad():
        I_and, I_or = sparsifier.get_interaction()
        I_and, I_or = reorganize_and_or_harsanyi(masks, I_and, I_or)
        
    torch.save(masks, osp.join(save_folder, "masks.pth"))
    return masks, I_and, I_or, rewards, p_N, p_empty


# Function to calculate the interaction value I(S)
def experiment_and_or(
    forward_func: Callable,
    selected_dim: str,
    input_ids: torch.LongTensor,
    baseline_id: Union[torch.LongTensor, int],
    label: int,
    all_players: List,
    background: Union[None, List],
    mask_input_fn: Callable,
    calc_bs: int,
    device: list,  
    save_folder: str,
    hf_model
):
    calculator = AndOrHarsanyi(
        model=forward_func, selected_dim=selected_dim,
        x=input_ids, baseline=baseline_id, y=label,
        all_players=all_players, background=background,
        mask_input_fn=mask_input_fn, calc_bs=calc_bs, verbose=1,device=device,hf_model=hf_model
    )
    with torch.no_grad():
        calculator.attribute()
        
    masks = calculator.get_masks()
    rewards = calculator.get_rewards()
    p_N, p_empty = calculator.get_p()
    
    new_calculator = {
        "masks": masks,
        "rewards": rewards,
        "p_N": p_N,
        "p_empty": p_empty,
        "v_N": calculator.v_N,
        "v_empty": calculator.v_empty,
        "reward2Iand": calculator.reward2Iand,
        "reward2Ior": calculator.reward2Ior
    }
       
    return new_calculator
    


def find_max_folder_with_file(experiment_folder):
    # Get all subdirectories in the experiment_folder
    folder_names = [f for f in os.listdir(experiment_folder) if os.path.isdir(os.path.join(experiment_folder, f))]
    
    # Filter for directories with numeric names and sort them
    numeric_folders = []
    for folder_name in folder_names:
        if folder_name.isdigit():
            numeric_folders.append(int(folder_name))
    numeric_folders.sort()
    
    # Check from the largest-numbered folder downwards for the existence of the "I_and.pth" file
    max_folder_with_file = None
    for folder_name in reversed(numeric_folders):
        folder_path = os.path.join(experiment_folder, str(folder_name))
        file_path = os.path.join(folder_path, "I_and.pth")
        if os.path.exists(file_path):
            max_folder_with_file = folder_name
            break
    
    return max_folder_with_file


# FIXME: Return the accuracy for the current template from this function
def template_main(device, model, model_name, tokenizer, forward_func, \
                  experiment_folder, data_list: list, sparsify_kwargs):
    baseline_id = get_baseline_id(tokenizer, model_name)  # Get the ID for the token used for masking
    mask_input_fn = get_mask_input_function_for_llm()  # Get the function used for masking
    
    # # FIXME: Accuracy tracking
    answer_is_correct_flag_list=[]
    
    os.makedirs(experiment_folder, exist_ok=True)
    file_path = osp.join(experiment_folder, 'accuracy.txt')
    mode = 'w' if not osp.exists(file_path) else 'a'
    f_log = open(file_path, mode)

    
    start_index = find_max_folder_with_file(experiment_folder)
    if start_index is not None:
        start_index += 1
    else:
        start_index = 0
    # print("SKIPPING prompt_index != 0 ")
    for prompt_index, data in enumerate(data_list[start_index:]):
        actual_index = start_index + prompt_index
        print(f"Actual Index: {actual_index}")
        
        prompt = data["prompt"]
        all_players = data["all_players"]
        prompt_correct_answer = data["answer"]
        # # TODO: Temporary measure
        # if prompt_index != 0:
        #     continue

        # FIXME: Do not strip trailing newlines anymore
        # prompt = prompt[:-1]
        save_folder = osp.join(experiment_folder, str(actual_index))
        input_ids = get_input_ids(prompt, tokenizer, device[0])  # tokenized input
        # The ID of the model's output label, which is also the index for the selected dimension
        label, sequence = get_ground_truth_label(model, input_ids)
        # FIXME: Check if the model's answer is correct
        
        generate_token = tokenizer.convert_ids_to_tokens(label)[1:]
        answer_is_correct_flag = (
            generate_token in str(prompt_correct_answer))
        print(f"True answer: {prompt_correct_answer}")
        print(f"Generate token: {generate_token}")
        answer_is_correct_flag_list.append(answer_is_correct_flag)
        
        print(f"Prompt {actual_index}: {answer_is_correct_flag}",
              file=f_log, flush=True)
        
        key_players_dict = dict()

        background = [i for i in range(input_ids.shape[-1])
                      if i not in set(flatten(all_players))]
        descriptions = get_player_description(tokenizer, input_ids, all_players)

        print(f"Prompt:%{prompt}%")
        print("all_players:", all_players)
        print("descriptions:", descriptions)
        print("key_players_dict:", key_players_dict)
        sys.stdout.flush()
    
        # >>>> Step 3 -- Explain (Compute Harsanyi interactions)
        os.makedirs(save_folder, exist_ok=True)
        # masks, I_and, v_list, hidden_states, all_sim, layers_Iand
        new_calculator = experiment_and_or(
            forward_func=forward_func,
            selected_dim="gt-log-odds-v0",  # FIXME: Switched the calculation method here
            # selected_dim="gt-v0",
            input_ids=input_ids,
            baseline_id=baseline_id,
            label=label,
            all_players=all_players,
            background=background,
            mask_input_fn=mask_input_fn,
            calc_bs=128,  # For WLL, considering the hidden_size structure, this must be 1
            save_folder=save_folder,
            device=device,
            hf_model=model
        )
        torch.save(descriptions, osp.join(save_folder, "descriptions.pth"))
        torch.save(all_players, osp.join(save_folder, "all_players.pth"))
        torch.save(key_players_dict, osp.join(
            save_folder, "key_players_dict.pth"))

        result = tokenizer.decode(label)
        sequence = tokenizer.decode(sequence)
        torch.save(result, osp.join(save_folder, "label_decode.pth"))
        torch.save(sequence, osp.join(save_folder, "sequence.pth"))
        
        masks, I_and, I_or, rewards, p_N, p_empty= \
            experiment_sparsifier(
                    calculator=new_calculator,
                    save_folder=save_folder,
                    device=device,
                    sparsify_kwargs=sparsify_kwargs
                )

        torch.save(I_and, osp.join(save_folder, "I_and.pth"))
        torch.save(I_or, osp.join(save_folder, "I_or.pth"))
    
        # Variables needed to log the results for this part:
        resultlog_dict = {
            "save_folder": save_folder,
            "model_name": model_name, "prompt": "",
            "label": "", "result": result, "sequence": sequence,
            "description": descriptions, "all_player": all_players, "rewards": rewards,
            "I_and": I_and, "I_or": I_or,"masks": masks}
        save_results_interaction(resultlog_dict)
        del resultlog_dict
    
        clear_cuda_memory()
    if 'model' in locals():
        del model
    if 'tokenizer' in locals():
        del tokenizer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def clear_cuda_memory():
    torch.cuda.empty_cache()
    gc.collect()


def seed_everything(seed: int = 0) -> None:
    """
    Set seeds for all random components to ensure reproducibility.
    """
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # If using dgl, set its seed as well.
    # dgl.seed(seed)


def get_save_path(dataset, template, model_name, save_path):  # Specify the model save path
    # FIXME: Rewrite the experiment path
    s = model_name.replace("_", "-").replace("-hf", "")
    s = s.lower()
    save_folder = os.path.join(save_path, f'saved-results_{dataset}', s, f"template_{template}", "results")
    return save_folder


if __name__ == '__main__':
    seed_everything()
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "expandable_segments:True"

    main_path = './LLM-Sensitivity'
    
    model_names = ["llama-2-7b"]
    
    parser = argparse.ArgumentParser(description="Generate sample")
    parser.add_argument('--dataset', default="mmlu", type=str)
    parser.add_argument('--templates', default=[1,2,3,4,5], type=list)
    parser.add_argument('--device_num0', default=0, type=int, help="gpu")
    parser.add_argument('--device_num1', default=1, type=int, help="gpu")
    parser.add_argument('--qthres', default=0.04, type=float)
    parser.add_argument('--qstd', default="vN-v0", type=str)
    parser.add_argument('--train_lr', default=1e-6, type=float,help="set the learning rate for sparsify.")
    parser.add_argument("--epoch", default=20000, type=int,help='the number of iterations for training model.')
    parser.add_argument('--reward_function', type=str, default='gt-v0')

    parser.add_argument('--softmax', default=1000, help='all or [int]; Number of tokens for softmax; "all" means all tokens are included')
    parser.add_argument('--seed', default=500)

    parser.add_argument('--q_tricks', type=bool, default=True, help=r'Update q = q - \bar(q)')
    parser.add_argument('--piece_wise', type=bool, default=False, help='Double the epochs, use sparsify-loss for the second half, and then use a new sparsify-loss for each update')
    parser.add_argument("--sparsify_loss", default="l1_on_0.99", type=str,
                        help="use which type of loss to sparsify and or interactions l1_on_0.9")

    args = parser.parse_args()
    
    sparsify_kwargs = {"trick": "pq", "loss": args.sparsify_loss, "qthres": args.qthres,
                       "qstd": args.qstd, "lr": args.train_lr, "niter": args.epoch, "q_tricks": args.q_tricks, "piece_wise": args.piece_wise}
    device_num = [args.device_num0, args.device_num1]
    device_num = [0,1,2,3]
    print(f"Device No. {device_num}")
    device = (
        torch.device(f"cuda:{device_num[0]}"),
        torch.device(f"cuda:{device_num[1]}"),
    )
    
    # >>>> Step 2 -- Prepare Model (LLM)
    for model_name in model_names:
        print("Loading model..." + f"{model_name}")
        model, tokenizer = get_model(device_num, model_name)
        forward_func = get_forward_function(model)
        for template in args.templates:
            # >>>> Step 1 --  Prepare Data (Prompt, the sample to explain)
            print("Preparing data...")
            file_path = os.path.join(main_path, 'explain_demo', 'data', args.dataset, "model_data", model_name, f"template_{template}.jsonl")
            with open(file_path, 'r', encoding='utf-8') as file:
                data_list = [json.loads(line.strip()) for line in file]
                
            save_parent_folder = get_save_path(
                dataset = args.dataset, template = template, model_name=model_name,save_path=main_path)
            os.makedirs(save_parent_folder, exist_ok=True)
            save_folder = osp.join(
                save_parent_folder)
            if os.path.exists(save_folder):
                pos = find_max_folder_with_file(save_folder)
                if pos!= None and len(data_list) == pos+1:
                    continue
        
            with open(osp.join(save_parent_folder, 'structure.txt'), 'w') as f:
                print(model, file=f)
            model_accuracies=[]

            

            # FIXME: To calculate the model's accuracy, the corresponding "correct answer" needs to be passed in
            template_main(
                device=device, model=model, model_name=model_name, tokenizer=tokenizer, forward_func=forward_func,
                experiment_folder=save_folder, data_list = data_list, sparsify_kwargs=sparsify_kwargs)
        
        # Release GPU memory after each experiment
        clear_cuda_memory()
        if 'model' in locals():
            del model
        if 'tokenizer' in locals():
            del tokenizer
        if 'forward_func' in locals():
            del forward_func
            
        
        gc.collect() # Attempt to force garbage collection
        if torch.cuda.is_available():
            torch.cuda.empty_cache()