import os
import sys
import logging
import tarfile
import pickle
import time
import copy
import torch
import json
import yaml
import tqdm
import dill 
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from nltk.corpus import stopwords

from datetime import datetime

from datasets import load_dataset
from functools import wraps
from pathlib import Path
from lib.models import ModelAndTokenizer
from lib import nethook


with open("lib/globals.yml", "r") as stream:
    data = yaml.safe_load(stream)

(RESULTS_DIR, DATA_DIR, STATS_DIR, HPARAMS_DIR,) = (
    Path(z)
    for z in [
        data["RESULTS_DIR"],
        data["DATA_DIR"],
        data["STATS_DIR"],
        data["HPARAMS_DIR"],
    ]
)
REMOTE_ROOT_URL = data["REMOTE_ROOT_URL"]
REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/known_1000.json"



def set_utils(args):

    os.makedirs("jobs", exist_ok=True)

    # Save Folder
    if "debug" not in args.save_root:
        if os.path.isdir(args.save_root) is True:
            print("Check your path!")
            import pdb; pdb.set_trace()

    curr_time = datetime.now().strftime('%Y%m%d_%H%M')
    spt = args.save_root.split("/")

    args.save_root = "{}/{}_{}_{}".format(spt[0], curr_time, spt[1], args.model if "/" not in args.model else args.model.split("/")[-1])
    if args.reverse:
        args.save_root += '_rev'
    if args.dataset_type != "rome_factual_1000":
        args.save_root += ("_" + args.dataset_type)
    if args.except_stopword:
        args.save_root += '_exStWd'
    if args.target_sample_index is not None:
        args.save_root += '_R{}'.format(args.target_sample_index)
    if args.target_sample_index_list is not None:
        spt = args.target_sample_index_list.split(", ")
        args.save_root += '_R{}-{}'.format(spt[0], spt[-1])
        
    os.makedirs(args.save_root, exist_ok=True)

    # Logging
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    file_handler = logging.FileHandler(os.path.join(args.save_root, "log.txt"))
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # Timer
    if "debug" in args.save_root:
        func_time_saver = function_time_saver()
        current_dir = os.path.abspath(os.getcwd())
        imported_modules = {}
        for name, module in sys.modules.items():
            if module is None:
                continue
            if getattr(module, '__file__', '') is None:
                continue
            if module and getattr(module, '__file__', '').startswith(current_dir):
                imported_modules.update({name: module})
        for module_name, module in imported_modules.items():
            for name, func in vars(module).items():
                if callable(func) and func.__module__ == module_name and not getattr(func, '__is_decorated__', False):
                    if func.__name__ == "measure_time_and_memory":
                        continue
                    decorated_func = measure_time_and_memory(logger=logger, func_time_saver=func_time_saver)(measure_time_and_memory(logger=logger, func_time_saver=func_time_saver)(func))
                    setattr(module, name, decorated_func)

    # Code Backup
    tar = tarfile.open( os.path.join(args.save_root, 'sources.tar'), 'w' )
    curr_file = os.listdir(os.getcwd())
    curr_file = [i for i in curr_file if ".py" in i]
    curr_file = [tar.add(i) for i in curr_file if os.path.isdir(i) is False]
    tar.add("lib")
    tar.close()

    # Arg Backup
    with open(os.path.join(args.save_root,'args.txt'), 'w') as f:
        json.dump(dict(vars(args)), f, indent=2)


    args.save_result_root = os.path.join(args.save_root, "results")
    args.save_inpinfo_root = os.path.join(args.save_root, "inp_info")
    os.makedirs(args.save_result_root, exist_ok=True)
    os.makedirs(args.save_inpinfo_root, exist_ok=True)

    return logger, func_time_saver


