import os
import argparse
import json
import torch
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

from adares_llms.adares_modeling_llama import LlamaForCausalLM
from adares_llms.adares_modeling_qwen2 import Qwen2ForCausalLM
from adares_llms.adares_modeling_phi3 import Phi3ForCausalLM
from adares_llms.adares_modeling_gemma3 import Gemma3ForCausalLM

from utils import now, flatten

CONFIG = {
    "name": "Adaptive Residual for Large Language Models",
    "version": "1.0",

    "llms":{
        "Llama3-8B":{ "path": "Your path" },
        "Qwen2.5-7B":{ "path": "Your path" },
        "Phi3-3.8B":{ "path": "Your path" },
        "Phi3-14B":{ "path": "Your path" },
        "Gemma3-4B":{ "path": "Your path" },
        "Gemma3-12B":{ "path": "Your path" },
    },

    "data":{
        "ZsRE":                  "./data/zsre_edit{}.json",
        "CounterFact":           "./data/counterfact_edit{}.json",
        "ConflictQA-PopQA":      "./data/conflictQA/conflictQA-popQA{}.json",
        "ConflictQA-StrategyQA": "./data/conflictQA/conflictQA-strategyQA{}.json"
    }
}

''' Class of Dataset '''
class ZsRE(Dataset):
    def __init__(self, data_fp: str, args=None) -> None:
        super().__init__()
        self.args = args
        with open(data_fp, 'r', encoding='utf-8') as f:
            self._data = json.load(f)
    
    def __len__(self):
        return len(self._data)
    
    def __getitem__(self, idx):
        return tuple([idx, self._data[idx]])

class CounterFact(Dataset):
    def __init__(self, data_fp: str, args=None):
        self.args = args
        with open(data_fp, "r", encoding='utf-8') as f:
            self._data = json.load(f)
    
    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        return tuple([idx, self._data[idx]])

class ConflictQA_PopQA(Dataset):
    def __init__(self, data_fp: str):
        with open(data_fp, "r", encoding='utf-8') as f:
            self._data = json.load(f)
    
    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        return tuple([idx, self._data[idx]])
    
class ConflictQA_StrategyQA(Dataset):
    def __init__(self, data_fp: str):
        with open(data_fp, "r", encoding='utf-8') as f:
            self._data = json.load(f)
    
    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        return tuple([idx, self._data[idx]])



''' Class of Dataloader '''
class ZsREDataLoader(DataLoader):
    def __init__(self, dataset:ZsRE, batch_size, shuffle=True, args=None,):
        super(ZsREDataLoader, self).__init__(
            dataset=dataset, batch_size=batch_size,
            collate_fn=self.zsre_collate_fn, shuffle=shuffle,
        )
        self.instruct = 'Answer the question based on the given triplet knowledge, and give the answer directly. '
        self.args = args
    
    def zsre_collate_fn(self, batch):
        batch_idx = [b[0] for b in batch]
        # case_ids = [b[1]["case_id"] for b in batch]
        triples = [self.instruct+'Triple Knowledge: '+b[1]["triples"] for b in batch] if self.args.guide else\
                  [b[1]["triples"] for b in batch]
        #
        edit_queries = ['Question: '+b[1]["prompt"][0] for b in batch] if self.args.guide else\
                       [b[1]["prompt"][0] for b in batch]
        edit_prompts = list(map(lambda x, y: x+' '+y+' Answer: ', triples, edit_queries)) if self.args.guide else\
                       list(map(lambda x, y: x+' '+y, triples, edit_queries))
        edit_ans = [b[1]["target_new"] for b in batch]
        edit_inputs_wotrp = list(map(lambda x, y: x+' '+y, edit_queries, edit_ans))
        edit_inputs = list(map(lambda x, y: x+' '+y, edit_prompts, edit_ans))
        edit_pred = [b[1]["prompt"][1] for b in batch]
        #
        rep_queries = ['Question: '+b[1]["paraphrase_prompts"][0] for b in batch] if self.args.guide else\
                      [b[1]["paraphrase_prompts"][0] for b in batch]
        rep_prompts = list(map(lambda x, y: x+' '+y+' Answer: ', triples, rep_queries)) if self.args.guide else\
                      list(map(lambda x, y: x+' '+y, triples, rep_queries))
        rep_ans = edit_ans
        rep_inputs_wotrp = list(map(lambda x, y: x+' '+y, rep_queries, edit_ans))
        rep_inputs = list(map(lambda x, y: x+' '+y, rep_prompts, rep_ans))
        rep_pred = [b[1]["paraphrase_prompts"][1] for b in batch]
        #
        loc_queries = [b[1]["locality_prompt"][0] for b in batch]
        loc_ans = [b[1]["locality_ground_truth"] for b in batch]
        loc_prompts = list(map(lambda x, y: x+' '+y, triples, loc_queries))
        loc_inputs_wotrp = list(map(lambda x, y: x+' '+y, loc_queries, loc_ans))
        loc_inputs = list(map(lambda x, y: x+' '+y, loc_prompts, loc_ans))
        loc_pred = [b[1]["locality_prompt"][-1] for b in batch]
        return tuple([batch_idx, triples,
                      edit_queries, rep_queries, loc_queries,
                      edit_prompts, rep_prompts, loc_prompts,
                      edit_ans, rep_ans, loc_ans,
                      edit_inputs_wotrp, rep_inputs_wotrp, loc_inputs_wotrp,
                      edit_inputs, rep_inputs, loc_inputs,
                      edit_pred, rep_pred, loc_pred
        ])

