import copy
import json
import os
import re
import sys
import argparse
import pdb

import fire

import torch
import sys
sys.path.append(os.path.join(os.getcwd(), "peft/src/"))
from peft import PeftModel, PeftConfig
from tqdm import tqdm
from transformers import GenerationConfig, LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer
from llama import LlamaForCausalLM
from llama_explore.modeling_llama_explore import New_LlamaForCausalLM

import wandb


if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

try:
    if torch.backends.mps.is_available():
        device = "mps"
except:  # noqa: E722
    pass


def main(
        load_8bit: bool = False,
        base_model: str = "",
        lora_weights: str = "tloen/alpaca-lora-7b",
        share_gradio: bool = False,
):
    args = parse_args()
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
        
    print(args)
    print(args.explore_weight)
    run = wandb.init(
        project=args.wandb_project,
        name=args.wandb_run_name,
        config={
            "lambda": args.explore_weight,
            "decoding_method": args.decoding_method,
            "num_epochs": args.num_epochs,
            # "clip_value": args.clip_value,
            "seed": args.seed,
            "alpha": args.alpha,
            "beta": args.beta,
            "model_type": args.model_type
        },
        mode="offline"
    )
    def evaluate(
            instruction,
            input_=None,
            num_forest=0,
            temperature=0.1,
            top_p=0.95,
            top_k=40,
            num_beams=4,
            max_new_tokens=256,
            **kwargs,
    ):
        prompt = generate_prompt(instruction, input_)
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(device)
        generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            generation_mode="greedy",
            num_beams=num_beams,
            do_sample=False,
            **kwargs,
        )
        generation_explore_config_forest = []
        if args.explore_flag == 1:
            if args.decoding_method == 'greedy':
                for i in range(num_forest):
                    generation_explore_config = GenerationConfig(
                        generation_mode="greedy",  # Set mode to greedy
                        num_beams=1,               # Ensure no beam search
                        do_sample=False,            # Disable sampling
                        output_hidden_states=True,
                        output_logits=True,
                        **kwargs
                    )
                    generation_explore_config_forest.append(generation_explore_config)
            elif args.decoding_method == 'beam_search':
                for i in range(num_forest):
                    generation_explore_config = GenerationConfig(
                        temperature=temperature,
                        top_p=top_p,
                        top_k=top_k,
                        generation_mode="beam_search",  # Set mode to beam search
                        num_beams=num_beams,             # Use beam search
                        do_sample=True,                 # Disable sampling
                        output_hidden_states=True,
                        **kwargs
                    )
                    generation_explore_config_forest.append(generation_explore_config)
        
        
        
        with torch.no_grad():
            generation_output = model.generate(
                model_explore=model_explore_forest if args.explore_flag == 1 else None,
                explore_lora_weights=explore_lora_weights_forest if args.explore_flag == 1 else None,
                generation_explore_config=generation_explore_config_forest if args.explore_flag == 1 else None,
                explore_weight=explore_weight_forest if args.explore_flag == 1 else 0.0,
                top_k_explore_logits=args.top_k_explore_logits if args.explore_flag == 1 else -1,
                topk_logits=args.topk_logits if args.explore_flag == 1 else None,
                # clip_value=args.clip_value if args.explore_flag == 1 else None,
                max_new_tokens_generation=max_new_tokens,
                input_ids=input_ids,
                generation_config=generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=max_new_tokens,
                use_cache=False,
            )
        s = generation_output.sequences[0]
        output = tokenizer.decode(s)
        return output.split("### Response:")[1].strip()


    save_file = f'experiment/{args.model}-{args.adapter}-{args.dataset}.json'
    create_dir('experiment/meta-llama')

    dataset = load_data(args)
    if args.explore_flag == 1:
        tokenizer, model, tokenizer_explore, model_explore_forest, explore_weight_forest, explore_lora_weights_forest = load_model(args)
    else: tokenizer, model = load_model(args)
    
    
    total = len(dataset)
    correct = 0
    miss = 0.001
    output_data = []
    pbar = tqdm(total=total)
    for idx, data in enumerate(dataset):
        instruction = data.get('instruction')
        input_ = data.get('input')
        
        outputs = evaluate(instruction, input_, num_forest=len(model_explore_forest
        ) if args.explore_flag == 1 else 0,)
        # outputs = evaluate(instruction, )
        label = data.get('answer')
        flag = False
        if args.dataset.lower() in ['aqua']:
            predict = extract_answer_letter(args, outputs)
            if label == predict:
                correct += 1
                flag = True
        elif args.dataset.lower() in ['llmtools']:
            predict = extract_new_answer_letter(args, outputs)
            if label == predict:
                correct += 1
                flag = True
        else:
            if isinstance(label, str):
                try:
                    label = float(label)
                except ValueError as e:
                    label = float('inf')
            predict = extract_answer_number(args, outputs)
            if abs(label - predict) <= miss:
                correct += 1
                flag = True
        new_data = copy.deepcopy(data)
        new_data['output_pred'] = outputs
        new_data['pred'] = predict
        new_data['flag'] = flag
        output_data.append(new_data)
        print(' ')
        print('---------------')
        print("MODEL_OUTPUT", outputs)
        print('prediction:', predict)
        print('label:', label)
        print('---------------')
        print(f'\rtest:{idx + 1}/{total} | accuracy {correct}  {correct / (idx + 1)}')
        run.log({"acc": correct / (idx + 1)})
        pbar.update(1)
    pbar.close()
    print('\n')
    print('test finished')