def convert2graph(paths, save_path=None, debug=False):
    if save_path is None:
        save_path = "debug{}.png"
    if debug:
        save_path = "z_" + save_path.split("/")[-1]
    node_name_temp = '${}_{{{}}}^{{{}}}$'

    node_type = ["IA", "Q", "K", "V", "KH", "VH", "A", "OA", "IM", "OM", "OB"]
    node_type_position = {
        "IA": (0, 3),
        "Q" : (1, 0),
        "K" : (1, 1),
        "V" : (1, 2),
        "KH": (2, 1),
        "VH": (2, 2),
        "A" : (3, 1),
        "OA": (4, 2),
        "IM": (5, 3),
        "OM": (6, 2),
        "OB": (7, 3)
    }
    xys = np.asarray([[v[0], v[1]] for k, v in node_type_position.items()])
    x_size = xys[:, 0].max()-xys[:, 0].min() +1
    y_size = xys[:, 1].max()-xys[:, 1].min() +1
    

    min_max_bidx = []
    min_max_tidx = []
    for path_idx, path in paths.items():
        [i[1][0].split("_") for i in path]

        buffer_bidx = []
        buffer_tidx = []
        for edge in path:
            if len(edge[1])!=1:
                print("There cannot exist multiple dst nodes!")
                import pdb; pdb.set_trace()
            
            srcs = edge[0]
            if srcs[0]!="ROOT":
                for src in srcs:
                    s_br, s_ti, s_bi = src.split("_")
                    buffer_bidx.append(s_bi)
                    buffer_tidx.append(s_ti)
            d_br, d_ti, d_bi = edge[1][0].split("_")
            buffer_bidx.append(d_bi)
            buffer_tidx.append(d_ti)
        plot_target_bidx = np.unique(buffer_bidx).astype(int)
        plot_target_tidx = np.unique(buffer_tidx).astype(int)
        
        min_max_bidx.append([plot_target_bidx.min(), plot_target_bidx.max()])
        min_max_tidx.append([plot_target_tidx.min(), plot_target_tidx.max()])
    min_max_bidx = np.asarray(min_max_bidx)
    min_max_tidx = np.asarray(min_max_tidx)

    min_bidx = min_max_bidx[:,0].min()
    max_bidx = min_max_bidx[:,1].max()

    min_tidx = min_max_tidx[:,0].min()
    max_tidx = min_max_tidx[:,1].max()

    total_node_pos = {}
    for path_idx, path in paths.items():
        node_pos = []
        total_node_positions = {}
        G = nx.DiGraph()
        for block_idx in range(min_bidx, max_bidx+1):
            for token_idx in range(min_tidx, max_tidx+1):
                for curr_n_type in node_type:
                    locals()["{}_{}_{}".format(curr_n_type, block_idx, token_idx)] = node_name_temp.format(curr_n_type, block_idx, token_idx)
                    G.add_node(locals()["{}_{}_{}".format(curr_n_type, block_idx, token_idx)])
                    base_position = node_type_position[curr_n_type]

                    rel_x_idx = block_idx - min_bidx
                    rel_y_idx = token_idx - min_tidx
                    curr_node_position = (
                        base_position[0] + rel_x_idx * x_size, 
                        base_position[1] + rel_y_idx * y_size)
                    total_node_positions[locals()["{}_{}_{}".format(curr_n_type, block_idx, token_idx)]] = curr_node_position
                    node_pos.append(curr_node_position)
        total_node_pos.update({path_idx: node_pos})
        for edge in path:
            srcs, dst = edge
            if len(dst)!=1:
                print("There cannot exist multiple dst nodes!")
                import pdb; pdb.set_trace()
            dst = dst[0]
            d_br, d_ti, d_bi = dst.split("_")
            dst_node = node_name_temp.format(d_br, d_bi, d_ti)
            
            if srcs[0] == "ROOT":
                continue
            for src in srcs:
                s_br, s_ti, s_bi = src.split("_")
                src_node = node_name_temp.format(s_br, s_bi, s_ti)
                src_dst = (src_node, dst_node)
                G.add_edge(*src_dst)
            
        plt.figure()
        nx.draw(G, 
                pos=total_node_positions, 
                with_labels=True, 
                labels={node: node for node in G.nodes()}, 
                node_size=800, node_color='lightblue', font_size=20, arrowsize=10)

        for token_idx in range(min_tidx, max_tidx+2): 
            rel_y_idx = token_idx - min_tidx
            plt.axhline(
                y=-0.5+y_size*rel_y_idx,
                xmin=-0.5, 
                xmax=x_size*(max_bidx+1-min_bidx)+0.5, 
                color='gray', 
                linestyle='-', 
                linewidth=1)
        for block_idx in range(min_bidx, max_bidx+2): 
            rel_x_idx = block_idx - min_bidx
            plt.axvline(
                x=-0.5+x_size*rel_x_idx, 
                ymin=-0.5, 
                ymax=y_size*(max_tidx+1-min_tidx)+0.5, 
                color='gray', 
                linestyle='-', 
                linewidth=1)
            
        for block_idx in range(min_bidx, max_bidx+1):
            for token_idx in range(min_tidx, max_tidx+1):
                rel_x_idx = block_idx - min_bidx
                rel_y_idx = token_idx - min_tidx
                plt.text(
                    rel_x_idx * x_size + 5, 
                    rel_y_idx * y_size, 
                    '{:2d}th Block {:2d}th Token'.format(block_idx, token_idx), fontweight='bold', fontsize=20, ha="center")

        plt.gca().set_aspect('equal', adjustable='box')
        plt.xlim(-1, x_size*(max_bidx+1-min_bidx)+1)
        plt.ylim(-1, y_size*(max_tidx+1-min_tidx)+0.5)

        fig = plt.gcf() 
        figsize = fig.get_size_inches()
        scale_factor = max((max_tidx+1-min_tidx), (max_bidx+1-min_bidx))
        new_figsize = figsize * scale_factor
        new_figsize[0] = new_figsize[0] * 1.25
        fig.set_size_inches(new_figsize)

        if debug:
            if os.path.isfile(save_path.format(path_idx)):
                os.remove(save_path.format(path_idx))
        plt.savefig(save_path.format(path_idx))
        plt.close()