class CounterFactDataLoader(DataLoader):
    def __init__(self, dataset:CounterFact, batch_size, shuffle=True, args=None,):
        super(CounterFactDataLoader, self).__init__(
            dataset=dataset, batch_size=batch_size,
            collate_fn=self.cf_collate_fn, shuffle=shuffle,
        )
        self.instruct = 'Answer the question based on the given triplet knowledge, and give the answer directly. '
        self.args = args
    
    def cf_collate_fn(self, batch):
        batch_idx = [b[0] for b in batch]
        # case_ids = [b[1]["case_id"] for b in batch]
        triples = [self.instruct+'Triple Knowledge: '+b[1]["triples"] for b in batch] if self.args.guide else\
                  [b[1]["triples"] for b in batch]
        #
        edit_queries = ['Question: '+b[1]["prompt"][0] for b in batch] if self.args.guide else\
                       [b[1]["prompt"][0] for b in batch]
        edit_prompts = list(map(lambda x, y: x+' '+y+' Answer: ', triples, edit_queries)) if self.args.guide else\
                       list(map(lambda x, y: x+' '+y, triples, edit_queries))
        edit_ans = [b[1]["target_new"] for b in batch]
        edit_inputs_wotrp = list(map(lambda x, y: x+' '+y, edit_queries, edit_ans))
        edit_inputs = list(map(lambda x, y: x+' '+y, edit_prompts, edit_ans))
        edit_pred = [b[1]["prompt"][1] for b in batch]
        #
        rep_queries = ['Question: '+b[1]["rephrase_prompt"][0] for b in batch] if self.args.guide else\
                      [b[1]["rephrase_prompt"][0] for b in batch]
        rep_prompts = list(map(lambda x, y: x+' '+y+' Answer: ', triples, rep_queries)) if self.args.guide else\
                      list(map(lambda x, y: x+' '+y, triples, rep_queries))
        rep_ans = edit_ans
        rep_inputs_wotrp = list(map(lambda x, y: x+' '+y, rep_queries, edit_ans))
        rep_inputs = list(map(lambda x, y: x+' '+y, rep_prompts, rep_ans))
        rep_pred = [b[1]["rephrase_prompt"][1] for b in batch]
        #
        loc_queries = [b[1]["locality_prompt"][0] for b in batch]
        loc_prompts = list(map(lambda x, y: x+' '+y, triples, loc_queries))
        loc_ans = [b[1]["locality_ground_truth"] for b in batch]
        loc_inputs_wotrp = list(map(lambda x, y: x+' '+y, loc_queries, loc_ans))
        loc_inputs = list(map(lambda x, y: x+' '+y, loc_prompts, loc_ans))
        loc_pred = [b[1]["locality_prompt"][-1] for b in batch]
        return tuple([batch_idx, triples,
                      edit_queries, rep_queries, loc_queries,
                      edit_prompts, rep_prompts, loc_prompts,
                      edit_ans, rep_ans, loc_ans,
                      edit_inputs_wotrp, rep_inputs_wotrp, loc_inputs_wotrp,
                      edit_inputs, rep_inputs, loc_inputs,
                      edit_pred, rep_pred, loc_pred
        ])