def create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.mkdir(dir_path)
    return


def generate_prompt(instruction, input_=None):
    if input:
        return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request, and make sure to put the final answer inside <ans>A/B/C/D</ans> tags.

### Instruction:
{instruction}

### Input:
{input_}

### Response:
"""
    else:
        return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request, and make sure to put the final answer inside <ans></ans> tags.

### Instruction:
{instruction}

### Response:

"""


def load_data(args) -> list:
    """
    read data from dataset file
    Args:
        args:

    Returns:

    """
    file_path = f'dataset/{args.dataset}/test.json'
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"can not find dataset file : {file_path}")
    json_data = json.load(open(file_path, 'r'))
    return json_data


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', choices=['AddSub', 'MultiArith', 'SingleEq', 'gsm8k', 'AQuA', 'SVAMP', "mawps", "math_10k_500", "math_10k_small_5", "LLMTOOLs"],
                        required=True)
    parser.add_argument('--model', required=True)
    parser.add_argument('--adapter', choices=['lora', 'AdapterP', 'AdapterH', 'Parallel', 'Prefix'],
                        required=True)
    parser.add_argument('--base_model', required=True)
    parser.add_argument('--lora_weights', required=True)
    parser.add_argument('--load_8bit', action='store_true', default=False)
    parser.add_argument('--logfile', required=False)
    
    # New added argument for the copilot model
    parser.add_argument('--explore_model', required=False)
    parser.add_argument('--base_explore_model', required=False)
    parser.add_argument('--explore_lora_weights', required=False)
    parser.add_argument('--explore_flag', type=int, default=0)
    parser.add_argument('--explore_weight', type=str, default="0.0")
    parser.add_argument('--decoding_method', type=str, choices=['greedy', 'beam_search'],required=False)
    parser.add_argument('--explore_logits_factor', type=float, default=100.0)
    parser.add_argument('--top_k_explore_logits', type=int, default=-2)
    parser.add_argument("--topk_logits", type=int, default=5)
    # parser.add_argument("--clip_value", type=float, default=0.1)
    parser.add_argument('--wandb_run_name', type=str, default='llama_explore')
    parser.add_argument('--wandb_project', type=str, default='BoostLLM')
    parser.add_argument('--num_epochs', type=int, default=3)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--alpha', type=float, default=0)
    parser.add_argument('--beta', type=float,default=0)
    parser.add_argument('--model_type', type=str, default='BoostLLM')
    parser.add_argument('--gpus', type=str, default='0')
    return parser.parse_args()