def exclude_subsets(candidate, exclude_target):
    if len(exclude_target)==0:
        return candidate
    new_candidate = []
    for cand in candidate:
        exclude_flag = False
        for ex in exclude_target:
            # ex is always smaller than cand, if the search proceeds with small number of steps.
            if len(set(ex).union(cand))==len(cand):
                exclude_flag = True
                break
        if exclude_flag is False:
            new_candidate.append(cand)
    return new_candidate

def get_data(args, data_path="known_1000.json", reverse=False):
    if args.dataset_type=="rome_factual_1000":
        line_data = KnownsDataset(DATA_DIR, reverse=reverse)
    elif args.dataset_type=="lama_trex":
        line_data = LamatrexDataset(DATA_DIR, reverse=reverse)
    elif args.dataset_type=="sst2":
        # ds = load_dataset("stanfordnlp/sst2")
        # few_shot = "Classify the following sentence as Positive or Negative, {}, "
        # mt = get_model(model_name=args.model, model_save=args.model_save)
        
        # anss = []
        # for i in tqdm.tqdm(range(len(ds["validation"]['sentence']))):
        #     normal_inp = make_inputs(mt.tokenizer, [ few_shot.format(ds["validation"]['sentence'][i])] * (1))
        #     outputs_exp = mt.model(**normal_inp)
        #     scores_normal = torch.softmax(outputs_exp.logits[:, -1, :], dim=1)[0]
        #     answer_t = torch.max(scores_normal, dim=0).indices.unsqueeze(0)
        #     answer = decode_tokens(mt.tokenizer, answer_t)
        #     anss.extend(answer)
        
        # dataset = load_dataset("trec", split="test")
        # dataset['text']
        
        # mt = get_model(model_name=args.model, model_save=args.model_save)
        
        # anss = []
        # for i in tqdm.tqdm(range(len(dataset['text']))):
        #     normal_inp = make_inputs(mt.tokenizer, [ dataset['text'][i]] * (1))
        #     outputs_exp = mt.model(**normal_inp)
        #     scores_normal = torch.softmax(outputs_exp.logits[:, -1, :], dim=1)[0]
        #     answer_t = torch.max(scores_normal, dim=0).indices.unsqueeze(0)
        #     answer = decode_tokens(mt.tokenizer, answer_t)
        #     anss.extend(answer)
        

        import pdb; pdb.set_trace()
        
    # with open(data_path, 'r', encoding='utf-8') as file:
    #     data = [json.loads(line) for line in file]
    # line_data = sorted([i["truncated_input"] + " " + i["output"]+"." for i in data])
    # return line_data
    return line_data