class ConflictQA_PopQA_DataLoader(DataLoader):
    def __init__(self, dataset:ConflictQA_PopQA, batch_size, shuffle=True):
        super(ConflictQA_PopQA_DataLoader, self).__init__(
            dataset=dataset, batch_size=batch_size,
            collate_fn=self.ConflictQA_popQA_collate_fn, shuffle=shuffle
        )
    
    def ConflictQA_popQA_collate_fn(self, batch):
        # batch_idx = [b[0] for b in batch]
        query = [b[1]["question"] for b in batch]
        #
        memory = [b[1]["parametric_memory"] for b in batch]
        memory_answer = [b[1]["memory_answer"] for b in batch]
        memory_prompts = list(map(lambda x, y: x+' '+y, memory, query))
        memory_inputs = list(map(lambda x, y: x+' '+y, memory_prompts, memory_answer))
        #
        context = [b[1]["counter_memory"] for b in batch]
        context_answer = [b[1]["counter_answer"] for b in batch]
        context_prompts = list(map(lambda x, y: x+' '+y, context, query))
        context_inputs = list(map(lambda x, y: x+' '+y, context_prompts, memory_answer))
        #
        memory_inputs_woctx = list(map(lambda x, y: x+' '+y, query, memory_answer))
        context_inputs_woctx = list(map(lambda x, y: x+' '+y, query, context_answer))
        return tuple([query,
                      memory, context,
                      memory_inputs_woctx, context_inputs_woctx,
                      memory_inputs, memory_prompts,
                      context_inputs, context_prompts
        ])

class ConflictQA_StrategyQA_DataLoader(DataLoader):
    def __init__(self, dataset:ConflictQA_StrategyQA, batch_size, shuffle=True):
        super(ConflictQA_StrategyQA_DataLoader, self).__init__(
            dataset=dataset, batch_size=batch_size,
            collate_fn=self.ConflictQA_strategyQA_collate_fn, shuffle=shuffle
        )
    
    def ConflictQA_strategyQA_collate_fn(self, batch):
        # batch_idx = [b[0] for b in batch]
        query = [b[1]["question"] for b in batch]
        #
        memory = [b[1]["parametric_memory"] for b in batch]
        memory_answer = [b[1]["memory_answer"] for b in batch]
        memory_prompts = list(map(lambda x, y: x+' '+y, memory, query))
        memory_inputs = list(map(lambda x, y: x+' '+y, memory_prompts, memory_answer))
        #
        context = [b[1]["counter_memory"] for b in batch]
        context_answer = [b[1]["counter_answer"] for b in batch]
        context_prompts = list(map(lambda x, y: x+' '+y, context, query))
        context_inputs = list(map(lambda x, y: x+' '+y, context_prompts, memory_answer))
        #
        memory_inputs_woctx = list(map(lambda x, y: x+' '+y, query, memory_answer))
        context_inputs_woctx = list(map(lambda x, y: x+' '+y, query, context_answer))
        return tuple([query,
                      memory, context,
                      memory_inputs_woctx, context_inputs_woctx,
                      memory_inputs, memory_prompts,
                      context_inputs, context_prompts
        ])


