
import numpy as np 
import hydra, os, sys, re, ipdb
from datetime import datetime
from importlib import import_module
import logging
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification
from huggingface_hub import login
from vllm import LLM

ROOT = "/u/audreyh/workspace/test-code"
sys.path.append(os.path.join(ROOT, 'code'))
from generate import get_sampling_params
import helpers.test_functions as tf
import helpers.io as io
os.environ['XDG_CACHE_HOME'] = "/work/hdd/bdkj/audreyh/.cache"
os.environ['OUTLINES_CACHE_DIR'] = ROOT

DATE = datetime.now().strftime("%m-%d")

MODEL_LIST = [
    'gemma-2-2b', 
    'llama-3-3b', 
    # 'mistral-7b', 
    # 'phi-3-small', 
    # 'phi-3-mini',
]

REWARD_LIST = [
    'grm-llama-3b',
    # 'mistral-7b',
    # 'fsfairx-8b', 
    # 'eurus-rm-7b', 
    # 'armo-rm',
    # 'oasst-rm', 
    # 'rm-gemma-2b',
]

TASK_LIST = [
    # 'gsm8k',
    # 'alpaca',
    'mmlu'
    ]

QUESTION = "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?"
ANSWER = " We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6."


class Tee(object):
    def __init__(self, *files):
        self.files = files  # list of file objects to write to (console and log file)

    def write(self, message):
        for file in self.files:
            file.write(message)  # Write to each file object

    def flush(self):
        for file in self.files:
            file.flush()  # Flush all file objects

def set_config(**kwargs): 
    overrides = []
    overrides.append("mode.max_samples=2")
    overrides.append("sampling.k=1")
    for key, value in kwargs.items(): 
        if key in ['include_prompt', 'batch_size']: 
           key = f"evaluation.{key}" 
        overrides.append(f"{key}={value}")
    cfg = hydra.compose(config_name="master", overrides=overrides)
    return cfg 

def extract_value_from_file(file_path, prefix='correct:'):
    if not os.path.exists(file_path): 
        return 'file does not exist'
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith(prefix):
                match = re.search(rf'{prefix}\s*([\d\.]+)', line)
                if match:
                    value = float(match.group(1))
                    return value
    return 'N/A'

def test_querybuilder(save_path):
    log_file = open(os.path.join(save_path, 'querybuilder-output.log'), 'w')
    sys.stdout = Tee(sys.stdout, log_file)
    for task in TASK_LIST: 
        for model in MODEL_LIST: 
            print(f"=====Testing querybuilder for {model.center(12)} on {task.center(8)}=====")
            cfg = set_config(task=task, policy=model)
            data_module = import_module(f"tasks.{cfg.task.name}",  package='code')
            dl = data_module.DataLoader(cfg)
            queries = dl.build_queries()
            if len(dl.qb.task_desc) > 0:
                print(f"TASK_DESC\n\n{dl.qb.task_desc}\n")
            print(f"QUERY\n\n{queries[0]}\n")
    sys.stdout.flush()
    sys.stdout = sys.__stdout__
    log_file.close() 


def test_generate(save_path):
    log_file = open(os.path.join(save_path, 'generate-output.log'), 'w')
    sys.stdout = Tee(sys.stdout, log_file)
    for task in TASK_LIST: 
        for model in MODEL_LIST: 
            print(f"\n=====Testing generate for {model.center(12)} on {task.center(8)}=====\n")
            cfg = set_config(task=task, policy=model)
            data_module = import_module(f"tasks.{cfg.task.name}",  package='code')
            dl = data_module.DataLoader(cfg)
            queries = dl.build_queries()

            print(f'Loading model {cfg.policy.model}')
            llm = LLM(cfg.policy.model, gpu_memory_utilization=cfg.sampling.gpu_memory_utilization, swap_space=cfg.sampling.swap_space, trust_remote_code=True)
            tok = llm.get_tokenizer() 
            if cfg.task.name == 'tldr': 
                tok.padding_side = 'left'
                tok.add_special_tokens({"pad_token": "[PAD]"})
                llm.set_tokenizer(tok) 
            stop_tokens = [tok.eos_token_id]
            sampling_parameters = get_sampling_params(cfg.sampling, stop_tokens)
            outputs = llm.generate(queries, sampling_parameters)
            print(f"> QUERY\n\n{queries[0]}\n")

            print(f"> RESPONSE\n\n{outputs[0].outputs[0].text}\n")
    sys.stdout.flush()
    sys.stdout = sys.__stdout__
    log_file.close() 

def test_greedy(save_path, seed=101, shots=0): 
    log_file = open(os.path.join(save_path, 'greedy-output.log'), 'w')
    sys.stdout = Tee(sys.stdout, log_file)
    for task in TASK_LIST: 
        for policy in MODEL_LIST: 
            print(f"=====Testing greedy for {policy.center(12)} on {task.center(8)}=====")
            data_path = f"/work/hdd/bdkj/audreyh/data/{task}/{policy}/generations"
            filename = f"{task}-{policy}-greedy-shots-{shots}-seed-{seed}-generations.json"
            filepath = os.path.join(data_path, filename)
            outputs = io.json_load(filepath)
            value = np.mean(tf.get_key(outputs, "correct"))
            # value = extract_value_from_file(os.path.join(data_path, filename))
            print(f"correct: {value}\n")
    sys.stdout.flush()
    sys.stdout = sys.__stdout__
    log_file.close() 

