import hydra
from omegaconf import DictConfig
import json

from trainer.utils import seed_everything
from model import get_model

import torch
from tqdm import tqdm

PATH_TO_FACTS = '/sensei-fs-3/users//msa_unlearning/datasets/facts/facts.json'
PATH_TO_ENTITIES = '/sensei-fs-3/users//msa_unlearning/datasets/facts/entities.json'

def load_facts_dataset():
    """Load the facts dataset from the given path."""
    with open(PATH_TO_FACTS, 'r') as f:
        data = json.load(f)
    return data


def load_enttities():
    """Load the entities dataset from the given path."""
    with open(PATH_TO_ENTITIES, 'r') as f:
        data = json.load(f)
    return data


def get_context(pid, context, qid, num_of_examples=5):
    output = ''
    info = context[pid]

    for entry in info:
        if entry['entity'] == qid:
            continue
        
        output += f'{entry["question"]} {entry["answer"]}\n'
        num_of_examples -= 1
        if num_of_examples == 0:
            return output

    return output


def evaluate_restor(model, tokenizer, all_data, batch_size, max_length, number_of_repeats=1,):

    dset = all_data['dset']
    context = all_data['context']
    entities_list = load_enttities()

    assert len(entities_list) == 50, "No entities found in the dataset."

    device = model.device
    results = {}
    whose_pid, whose_entity = {}, {}

    print('step 1 (doing)\t[aggregating prompts over all (entity, properties)]')

    prompts = []
    counter = 0

    for entity in entities_list:
        results[entity] = {}
        for pid in dset[entity]:
            question = dset[entity][pid]['question']
            rel_context = get_context(pid=pid, context=context, qid=entity)
            prompt = rel_context + f'{question}'
            for _ in range(number_of_repeats):
                prompts.append(prompt)
                whose_entity[counter] = entity
                whose_pid[counter] = pid
                counter += 1        

    print(f'step 1 (done)\t[total # of prompts: {counter}]')
    

    num_of_batches = (counter + batch_size - 1)// batch_size
    print(f'step 2 (doing)\t[total # of batches: {num_of_batches}]')

    with torch.no_grad():
        
        for batch_id in tqdm(range(num_of_batches)):
            i = batch_id * batch_size
            batch_prompts = prompts[i: min(i+batch_size, len(prompts))]
                
            if batch_size == 1:
                inputs = tokenizer(batch_prompts, return_tensors="pt").to(device)
            else:
                inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True).to(device)
                        
            generate_ids = model.generate(
                **inputs,
                min_new_tokens = 2, #1,
                pad_token_id=tokenizer.eos_token_id,
                max_length=max_length,
                do_sample=True)
            
            outputs = tokenizer.batch_decode(
                generate_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False)


            for j, prompt in enumerate(batch_prompts):
                pid = whose_pid[i + j]
                entity = whose_entity[i + j]

                if pid not in results[entity]:
                    results[entity][pid] = {
                        'question': dset[entity][pid]['question'],
                        'prompt': batch_prompts[j],
                        'outputs': [],
                        'raw_outputs': [],
                    }

                
                curr_output = outputs[j]
                curr_output = curr_output[len(prompt) + 1: ]
                results[entity][pid]['raw_outputs'].append(curr_output)
                
                if curr_output.find('\n') != -1:
                    curr_output = curr_output[:curr_output.find('\n')]
                
                if len(curr_output) == 0:
                    print('[WARNING] -- there is empty string in the output.')
                
                results[entity][pid]['outputs'].append(curr_output)

    print(f'step 2 (done)\t[calculated all outputs with {number_of_repeats} replications]')
    return results



@hydra.main(version_base=None, config_path="../configs", config_name="restor_eval.yaml")
def main(cfg: DictConfig):
    """Entry point of the code to evaluate models
    Args:
        cfg (DictConfig): Config to train
    """
    seed_everything(cfg.seed)
    model_cfg = cfg.model
    assert model_cfg is not None, "Invalid model yaml passed in train config."
    path_to_save = cfg.path_to_save
    batch_size = int(cfg.batch_size)

    model, tokenizer = get_model(model_cfg)
    tokenizer.padding_side = "left"

    all_data = load_facts_dataset()

    results = evaluate_restor(
        model=model,
        tokenizer=tokenizer,
        all_data=all_data,
        batch_size=batch_size,
        max_length=200,
        number_of_repeats=3
    )

    
    # Save the results
    with open(path_to_save, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"[done]: saved the results to {path_to_save}")



if __name__ == "__main__":
    main()