class ConflictQA_PpQA_DataLoader(DataLoader):
    def __init__(self, dataset:ConflictQA_PopQA, batch_size, shuffle=True, args=None,):
        super(ConflictQA_popQA_DataLoader, self).__init__(
            dataset=dataset, batch_size=batch_size,
            collate_fn=self.ConflictQA_PopQA_collate_fn, shuffle=shuffle
        )
        self.instruct = 'Answer the question based on the given information, and give the answer directly. '
        self.args = args

    def ConflictQA_PopQA_collate_fn(self, batch):
        # batch_idx = [b[0] for b in batch]
        query = ['Question: '+b[1]["question"] for b in batch] if self.args.guide else\
                [b[1]["question"] for b in batch]
        memory_context = [self.instruct+'Information: '+b[1]["parametric_memory"] for b in batch] if self.args.guide else\
                         [b[1]["parametric_memory"] for b in batch]
        counter_context = [self.instruct+'Information: '+b[1]["counter_memory"] for b in batch] if self.args.guide else\
                          [b[1]["counter_memory"] for b in batch]
        memory_answer = [b[1]["memory_answer"] for b in batch]
        counter_answer = [b[1]["counter_answer"] for b in batch]
        # 
        query_with_memory_ans = list(map(lambda x, y: x+' '+y, query, memory_answer))
        query_with_counter_ans = list(map(lambda x, y: x+' '+y, query, counter_answer))
        memory_prompts = list(map(lambda x, y: x+' '+y, memory_context, query))
        memory_inputs = list(map(lambda x, y: x+' '+y, memory_prompts, memory_answer))
        counter_prompts = list(map(lambda x, y: x+' '+y, counter_context, query))
        counter_inputs = list(map(lambda x, y: x+' '+y, counter_prompts, counter_answer))

        return tuple([memory_context, counter_context,
                      query_with_memory_ans, query_with_counter_ans, query,
                      memory_inputs, memory_prompts,
                      counter_inputs, counter_prompts
        ])

class ConflictQA_StrategyQA_DataLoader(DataLoader):
    def __init__(self, dataset:ConflictQA_StrategyQA, batch_size, shuffle=True, args=None,):
        super(ConflictQA_strategyQA_DataLoader, self).__init__(
            dataset=dataset, batch_size=batch_size,
            collate_fn=self.ConflictQA_StrategyQA_collate_fn, shuffle=shuffle
        )
        self.instruct = 'Answer the question based on the given information, and give the answer directly. '
        self.args = args
    
    def ConflictQA_StrategyQA_collate_fn(self, batch):
        # batch_idx = [b[0] for b in batch]
        query = ['Question: '+b[1]["question"] for b in batch] if self.args.guide else\
                [b[1]["question"] for b in batch]
        memory_context = [self.instruct+'Information: '+b[1]["parametric_memory"] for b in batch] if self.args.guide else\
                         [b[1]["parametric_memory"] for b in batch]
        counter_context = [self.instruct+'Information: '+b[1]["counter_memory"] for b in batch] if self.args.guide else\
                          [b[1]["counter_memory"] for b in batch]
        memory_answer = [b[1]["memory_answer"] for b in batch]
        counter_answer = [b[1]["counter_answer"] for b in batch]
        # 
        query_with_memory_ans = list(map(lambda x, y: x+' '+y, query, memory_answer))
        query_with_counter_ans = list(map(lambda x, y: x+' '+y, query, counter_answer))
        memory_prompts = list(map(lambda x, y: x+' '+y, memory_context, query))
        memory_inputs = list(map(lambda x, y: x+' '+y, memory_prompts, memory_answer))
        counter_prompts = list(map(lambda x, y: x+' '+y, counter_context, query))
        counter_inputs = list(map(lambda x, y: x+' '+y, counter_prompts, counter_answer))

        return tuple([memory_context, counter_context,
                      query_with_memory_ans, query_with_counter_ans, query,
                      memory_inputs, memory_prompts,
                      counter_inputs, counter_prompts
        ])


