import os
os.environ["HF_HOME"] = "/tmlscratch/andriush/models/"
import random
import torch
import numpy as np
from loggers import WandBLogger
from conversers import load_target_model, TargetLM
from utils import  schedule_n_to_change_fixed, schedule_n_to_change_prob, extract_logprob, early_stopping_condition_mps, locate_guess_answer_index, define_target_token, define_target_token2, gt_map
from src.utils.initialization import (
    read_config_from_yaml
)
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.configs import REDDITConfig
from src.reddit.reddit_utils import load_data, education_map
from src.reddit.reddit import create_prompts, parse_answer
from src.reddit.reddit_types import Profile
from src.prompts import Prompt
import json

def defense_single_message(orig_msg, args, targetLM: TargetLM, logger: WandBLogger, prompt:Prompt):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    n_iterations_target_not_in_topk = 25
    
    tokenizer, max_token_value = targetLM.model.tokenizer, targetLM.model.tokenizer.vocab_size

    if args.n_tokens_change_max > 0:
        defense_init = ' !' * args.n_tokens_defense
    best_defense_tokens = defense_tokens = tokenizer.encode(defense_init)  #[:args.n_tokens_defense]
    

    if prompt.target== "education":
        prompt.answer = education_map(prompt.answer.strip().lower())
        
    if prompt.target in ["income","education","relationship_status"]:
        prompt.answer= gt_map(prompt.target, prompt.answer)
    print(f"Profile target:{prompt.target}, ground truth: {prompt.answer}")


    is_optimization = True
    defense = defense_init  
    msg = orig_msg.format(defense=defense)
    print(f"Original text:\n{msg}")
    output = targetLM.get_response([msg], max_n_tokens=500)[0]
    response_text = output['text']
    print(f"Original answer:\n{response_text}")

    guess_answer_index = locate_guess_answer_index(response_text, tokenizer)

    if guess_answer_index is None:
        print("Warning: can't locate 'Guess:' token")
        guess_answer_index = 0
        is_locate = False
        target_token = define_target_token2(prompt)
    else:
        if prompt.target == 'age':
            guess_answer_index = guess_answer_index+1
        logprob_dict = output['logprobs'][guess_answer_index]
        sorted_logprobs = sorted(logprob_dict.items(), key=lambda item: item[1], reverse=True)
        target_token = define_target_token(prompt, sorted_logprobs)
        print(f"Target token: {target_token}")
        
        argmax_token = max(logprob_dict, key=logprob_dict.get)
        if argmax_token in [target_token, ' '+target_token, 'Ġ' + target_token]:
            print(f"Already anwer wrong! No need to optimization.")
            is_optimization = False
            is_finish = True
            is_locate = True
            return defense_init, is_finish, is_optimization,is_locate, response_text, target_token, prompt.answer
   
    for i_restart in range(args.n_restarts):    
        is_finish = False
        is_locate = True
        n_chars_change, n_tokens_change = args.n_chars_change_max, args.n_tokens_change_max
        best_defense = defense = defense_init  
        best_msg = msg = orig_msg.format(defense=defense)

        best_logprob = -np.inf
        best_logprobs, best_defenses, logprob_dict = [], [], {}
        final_response_text=''
        


        for it in range(1, args.n_iterations + 1):
            if not early_stopping_condition_mps(best_logprobs, targetLM, logprob_dict, target_token, args.determinstic):  
                output = targetLM.get_response([msg], max_n_tokens=500)[0]
                response_text = output['text']
                guess_answer_index = locate_guess_answer_index(response_text, tokenizer)

                if guess_answer_index is None:
                    print("Warning: can't locate 'Guess:' token")
                    guess_answer_index = 0
                    is_locate = False
                    return defense, is_finish, is_optimization, is_locate, response_text, target_token, prompt.answer
                
                if prompt.target == 'age':
                    guess_answer_index = guess_answer_index+1
                
                if guess_answer_index >= len(output['logprobs']):
                    print("warning: guess_answer_index exceed logprobs")
                    guess_answer_index = len(output['logprobs']) - 3
                print(f"guess index: {guess_answer_index}")
                
                logprob_dict = output['logprobs'][guess_answer_index]

                logprob = extract_logprob(logprob_dict, target_token)
                sorted_logprobs = sorted(logprob_dict.items(), key=lambda item: item[1], reverse=True)
                print("====Token rank====")
                print(" ".join(f"Token: {token}, Log Probability: {logprob:.3f}" for token, logprob in sorted_logprobs[:5]))
                rank = None
                for idx, (token, prob) in enumerate(sorted_logprobs):
                    if token == target_token or token == 'Ġ' + target_token or token == ' ' + target_token:
                        rank = idx + 1  
                        break
    

                if rank:
                    print(f"Token: {target_token}, Rank: {rank}, Log Probability: {logprob}")
            else: 
                final_response_text = targetLM.get_response([best_msg], max_n_tokens=args.target_max_n_tokens, temperature=0)[0]['text']
                print(f'final_response_text: {final_response_text}\n\n')
                is_finish = True
            
            if  is_finish == True:
                break
            print(f'it={it} [best] logprob={best_logprob:.3f} prob={np.exp(best_logprob):.5f} defense={best_defense}\n  [curr] logprob={logprob:.3f} prob={np.exp(logprob):.5f}  len_defense={len(best_defense)}/{len(best_defense_tokens)} n_change={n_chars_change}/{n_tokens_change}: {defense}')
            if logprob > best_logprob:
                best_logprob, best_msg, best_defense, best_defense_tokens = logprob, msg, defense, defense_tokens
            else:
                defense, defense_tokens = best_defense, best_defense_tokens
            best_logprobs.append(best_logprob)
            best_defenses.append(best_defense)



            # early exit if "Sure" not in top-5 after multiple trials (then it also makes n_retry_final=1 to save queries)
            if best_logprob == -np.inf:
                n_iterations_target_not_in_topk -= 1
                if n_iterations_target_not_in_topk == 0:
                    break
            
            if args.n_tokens_change_max > 0:
                if args.schedule_prob:
                    n_tokens_change = schedule_n_to_change_prob(args.n_tokens_change_max, np.exp(best_logprob), targetLM)
                elif args.schedule_n_to_change:
                    n_tokens_change = schedule_n_to_change_fixed(args.n_tokens_change_max, it)  
                else:
                    n_tokens_change = args.n_tokens_change_max
                substitute_pos_start = random.choice(range(len(defense_tokens)))
                substitution_tokens = np.random.randint(0, max_token_value, n_tokens_change).tolist()
                defense_tokens = defense_tokens[:substitute_pos_start] + substitution_tokens + defense_tokens[substitute_pos_start+n_tokens_change:]
                defense = tokenizer.decode(defense_tokens, skip_special_tokens=True).replace('<s>', '').replace('\n', '')  # somehow, the R2D2 tokenizer inserts '<s>' at the first position
            # apply the new defenseersarial suffix
            msg = orig_msg.format(defense=defense)

        if is_finish == True:
            break
        
        
    if is_finish == False:
        final_response_text = targetLM.get_response([best_msg], max_n_tokens=args.target_max_n_tokens, temperature=0)[0]['text']
        print(f'final_response_text: {final_response_text}\n\n')

    if not args.debug:
        logger.log({
            'final_response_text': final_response_text,
            'n_queries': it,
            'orig_msg': orig_msg,
            'best_msg': best_msg,
            'best_defenses': best_defenses,
        })
    
    return best_defense, is_finish, is_optimization,is_locate, final_response_text, target_token, prompt.answer