def get_model(args):
    
    mt = ModelAndTokenizer(
        args.model,
        low_cpu_mem_usage=False,
        torch_dtype=(torch.float16 if "20b" in args.model else None),
    )
    if args.model_save is not None:
        os.makedirs(args.model_save, exist_ok=True)
        curr_model_path = os.path.join(args.model_save, args.model.replace("/", "_"))
        os.makedirs(curr_model_path, exist_ok=True)
        torch.save(mt.model.state_dict(), os.path.join(curr_model_path, "model.pth"))
        mt.tokenizer.save_pretrained(os.path.join(curr_model_path, "tokenizer"))
    
    if args.except_stopword:
        get_stopwords(args, mt)
        
    return mt


def get_noise_level(args, mt, fpath_nlv="./knowns_n_lv.pt"):
    if args.dataset_type=="rome_factual_1000":
        if os.path.isfile(fpath_nlv) is False:
            knowns = KnownsDataset(DATA_DIR)  # Dataset of known facts
            noise_level = collect_embedding_std(mt, [k["subject"] for k in knowns])
            torch.save(noise_level, fpath_nlv)
        else:
            noise_level = torch.load(fpath_nlv)
    elif args.dataset_type=="lama_trex":
        fpath_nlv = "./lama_trex_n_lv.pt"
        if os.path.isfile(fpath_nlv) is False:
            knowns = LamatrexDataset(DATA_DIR)  # Dataset of known facts
            noise_level = collect_embedding_std(mt, [k["subject"] for k in knowns])
            torch.save(noise_level, fpath_nlv)
        else:
            noise_level = torch.load(fpath_nlv)
    elif args.dataset_type=="sst2":
        import pdb; pdb.set_trace()
    return noise_level

def collect_embedding_std(mt, subjects):
    alldata = []
    for s in tqdm.tqdm(subjects, desc="collecting_std.."):
        inp = make_inputs(mt.tokenizer, [s])
        with nethook.Trace(mt.model, layername(mt.model, 0, "embed")) as t:
            mt.model(**inp)
            alldata.append(t.output[0])
    alldata = torch.cat(alldata)
    noise_level = alldata.std().item()
    return noise_level

def layername(model, num, kind=None, check_subl_inp=False, verf_type="all"):
    if hasattr(model, "transformer"):
        if kind == "embed":
            return "transformer.wte"
        
        name = f'transformer.h.{num}'
        if verf_type=="all":
            if check_subl_inp:
                if kind == "mlp":
                    ln_name = "ln_2"
                elif kind == "attn":
                    ln_name = "ln_1"
                else:
                    import pdb; pdb.set_trace()
                name += "" if kind is None else "." + ln_name
            else:
                name += "" if kind is None else "." + kind
        else:
            if check_subl_inp:
                print("Check!! -> check_subl_inp option cannot operate!!")
            if verf_type=="attn":
                if "attn_" in kind:
                    name+=".ln_1"
                else:
                    name += "" if kind is None else "." + kind
            else:
                import pdb; pdb.set_trace()
        return name
    if hasattr(model, "gpt_neox"):
        if verf_type != "all":
            import pdb; pdb.set_trace()
        if kind == "embed":
            return "gpt_neox.embed_in"
        if kind == "attn":
            kind = "attention"
        return f'gpt_neox.layers.{num}{"" if kind is None else "." + kind}'
    assert False, "unknown transformer structure"


def decode_tokens(tokenizer, token_array):
    if hasattr(token_array, "shape") and len(token_array.shape) > 1:
        return [decode_tokens(tokenizer, row) for row in token_array]
    return [tokenizer.decode([t]) for t in token_array]


def find_path_set_idx(path_set, B):
    connected_idx = []
    for set_idx, set_edges in path_set.items():
        flag = False
        for set_edge in set_edges:
            if set_edge[0] == B.tolist():
                flag = True 
                break
        if flag:
            connected_idx.append(set_idx)

    return connected_idx