''' Class of Adaptive Residual '''
class AdaRes(torch.nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        AdaRes_INPUTS_DOCSTRING=r'''
        -----
        llm_name:
            the LLM name to apply AdaRes, currently supporting:
            [Lowercase Name]: llama3, llama-3, llama3-8b, llama-3-8b       [Map to]--->>>  perform editing on Llama3-8B  model
            [Lowercase Name]: qwen2.5, qwen2.5-7, qwen2.5-7b, qwen-2.5-7b  [Map to]--->>>  perform editing on Qwen2.5-7B model
            [Lowercase Name]: gemma3-4b, gemma-3-4b                        [Map to]--->>>  perform editing on Gemma3-4B  model
            [Lowercase Name]: gemma3-12b, gemma-3-12b                      [Map to]--->>>  perform editing on Gemma3-12B model
            [Lowercase Name]: phi3-mini, phi3-3.8b                         [Map to]--->>>  perform editing on Phi3-3.8B  model
            [Lowercase Name]: phi3-medium, phi3-14b                        [Map to]--->>>  perform editing on Phi3-14B   model
        '''
        self.args = args
        self.model = None
        self.tokenizer = None
        self.llm = None
        self.config = None
        #
        name = self.args.llm_name.lower()
        if name in ['llama3','llama-3','llama3-8b','llama-3-8b']:
            self.llm = 'Llama3-8B'
            self.model = LlamaForCausalLM.from_pretrained(CONFIG['llms'][self.llm]['path'], device_map="auto")
            self.tokenizer = AutoTokenizer.from_pretrained(CONFIG['llms'][self.llm]['path'])
        elif name in ['qwen2.5','qwen2.5-7','qwen2.5-7b','qwen-2.5-7b']:
            self.llm = 'Qwen2.5-7B'
            self.model = Qwen2ForCausalLM.from_pretrained(CONFIG['llms'][self.llm]['path'], device_map="auto")
            self.tokenizer = AutoTokenizer.from_pretrained(CONFIG['llms'][self.llm]['path'])
        elif name in ['gemma3-4b','gemma-3-4b']:
            self.llm = 'Gemma3-4B'
            self.model = Gemma3ForCausalLM.from_pretrained(CONFIG['llms'][self.llm]['path'], device_map="auto")
            self.tokenizer = AutoTokenizer.from_pretrained(CONFIG['llms'][self.llm]['path'])
        elif name in ['gemma3-12b','gemma-3-12b']:
            self.llm = 'Gemma3-12B'
            self.model = Gemma3ForCausalLM.from_pretrained(CONFIG['llms'][self.llm]['path'], device_map="auto")
            self.tokenizer = AutoTokenizer.from_pretrained(CONFIG['llms'][self.llm]['path'])
        elif name in ['phi3-mini','phi3-3.8b']:
            self.llm = 'Phi3-3.8B'
            self.model = Phi3ForCausalLM.from_pretrained(CONFIG['llms'][self.llm]['path'], device_map="auto")
            self.tokenizer = AutoTokenizer.from_pretrained(CONFIG['llms'][self.llm]['path'])
        elif name in ['phi3-medium','phi3-14b']:
            self.llm = 'Phi3-14B'
            self.model = Phi3ForCausalLM.from_pretrained(CONFIG['llms'][self.llm]['path'], device_map="auto")
            self.tokenizer = AutoTokenizer.from_pretrained(CONFIG['llms'][self.llm]['path'])
        else: raise Exception(f'The current approach does not support the LLM {name}!')
        # Set model parameters as no_grad
        for name, param in self.model.named_parameters():
            if param.requires_grad: param.requires_grad = False
        # Setting the pad token
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = 'right'
        self.model.config.pad_token_id = self.model.config.eos_token_id
        self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
    
    def tokenize(self, batch_inputs:list, batch_prompts:list, batch_context:list=None, batch_query:list=None):
        inputs = self.tokenizer(batch_inputs, padding=True, return_tensors='pt')
        context = self.tokenizer(batch_context, padding=True, return_tensors='pt') if batch_context is not None else None
        query = self.tokenizer(batch_query, padding=True, return_tensors='pt') if batch_context is not None else None
        labels = inputs.input_ids.detach().clone()
        inputs_length = inputs.attention_mask.sum(1).tolist()
        prompts_length = self.tokenizer(batch_prompts, padding=True, return_tensors='pt').attention_mask.sum(1).tolist()
        return inputs, labels, context, query, inputs_length, prompts_length
    
    def lambda_fn(self, attn_feats, ffn_feats, alpha, beta):
        attn_device = attn_feats.device
        ffn_device = ffn_feats.device
        alpha = torch.abs(alpha)
        beta = torch.abs(beta)
        res = (1.+alpha.to(attn_device)/(alpha+beta.to(attn_device)))*attn_feats +\
              (1.-alpha.to(ffn_device)/(alpha+beta.to(ffn_device)))*ffn_feats
        return res

    def forward(self, inputs, context=None, z_query=None, output_hidden_states=False):
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(**inputs, context=context, z_query=z_query,
                                 H=self.args.hlayers, top_n=self.args.topn,
                                 lambda_fn = self.lambda_fn,
                                 output_hidden_states=output_hidden_states,)
        return outputs

    def to(self, device:torch.device):
        if self.model is not None: self.model = self.model.to(device)
        else: raise Exception('No model specified!')


''' Class of Model Evaluation '''
class Evaluator(object):
    @staticmethod
    def success(model, batch_inputs:list, batch_prompts:list, batch_context:list=None, batch_query:list=None):
        if isinstance(batch_inputs, str): batch_inputs = [batch_inputs,]
        if isinstance(batch_prompts, str): batch_prompts = [batch_prompts,]
        if isinstance(batch_context, str): batch_context = [batch_context,]
        if isinstance(batch_query, str): batch_query = [batch_query,]
        inputs, labels, context, z_query, num_seq, num_prompt = model.tokenize(batch_inputs, batch_prompts, batch_context, batch_query)
        if torch.cuda.device_count() == 1:
            inputs = inputs.to('cuda:0') if inputs is not None else None
            labels = labels.to('cuda:0') if labels is not None else None
            context = context.to('cuda:0') if context is not None else None
            z_query = z_query.to('cuda:0') if z_query is not None else None
        num_answer = list(map(lambda x,y : x-y, num_seq, num_prompt))
        model.eval()
        with torch.no_grad(): outputs = model(inputs, context, z_query)
        if type(outputs) is torch.Tensor: logits = outputs
        else: logits = outputs.logits
        answers = torch.argmax(logits, dim=-1).tolist()
        answers = list(map(lambda x,i,j: x[i-1:i+j-1], answers,num_prompt,num_answer))
        labels = labels.tolist()
        labels = list(map(lambda x,i,j: x[i:i+j], labels,num_prompt,num_answer))
        if -100 in labels: raise Exception('Error in labels when evaluation!')
        res = []
        for ans,label in zip(answers,labels):
            temp_acc = np.mean(np.equal(ans, label))
            if np.isnan(temp_acc): continue
            res.append(temp_acc)
        return res
    
    @staticmethod
    def locality(model, batch_inputs:list, batch_prompts:list, original_preds:list):
        if isinstance(batch_inputs, str): batch_inputs = [batch_inputs,]
        if isinstance(batch_prompts, str): batch_prompts = [batch_prompts,]
        if isinstance(original_preds, str): original_preds = [original_preds,]
        # Get the model predictions
        inputs, _, _, _, num_seq, num_prompt = model.tokenize(batch_inputs, batch_prompts)
        if torch.cuda.device_count() == 1: inputs = inputs.to('cuda:0')
        num_answer = list(map(lambda x,y : x-y, num_seq, num_prompt))
        model.eval()
        with torch.no_grad(): outputs = model(inputs)
        if type(outputs) is torch.Tensor: logits = outputs
        else: logits = outputs.logits
        answers = torch.argmax(logits, dim=-1).tolist()
        answers = list(map(lambda x,i,j: x[i-1:i+j-1], answers,num_prompt,num_answer))
        #
        # Obtain the original outputs
        labels = list(map(lambda x,i: x[-(i+1):-1], original_preds,num_answer))
        res = []
        for ans,label in zip(answers,labels):
            # print(f'ans:{ans}, label:{label}')
            temp_acc = np.mean(np.equal(ans, label))
            if np.isnan(temp_acc): continue
            res.append(temp_acc)
        return res


if __name__ == '__main__':
    #
    # Running this code by: python adares.py --llm_name ... --data_name ... --mode ... --hlayer ... --topn ...
    #
    parser = argparse.ArgumentParser()
    parser.add_argument('--llm_name', type=str, required=True)
    parser.add_argument('--data_name', type=str, required=True, choices=['zsre', 'counterfact', 'popqa', 'strategyqa'])
    parser.add_argument('--test_bs', type=int, default=16, help='Batch size of evaluation.')
    parser.add_argument('--hlayers', nargs='+', type=int, default=None, help="The only hyperparameter: Layers to apply AdaRes.")
    parser.add_argument('--verbose', type=bool, default=True, help="Whether to show the progress bar or not.")
    parser.add_argument('--topn', type=int, default=-1)
    parser.add_argument('--guide', type=bool, default=False)
    args = parser.parse_args()
    # Initialize our adaptive residual
    model = AdaRes(args)
    # Load dataset and prepare data loader
    if args.data_name == 'zsre':
        data = ZsRE(CONFIG['data']['ZsRE'].format('_'+model.llm), args=args)
        test_loader = ZsREDataLoader(data, args.test_bs, shuffle=False, args=args)
    elif args.data_name == 'counterfact':
        data = CounterFact(CONFIG['data']['CounterFact'].format('_'+model.llm), args=args)
        test_loader = CounterFactDataLoader(data, args.test_bs, shuffle=False, args=args)
    elif args.data_name == 'popqa':
        data = ConflictQA_PopQA(CONFIG['data']['ConflictQA_PopQA'].format('_'+model.llm), args=args)
        test_loader = ConflictQA_PopQA_DataLoader(data, args.test_bs, shuffle=False, args=args)
    elif args.data_name == 'strategyqa':
        data = ConflictQA_StrategyQA(CONFIG['data']['ConflictQA_StrategyQA'].format('_'+model.llm), args=args)
        test_loader = ConflictQA_StrategyQA_DataLoader(data, args.test_bs, shuffle=False, args=args)
    else: raise Exception(f"Currently not support the dataset: {args.data_name}!")

    print(f'Model settings:\n', args, '\n', '--'*15)
    
    # Evaluation
    efficacy_acc = list()
    rephrase_acc = list()
    locality_acc = list()
    if args.verbose: pbar = tqdm(total=data.__len__(), ncols=75, leave=True)
    if args.verbose: pbar.set_description_str(desc='AdaRes->')
    if args.data_name == 'popqa' or args.data_name == 'strategyqa':    
        for idx, batch in enumerate(test_loader):
            efficacy_acc.append(Evaluator.efficacy(model, batch[7], batch[8], batch[1], batch[4]))
            if args.verbose: pbar.update(len(batch[0]))
        if args.verbose: pbar.close()
        print(f'Mean Efficacy: {np.mean(flatten(efficacy_acc))}\n')
    elif args.data_name == 'zsre' or args.data_name == 'counterfact':
        for idx, batch in enumerate(test_loader):
            efficacy_acc.append(Evaluator.success(model, batch[14], batch[5], batch[1], batch[2]))
            rephrase_acc.append(Evaluator.success(model, batch[15], batch[6], batch[1], batch[3]))
            locality_acc.append(Evaluator.locality(model, batch[13], batch[4], batch[19]))
            if args.verbose: pbar.update(len(batch[0]))
        if args.verbose: pbar.close()
        print(f'Mean Efficacy: {np.mean(flatten(efficacy_acc))}')
        print(f'Mean Rephrase: {np.mean(flatten(rephrase_acc))}')
        print(f'Mean Locality: {np.mean(flatten(locality_acc))}\n')
    else: raise Exception(f'Cannot proceed the data {args.data_name}')