# an example: 
# os.system('CUDA_VISIBLE_DEVICES=7 torchrun --nproc_per_node=1 --master_port=18149 eval_energy.py model_family=llama2-7b split=forget01 model_path=*/llama2-7b/method_1e-05_forget01_8_0.0_2_0.1/checkpoint-25 ps_type=exact topk=5')
# note: add a property of 'ps_type' in config/eval_everything.yaml, taking values from exact, perturb, and similar
import pdb, os, hydra
import logging
import random,time,zlib
import numpy as np
import sklearn.metrics as sk

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM, AutoConfig
from utils import get_model_identifiers_from_yaml


import safetensors
log = logging.getLogger("Unlearning")
recall_level_default = 0.95
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

def model_mix(model,before,after,update_ratio):
    for name,parameter in model.named_parameters():
        parameter.data=update_ratio*before[name[:]].cuda()+(1-update_ratio)*after[name[:]].cuda()
    return model    


@hydra.main(version_base=None, config_path="config", config_name="eval_everything")
def main(cfg):
    if os.environ.get('LOCAL_RANK') is not None:
        local_rank = int(os.environ.get('LOCAL_RANK', '0'))
        device_map = {'': local_rank}

    os.environ["WANDB_DISABLED"] = "true"
    model_cfg = get_model_identifiers_from_yaml(cfg.model_family)
    model_id = model_cfg["hf_key"]
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side='left'
    max_length = 500
    batch_size = cfg.batch_size

    model = None
    config = AutoConfig.from_pretrained(model_id)
    for attempt in range(3):
        try:
        # do thing
            if cfg.use_pretrained:
                print(f"Loading pretrained from {model_id}")
                try:
                    model = AutoModelForCausalLM.from_pretrained(model_id, config=config, use_flash_attention_2=model_cfg["flash_attention2"]=="true", torch_dtype=torch.bfloat16, trust_remote_code = True, device_map=device_map)
                except:
                    model = AutoModelForCausalLM.from_pretrained(model_id, config=config, use_flash_attention_2=model_cfg["flash_attention2"]=="false", torch_dtype=torch.bfloat16, trust_remote_code = True, device_map=device_map)
            else:
                print(f"Loading checkpoint from {cfg.model_path}")
                try:
                    model = AutoModelForCausalLM.from_pretrained(cfg.model_path, config=config, use_flash_attention_2=model_cfg["flash_attention2"]=="true", torch_dtype=torch.bfloat16, trust_remote_code = True, device_map=device_map)
                except:
                    model = AutoModelForCausalLM.from_pretrained(cfg.model_path, config=config, use_flash_attention_2=model_cfg["flash_attention2"]=="false", torch_dtype=torch.bfloat16, trust_remote_code = True, device_map=device_map)
        except Exception as e:
            print(e)
            continue
        # perhaps reconnect, etc.
        else:
            break
    else:
        print("Error: could not load model")
    
    if model_id=='meta-llama/Llama-3.2-1B-Instruct':
        before_ckpt=safetensors.torch.load_file('data/weight/ft_epoch5_lr2e-05_llama3.2-1b_full_wd0.0/checkpoint-1250/model.safetensors')
        after_ckpt=safetensors.torch.load_file(cfg.model_path+'/model.safetensors')
        
    model.generation_config.temperature=None
    model.generation_config.top_p=None
    
    if cfg.split=='forget10':
        retain_name='retain90'
    elif  cfg.split=='forget05':
        retain_name='retain95'
    elif  cfg.split=='forget01':
        retain_name='retain99'
    
    if cfg.ps_type == 'perturb':
        
        if cfg.split=='forget01':
            forget_data=load_dataset('locuslab/TOFU',cfg.split+'_perturbed')['train']
            retain_eval_data=load_dataset('locuslab/TOFU','retain_perturbed')['train']
        elif cfg.split=='forget05':
            forget_data=load_dataset('locuslab/TOFU',cfg.split+'_perturbed')['train']
            retain_eval_data=load_dataset('locuslab/TOFU','retain_perturbed')['train']
        else:
            forget_data=load_dataset('locuslab/TOFU',cfg.split+'_perturbed')['train']
            retain_eval_data=load_dataset('locuslab/TOFU','retain_perturbed')['train']
    else:
        if cfg.split=='forget01':
            forget_data=load_dataset('locuslab/TOFU',cfg.split)['train']
            retain_eval_data=load_dataset('locuslab/TOFU',retain_name)['train'].train_test_split(train_size=400,shuffle=False)['train']
        elif cfg.split=='forget05':
            forget_data=load_dataset('locuslab/TOFU',cfg.split)['train']
            retain_eval_data=load_dataset('locuslab/TOFU',retain_name)['train'].train_test_split(train_size=400,shuffle=False)['train']
        else:
            forget_data=load_dataset('locuslab/TOFU',cfg.split)['train']
            retain_eval_data=load_dataset('locuslab/TOFU',retain_name)['train'].train_test_split(train_size=400,shuffle=False)['train']
        
        
    retain_eval_loader=torch.utils.data.DataLoader(retain_eval_data,batch_size=1)
    forget_loader=torch.utils.data.DataLoader(forget_data,batch_size=1)

    log1 = logging.getLogger("Unlearning")
    if cfg.ps_type == 'perturb':
        log_file_path = cfg.model_path+f'/ps_perturb.log'
    elif cfg.ps_type == 'exact': 
        log_file_path = cfg.model_path+f'/energy_exact_{cfg.topk}.log'
    else: raise RuntimeError('error here in logger')
    file_handler = logging.FileHandler(log_file_path)
    file_handler.setLevel(logging.DEBUG)
    log1.addHandler(file_handler)
    
    def string2token(strings):
        tks = [tokenizer.encode(_, add_special_tokens=True, return_tensors='pt').to(model.device)[0] for _ in strings]
        tk_lens = [_.size(0) for _ in tks]
        return {'token': tks, 'length': tk_lens}

    def processing(loader, model):
        ps_list = []
        for idx, s in enumerate(loader):
            if cfg.ps_type == 'perturb':
                ques, anws = s['paraphrased_question'], s['answer']
            else:
                ques, anws = s['question'], s['answer'] # ['paraphrased_question'][0]
            fuls = [f"### Question: {que}\n ### Answer: {ans}" for que, ans in zip(ques, anws)]
            ques = [f"### Question: {que}\n ### Answer: " for que, ans in zip(ques, anws)]
            _ques_tks_and_lens, _fuls_tks_and_lens = string2token(ques), string2token(fuls)
            ques_tks, ques_tks_lens = _ques_tks_and_lens['token'], _ques_tks_and_lens['length']
            fuls_tks, fuls_tks_lens = _fuls_tks_and_lens['token'], _fuls_tks_and_lens['length']
            #if cfg.zlib_set==True:
            #    zlib_entropy=len(zlib.compress(bytes(s['answer'][0], "utf-8")))
            outputs=model.model(fuls_tks[0].unsqueeze(0))
            hidden_states = outputs[0]
            logits = model.lm_head(hidden_states)
            shift_logits=logits[:,ques_tks_lens[0]-1:-1]
            energy_sample = torch.logsumexp(shift_logits.view(-1, shift_logits.size(-1)), dim=-1)

            topk_values, _ = torch.topk(energy_sample, cfg.topk)  # 取出最大的 k 个值
            ps_list.append(topk_values.mean().item())
        return ps_list
    model=model_mix(model,before_ckpt,after_ckpt,0)
    energy_forget = processing(forget_loader, model)
    energy_retain = processing(retain_eval_loader, model)
    log1.info(f'Energy Forget Set: {energy_forget}')
    log1.info(f'Energy Retain Set: {energy_retain}')
    retain_mean = np.mean(energy_retain)
    forget_mean = np.mean(energy_forget)
    log1.info(f'Energy Retain Mean: {retain_mean}')
    log1.info(f'Energy Forget Mean: {forget_mean}')
    log1.info(f'Energy Final: {(forget_mean+retain_mean)/2}')
    retain_min = np.min(energy_retain)
    forget_max = np.max(energy_forget)
    log1.info(f'Energy Retain Min: {retain_min}')
    log1.info(f'Energy Forget Max: {forget_max}')
if __name__ == "__main__":
    main()