def predict_from_input(model, inp, multipred=False, end_symbol=[], mt=None, force_idx=None, use_mean=False, stwd_mask=None):
    # multipred option makes outputs until resulting one sentence output
    # Don't reduce the batch size...: https://discuss.pytorch.org/t/why-is-the-output-of-a-linear-layer-different-when-the-batch-size-is-1/93515
    if multipred is False:
        out = model(**inp)["logits"]
        if use_mean:
            probs = torch.softmax(out[:, -1], dim=1).mean(dim=0).unsqueeze(0)
        else:
            probs = torch.softmax(out[:, -1], dim=1)
        
        if (stwd_mask is not None):
            if force_idx is None:
                desc_idx = torch.argsort(probs, dim=1, descending=True)
                sorted_stwd_mask = stwd_mask[desc_idx]
                preds = desc_idx[sorted_stwd_mask][0]
                p = probs[:, preds][0]
            else: 
                import pdb; pdb.set_trace()
        else:   
            if force_idx is None:
                p, preds = torch.max(probs, dim=1)
            else:
                p = probs[:, force_idx]
                preds = force_idx.repeat(p.shape[0])
        return preds.unsqueeze(0), p.unsqueeze(0), out
    else:
        iter_inp = copy.deepcopy(inp)
        # import pdb; pdb.set_trace()
        iter_p = []
        iter_preds = []

        while 1:
            out = model(**iter_inp)["logits"]
            probs = torch.softmax(out[:, -1], dim=1)
            if use_mean:
                import pdb; pdb.set_trace()

            if force_idx is None:
                p, preds = torch.max(probs, dim=1)
            else:
                import pdb; pdb.set_trace()

            iter_p.append(p)
            iter_preds.append(preds)
            next_token = preds.unsqueeze(-1)
            buffer_ids = torch.cat((iter_inp["input_ids"], next_token), dim=1)

            if (torch.unique(iter_inp["attention_mask"]).shape[0]==1) is False:
                import pdb; pdb.set_trace()
            buffer_mask = torch.cat((iter_inp["attention_mask"], iter_inp["attention_mask"][:, -1].unsqueeze(-1)), dim=-1)
            iter_inp = {"input_ids":buffer_ids, "attention_mask":buffer_mask}

            if torch.unique(next_token).shape[0] != 1:
                import pdb; pdb.set_trace()
            currsym = mt.tokenizer.decode(next_token[0], skip_special_tokens=True, clean_up_tokenization=True)
            if currsym in end_symbol:
                break
        iter_preds = torch.vstack(iter_preds)
        iter_p = torch.vstack(iter_p)
        print("ToDo: check!")
        import pdb; pdb.set_trace()
        return iter_preds, iter_p
    

def predict_from_normal_and_noise_input(
        prompt, flow_tracer, y, logger,
        mt, n_lev, num_noise_sample=3, num_normal_sample=3, noise_type="other",
        end_symbol=[".", "?"], out_num=1, correct_check_only=False):
    
    # num_noise_sample and num_normal_sample should be same.
    # their outputs are slightly different (it maybe the batch-wise operation in hardware-level.)
    
    normal_inp = make_inputs(mt.tokenizer, [prompt] * (num_normal_sample))
    # noise_inp = make_inputs(mt.tokenizer, [prompt] * (num_noise_sample))

    if out_num==1:
        answer_t = flow_tracer.trace_normal(mt.model, normal_inp)
        answer = decode_tokens(mt.tokenizer, answer_t)
        if correct_check_only:
            return y in answer[0]
        if y not in answer[0]:
            return -1, -1, -1, -1, -1
        noise_rand_seed = 0
        while 1:
            corrupted_answer_t, used_token = flow_tracer.trace_corrupted(
                mt.model, mt.tokenizer, prompt, noise_level=n_lev, rand_seed=noise_rand_seed, noise_type=noise_type, num_noise_sample=num_noise_sample)
            
            if answer_t.item() != corrupted_answer_t.item():
                break
            noise_rand_seed += 1
            ### For future debugging
            # corrupted_answer = decode_tokens(mt.tokenizer, corrupted_answer_t)
            # corrupted_token = decode_tokens(mt.tokenizer, used_token)
            logging.info("[Noise Finder] Changing the seed to find noise that changes the output... Current Seed:{}, Ans:{}".format(
                noise_rand_seed, answer[0]))


    
    curr_total_token_num = normal_inp["input_ids"].shape[-1]
    
    if hasattr(mt.model, "transformer"):
        curr_total_block_num = len(mt.model.transformer.h)
    elif hasattr(mt.model, "gpt_neox"):
        curr_total_block_num = len(mt.model.gpt_neox.layers)
    
    return flow_tracer, answer, curr_total_token_num, curr_total_block_num, normal_inp