def load_model(args) -> tuple:
    """
    load tuned model
    Args:
        args:

    Returns:
        tuple(tokenizer, model)
    """
    base_model = args.base_model
    if args.explore_flag == 1:
        base_explore_model_forest = args.base_explore_model.split(',')
    
    if not base_model:
        raise ValueError(f'can not find base model name by the value: {args.model}')
    
    if args.explore_flag == 1:
        if not base_explore_model_forest:
            raise ValueError(f'can not find base model name by the value: {args.model}')
        
    lora_weights = args.lora_weights
    
    if args.explore_flag == 1:
        explore_lora_weights_forest = args.explore_lora_weights.split(',')
        if not explore_lora_weights_forest:
            raise ValueError(f'can not find lora weight, the value is: {explore_lora_weights_forest}')
    
    if not lora_weights:
        raise ValueError(f'can not find lora weight, the value is: {lora_weights}')

    load_8bit = args.load_8bit
    if args.model == 'LLaMA-7B':
        tokenizer = AutoTokenizer.from_pretrained(base_model)
    else:
        tokenizer = AutoTokenizer.from_pretrained(base_model)
        tokenizer.pad_token_id = 0 
        tokenizer.padding_side = "right"

        tokenizer_forest = []
        if args.explore_flag == 1:
            for base_explore_model in base_explore_model_forest:
                tokenizer = AutoTokenizer.from_pretrained(base_explore_model)
                tokenizer.pad_token_id = 0
                tokenizer.padding_side = "right"
                tokenizer_forest.append(tokenizer)

            
    assert device == "cuda"
    
    if device == "cuda":
        gpus = args.gpus.split(',')
        model = LlamaForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=load_8bit,
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True,
            mode="eval"
            # attn_implementation="flash_attention_2"
        ) 
        model.eval()
        explore_weight_forest = args.explore_weight.split(",")  
        model_explore_forest = []
        if args.explore_flag == 1:          
            model_hidden_size = model.config.hidden_size
            explore_weight_forest = args.explore_weight.split(",")  
            for i in range(len(base_explore_model_forest)):
                base_explore_model = base_explore_model_forest[i]
                explore_lora_weights = explore_lora_weights_forest[i]
                explore_weight = eval(explore_weight_forest[i])
                print(f"Loading explore model: {base_explore_model}, with lora weights: {explore_lora_weights}")
                print(gpus)
                gpu_id = int(gpus[i+1])
                model_explore = New_LlamaForCausalLM.from_pretrained(
                    base_explore_model,
                    load_in_8bit=load_8bit,
                    torch_dtype=torch.float16,
                    device_map={"":gpu_id},
                    trust_remote_code=True,
                    exploit_hidden_size=model_hidden_size,
                    explore_logits_factor=args.explore_logits_factor,
                    topk_logits=args.topk_logits,
                    model_idx=i,
                    gpu_id=gpu_id,
                    alpha=args.alpha,
                    # attn_implementation="flash_attention_2"
                )
                print("model_explore")
                print(model_explore)
                model_explore.eval()
                
                model_hidden_size = model_explore.config.hidden_size
                # model_explore.config.output_hidden_states = True
                model_explore_forest.append(model_explore)
            
        model.config.output_hidden_states = True
        
    elif device == "mps":
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
            torch_dtype=torch.float16,
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            base_model, device_map={"": device}, low_cpu_mem_usage=True
        )
        model = PeftModel.from_pretrained(
            model,
            lora_weights,
            device_map={"": device},
        )

        # unwind broken decapoda-research config
        model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
        model.config.bos_token_id = 1
        model.config.eos_token_id = 2

        if not load_8bit:
            model.half()  # seems to fix bugs for some users.

        model.eval()
        if torch.__version__ >= "2" and sys.platform != "win32":
            model = torch.compile(model)

    if args.explore_flag == 1:
        return tokenizer, model, tokenizer_forest, model_explore_forest, explore_weight_forest, explore_lora_weights_forest
    else: return tokenizer, model


def load_instruction(args) -> str:
    instruction = ''
    if not instruction:
        raise ValueError('instruct not initialized')
    return instruction


def extract_answer_number(args, sentence: str) -> float:
    dataset = args.dataset.lower()
    if dataset in ["multiarith", "addsub", "singleeq", "gsm8k", "svamp", "mawps", "math_10k_500", "math_10k_small_5"]:
        sentence = sentence.replace(',', '')
        pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)]
        if not pred:
            return float('inf')
        pred_answer = float(pred[-1])
    else:
        raise NotImplementedError(' not support dataset: {}'.format(dataset))
    if isinstance(pred_answer, str):
        try:
            pred_answer = float(pred_answer)
        except ValueError as e:
            pred_answer = float('inf')
    return pred_answer


def extract_answer_letter(args, sentence: str) -> str:
    sentence_ = sentence.strip()
    pred_answers = re.findall(r'A|B|C|D|E', sentence_)
    if pred_answers:
        if not pred_answers:
            return ''
        return pred_answers[-1]
    else:
        return ''

def extract_new_answer_letter(args, sentence: str) -> str:
    sentence_ = sentence.strip()
    match = re.search(r"<ans>(.*?)</ans>", sentence_)
    if match:
        return match.group(1).strip()
    return ""
if __name__ == "__main__":
    fire.Fire(main)