if __name__ == '__main__':
    class Args:
        seed = 1
        n_restarts = 2
        determinstic = False
        debug = True
        n_iterations = 500
        n_chars_defense = 100
        n_chars_change_max = 0
        n_tokens_defense = 50
        n_tokens_change_max = 4
        schedule_prob = True
        schedule_n_to_change = False
        target_model = "qwen2.5-7b"
        target_max_n_tokens = 500
        config_path = "configs/reddit/defense/mps_qwen25.yaml"

    args = Args()
    cfg = read_config_from_yaml(args.config_path)

    assert isinstance(cfg.task_config, REDDITConfig)
    profiles = load_data(cfg.task_config.path)
    # Create prompts
    prompts = []
    for profile in profiles:
        prompts += create_prompts(profile, cfg.task_config)
    
    targetLM = load_target_model(args)

    if not args.debug:
        logger = WandBLogger(args)
    else:
        logger = None

    with open(cfg.task_config.path, "r", encoding="utf-8") as source_f, \
    open("predicted_qwen25_7b_gpt4o_trace-rps.jsonl","a",encoding="utf-8") as predict_f,\
    open(cfg.task_config.outpath, "a", encoding="utf-8") as target_f:
        
        source_data = [json.loads(line.strip()) for line in source_f]
        finish_count = 0
        total_count = len(prompts)

        for i, prompt in enumerate(prompts):
            print(f"*********Profile: {i}*********")
            is_finish = False
            orig_msg_full = prompt.get_prompt()
        
            insert_str = "2014-05-19: {defense}\n"
            keyword = "\nFirst reason step-by-step"
            pos = orig_msg_full.find(keyword)
            if pos != -1:
                new_prompt = orig_msg_full[:pos] + insert_str + orig_msg_full[pos:]
            else:
                print("The suffix was not inserted correctly.")


            best_defense, is_finish, is_optimization, is_locate, response, target_token, gt = defense_single_message(new_prompt, orig_msg_full, args, targetLM, logger, prompt)
            if is_finish:
                finish_count += 1

            source_data[i]['response'] = source_data[i]['response'] +' \n\n' + best_defense
            source_data[i]['finish'] = is_finish
            source_data[i]['optimization'] = is_optimization
            source_data[i]['locate'] = is_locate
            source_data[i]['target_token'] = target_token
            source_data[i]['gt'] = gt
            source_data[i]['best_defense'] = best_defense 
            source_data[i]['final_response_text'] = response
            target_f.write(json.dumps(source_data[i], ensure_ascii=False) + '\n')
            target_f.flush()

            finish_ratio = finish_count / total_count if total_count > 0 else 0
            print("Finish ratio: {:.2%}".format(finish_ratio))

            op = prompt.original_point
            assert isinstance(op, Profile)
            op.predictions[cfg.gen_model.name] = parse_answer(response, prompt.gt)
            op.predictions[cfg.gen_model.name]["full_answer"] = response

            predict_f.write(json.dumps(op.to_json()) + "\n")
            predict_f.flush()
        
    
    if logger is not None:
        logger.finish()