def load_saver(saver, args, curr_meta_arg, except_list=["save_result_root"]):
    if args.load_path_saver is None:
        print("[Not Loaded] Please specify the path!")
    print("[SAVER Load] {}".format(args.load_path_saver))

    load_list = saver.parse_load_saver_file(args)
    
    f = open(os.path.join(args.load_path_saver, load_list[-1]), "rb")
    loaded_data = dill.load(f)

    loaded_args_dict = vars(loaded_data["args"])
    args_dict = vars(args)
    print("\t args is loaded")
    saver.dict_equal_checker(args_dict, loaded_args_dict, warning_template="\t\t {} is different!")

    print("\t curr_meta_arg is loaded")
    saver.dict_equal_checker(curr_meta_arg, loaded_data["curr_meta_arg"])
    curr_meta_arg = loaded_data["curr_meta_arg"]

    print("\t saver is loaded")
    saver.dict_equal_checker(saver.__dict__, loaded_data["saver"])
    saver.set_values(loaded_data["saver"])
    return curr_meta_arg

# Utilities for dealing with tokens
def make_inputs(tokenizer, prompts, device="cuda", pass_encoding=False):
    if pass_encoding is False:
        token_lists = [tokenizer.encode(p) for p in prompts]
    else:
        token_lists = prompts
    maxlen = max(len(t) for t in token_lists)
    if "[PAD]" in tokenizer.all_special_tokens:
        pad_id = tokenizer.all_special_ids[tokenizer.all_special_tokens.index("[PAD]")]
    else:
        pad_id = 0
    input_ids = [[pad_id] * (maxlen - len(t)) + t for t in token_lists]
    # position_ids = [[0] * (maxlen - len(t)) + list(range(len(t))) for t in token_lists]
    attention_mask = [[0] * (maxlen - len(t)) + [1] * len(t) for t in token_lists]
    return dict(
        input_ids=torch.tensor(input_ids).to(device),
        #    position_ids=torch.tensor(position_ids).to(device),
        attention_mask=torch.tensor(attention_mask).to(device),
    )


class KnownsDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir: str,  reverse=False, *args, **kwargs):
        data_dir = Path(data_dir)
        known_loc = data_dir / "known_1000.json"
        if not known_loc.exists():
            print(f"{known_loc} does not exist. Downloading from {REMOTE_URL}")
            data_dir.mkdir(exist_ok=True, parents=True)
            torch.hub.download_url_to_file(REMOTE_URL, known_loc)

        with open(known_loc, "r") as f:
            self.data = json.load(f)
        if reverse:
            self.data = self.data[::-1]
        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]
    

class LamatrexDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir: str,  reverse=False, *args, **kwargs):
        data_dir = Path(data_dir)
        known_loc = data_dir / "lama_trex.json"
        if not known_loc.exists():
            raise Exception

        with open(known_loc, "r") as f:
            self.data = json.load(f)
        if reverse:
            self.data = self.data[::-1]
        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]