def test_reward_formatter(save_path, policy='gemma-2-2b', include_prompt=False): 
    log_file = open(os.path.join(save_path, 'reward-format-output.log'), 'w')
    sys.stdout = Tee(sys.stdout, log_file)
    for task in TASK_LIST: 
        for reward in REWARD_LIST: 
            print(f"=====Testing reward format for {reward.center(12)} on {task.center(8)}=====")
            cfg = set_config(task=task, policy=policy, reward=reward, include_prompt=include_prompt)
            eval_module = import_module(f"rewards.{cfg.reward.name}", package='code')
            collator = eval_module.RewardCollator(cfg)
            query = collator._build_query_str(QUESTION)
            response = collator._build_response_str(ANSWER)
            qr = collator.format_query_response(QUESTION, ANSWER)
            print(f"TASK_DESC\n\n{collator.task_desc}\n")
            print(f"QUERY\n\n{query}\n")
            print(f"RESPONSE\n\n{response}\n")
            print(f"QUERY-RESPONSE\n\n{qr}\n")  
    sys.stdout.flush()
    sys.stdout = sys.__stdout__
    log_file.close()      

def build_output(batch_size): 
    outputs = [
                {
                    'prompt_idx': 0, 
                    'prompt': QUESTION, 
                    'response': ANSWER,
                }
                for _ in range(batch_size)
                ]
    outputs[-1]['response'] = ANSWER[:int(len(ANSWER)/2)]
    assert len(outputs[-1]['response']) != len(outputs[0]['response'])
    return outputs 

def test_reward_collator(save_path, policy='gemma-2-2b',batch_size=4): 
    outputs = build_output(batch_size)
    log_file = open(os.path.join(save_path, 'reward-collator-output.log'), 'w')
    sys.stdout = Tee(sys.stdout, log_file)
    is_pass = True
    for task in TASK_LIST: 
        for reward in REWARD_LIST: 
            print(f"=====Testing reward collator for {reward.center(12)} on {task.center(8)}=====")
            cfg = set_config(task=task, policy=policy, reward=reward, batch_size=batch_size)
            eval_module = import_module(f"rewards.{cfg.reward.name}", package='code')
            collator = eval_module.RewardCollator(cfg)
            try: 
                dataloader = DataLoader(
                            outputs, 
                            batch_size=cfg.evaluation.batch_size,
                            collate_fn=collator,
                            shuffle=False, 
                        )
                
                for idx, batch in enumerate(dataloader):

                    print_str = 'PASS: type {batch_type} of {value_type}'
                    if isinstance(batch, list): 
                        batch_type = "list"
                        value_type = type(batch[0])
                    elif isinstance(batch, dict):
                        batch_type = "dict"
                        value_type = type(list(batch.values())[0])
                    if isinstance(batch, torch.Tensor): 
                        batch_type = "torch.Tensor"
                        value_type = batch.shape
                        assert value_type[0] == batch_size, "Tensor shape does not match batch size"
                    else: 
                        batch_type = type(batch)
                        value_type = ""
                    print(print_str.format(batch_type=batch_type, value_type=value_type))
                    # print(f'pass, type {type(batch)} of {type(batch[0]) if isinstance(batch, list) else type(list(batch.values())[0])}')
            except: 
                print('FAIL')
                is_pass = False
    sys.stdout.flush()
    sys.stdout = sys.__stdout__
    log_file.close()  
    return is_pass

def test_reward_module(save_path, policy='gemma-2-2b', batch_size=4): 
    from accelerate import Accelerator
    accelerator = Accelerator() 
    outputs = build_output(batch_size)
    log_file = open(os.path.join(save_path, 'reward-module-output.log'), 'w')
    sys.stdout = Tee(sys.stdout, log_file)
    for include_prompt in [True, False]: 
        print(f'******************** include_prompt: {include_prompt} ********************')
        for task in TASK_LIST: 
            for reward in REWARD_LIST: 
                print(f"=====Testing reward module for {reward.center(12)} on {task.center(8)}=====")
                cfg = set_config(task=task, policy=policy, reward=reward, batch_size=batch_size, include_prompt=include_prompt)
                eval_module = import_module(f"rewards.{cfg.reward.name}", package='code')
                collator = eval_module.RewardCollator(cfg)
                dataloader = DataLoader(
                            outputs, 
                            batch_size=cfg.evaluation.batch_size,
                            collate_fn=collator,
                            shuffle=False, 
                        )
                dataloader = accelerator.prepare(dataloader)
                # reward_model = AutoModelForSequenceClassification.from_pretrainsed(cfg.reward.model, trust_remote_code=True)
                try: 
                    rm = eval_module.RewardModule(cfg)
                    reward_model = accelerator.prepare(rm.model)
                    rm.model = reward_model
                    del reward_model 
                except: 
                    print('Could not load reward module')
                    continue
                for idx, batch in enumerate(dataloader):
                    # print(f"device: batch on {batch['input_ids'].device} and model on {rm.model.device}")
                    try: 
                        scores = rm.get_reward(batch)
                        assert isinstance(scores, list), "Scores is not a list."
                        assert len(scores) == batch_size, "Length of scores does not match batch size."
                        print(f'score: {scores[0]}')
                    except: 
                        print('Could not evaluate rewards')
    sys.stdout.flush()
    sys.stdout = sys.__stdout__
    log_file.close()  

@hydra.main(config_path=os.path.join(ROOT, "code/configs"), config_name="master", version_base=None)
def main(cfg):
    np.random.seed(cfg.seed)
    save_path = os.path.join(ROOT, f'tests/{DATE}-tests')
    os.makedirs(save_path, exist_ok=True)

    test_querybuilder(save_path)
    # test_generate(save_path)
    test_greedy(save_path)
    # test_reward_formatter(save_path, include_prompt=True)
    # reward_collator_pass = test_reward_collator(save_path)
    # if reward_collator_pass: 
    # test_reward_module(save_path)

     

if __name__ == "__main__":
    # login(token="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")

    os.environ['XDG_CACHE_HOME'] = "/work/hdd/bdkj/audreyh/.cache"

    main()