class function_time_saver(object):
    def __init__(self):
        self.function_timer = {}
    def save(self, func_name, val_dict):
        if func_name not in self.function_timer:
            self.function_timer[func_name] = {
                "total_time":0,
                "cpu_time":0,
                "gpu_time":0,
                "gpu_mem_u":0,
                "gpu_mem_r":0,
                "func_file_path":val_dict["func_file_path"]
            }
        self.function_timer[func_name]["total_time"] += val_dict["total_time"]
        self.function_timer[func_name]["cpu_time"] += val_dict["cpu_time"]
        self.function_timer[func_name]["gpu_time"] += val_dict["gpu_time"]
        self.function_timer[func_name]["gpu_mem_u"] += val_dict["gpu_mem_u"]
        self.function_timer[func_name]["gpu_mem_r"] += val_dict["gpu_mem_r"]

    def logging(self, logger, header=None):
        keys = list(self.function_timer.keys())
        time_list = [self.function_timer[k]["total_time"] for k in keys]
        sorted_index = np.argsort(time_list)
        sorted_key = np.asarray(keys)[sorted_index]

        if header is not None:
            logger.info(header)
        for k in sorted_key:
            log_message = "\t[Function Log] '{}' => Total time: {:.4f}s (CPU time: {:.4f}s, GPU time: {:.4f}s) | GPU mem: (U) {:.2f}MB (R) {:.2f}MB\n\t\tFunction Path: '{}'".format(
                k,
                self.function_timer[k]["total_time"],
                self.function_timer[k]["cpu_time"],
                self.function_timer[k]["gpu_time"],
                self.function_timer[k]["gpu_mem_u"],
                self.function_timer[k]["gpu_mem_r"],
                self.function_timer[k]["func_file_path"],
            )
            logger.info(log_message)
        logger.info("==============================\n\n")


def measure_time_and_memory(logger=None, func_time_saver=None):
    def decorator(func):
        if getattr(func, '__is_decorated__', False):
            return func
        
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Measure CPU start time
            cpu_start_time = time.time()

            # Measure GPU start time (if CUDA is available)
            if torch.cuda.is_available():
                torch.cuda.synchronize()  # Wait for all CUDA kernels to finish
                gpu_start_time = torch.cuda.Event(enable_timing=True)
                gpu_end_time = torch.cuda.Event(enable_timing=True)
                gpu_start_time.record()
                gpu_start_abs_time = time.time()

                initial_memory = torch.cuda.memory_allocated()
                initial_reserved = torch.cuda.memory_reserved()

            result = func(*args, **kwargs)

            # Measure CPU end time
            cpu_end_time = time.time()
            cpu_elapsed_time = cpu_end_time - cpu_start_time

            # Measure GPU end time (if CUDA is available)
            if torch.cuda.is_available():
                torch.cuda.synchronize()  # Wait for all CUDA kernels to finish
                gpu_end_time.record()
                torch.cuda.synchronize()  # Wait for the events to be recorded
                gpu_end_abs_time = time.time()
                gpu_elapsed_time = gpu_start_time.elapsed_time(gpu_end_time) / 1000.0  # Convert milliseconds to seconds
                final_memory = torch.cuda.memory_allocated()
                final_reserved = torch.cuda.memory_reserved()
                memory_usage = (final_memory - initial_memory) / (1024 * 1024) 
                memory_reserved = (final_reserved - initial_reserved) / (1024 * 1024) 
            else:
                gpu_elapsed_time = None

         
            if gpu_start_abs_time is not None:
                total_start = min(cpu_start_time, gpu_start_abs_time) 
                total_end = max(cpu_end_time, gpu_end_abs_time)
            else:
                total_start = cpu_start_time
                total_end = cpu_end_time
                gpu_elapsed_time = 0.0
                memory_usage = 0.0
                memory_reserved = 0.0
            total_elapsed_time = total_end - total_start

            # Get the absolute path of the function's module
            func_module = sys.modules[func.__module__]
            func_file_path = os.path.abspath(func_module.__file__)

            val_dict = {
                "total_time":total_elapsed_time,
                "cpu_time":cpu_elapsed_time,
                "gpu_time":gpu_elapsed_time,
                "gpu_mem_u":memory_usage,
                "gpu_mem_r":memory_reserved,
                "func_file_path":func_file_path
            }
            func_time_saver.save(func.__name__, val_dict)
            # if logger:
            #     logger.info(log_message)
            # else:
            #     print(log_message)

            return result

        wrapper.__is_decorated__ = True
        return wrapper
    return decorator

def get_stopwords(args, mt):
    stopword_list = stopwords.words('english')
    
    stopword_ids = []
    for stopword in stopword_list:
        token_ids = mt.tokenizer.encode(' '+stopword, add_special_tokens=False)
        if len(token_ids) == 1:
            stopword_ids.append(token_ids[0])
        token_ids = mt.tokenizer.encode(stopword, add_special_tokens=False)
        if len(token_ids) == 1:
            stopword_ids.append(token_ids[0])
            
    args.stwd_ids = sorted(list(set(stopword_ids)))