import numpy as np
from language_models import HuggingFace
from language_models import GPT
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.prompts import Prompt


def insert_defense_string(msg, defense):
    return msg + defense  


def schedule_n_to_change_fixed(max_n_to_change, it):
    """ Piece-wise constant schedule for `n_to_change` (both characters and tokens) """
    # it = int(it / n_iters * 10000)

    if 0 < it <= 10:
        n_to_change = max_n_to_change
    elif 10 < it <= 25:
        n_to_change = max_n_to_change // 2
    elif 25 < it <= 50:
        n_to_change = max_n_to_change // 4
    elif 50 < it <= 100:
        n_to_change = max_n_to_change // 8
    elif 100 < it <= 500:
        n_to_change = max_n_to_change // 16
    else:
        n_to_change = max_n_to_change // 32
    
    n_to_change = max(n_to_change, 1)

    return n_to_change


def schedule_n_to_change_prob(max_n_to_change, prob, target_model):
    """ Piece-wise constant schedule for `n_to_change` based on the best prob """
    # it = int(it / n_iters * 10000)

    # need to adjust since llama and r2d2 are harder to attack
    if isinstance(target_model.model, HuggingFace):
        if 0 <= prob <= 0.01:
            n_to_change = max_n_to_change
        elif 0.01 < prob <= 0.1:
            n_to_change = max_n_to_change // 2
        elif 0.1 < prob <= 1.0:
            n_to_change = max_n_to_change // 4
        else:
            raise ValueError(f'Wrong prob {prob}')
    else:
        if 0 <= prob <= 0.1:
            n_to_change = max_n_to_change
        elif 0.1 < prob <= 0.5:
            n_to_change = max_n_to_change // 2
        elif 0.5 < prob <= 1.0:
            n_to_change = max_n_to_change // 4
        else:
            raise ValueError(f'Wrong prob {prob}')
    
    n_to_change = max(n_to_change, 1)

    return n_to_change


def extract_logprob(logprob_dict, target_token):
    logprobs = []
    if ' ' + target_token in logprob_dict:
        logprobs.append(logprob_dict[' ' + target_token])
    if target_token in logprob_dict:
        logprobs.append(logprob_dict[target_token])
    
    if 'Ġ' + target_token in logprob_dict:
        logprobs.append(logprob_dict['Ġ' + target_token])
    
    if logprobs == []:
        return -np.inf
    else:
        return max(logprobs)


def early_stopping_condition(best_logprobs, target_model, logprob_dict, target_token, determinstic, no_improvement_history=5000):
    if determinstic and logprob_dict != {}:
        argmax_token = max(logprob_dict, key=logprob_dict.get)
        if argmax_token in [target_token, ' '+target_token, 'Ġ'+target_token]:
            return True
        else:
            return False
        
    if len(best_logprobs) == 0:
        return False
    
    best_logprob = best_logprobs[-1]

    if isinstance(target_model.model, HuggingFace) and np.exp(best_logprob) > 0.55 and len(best_logprobs)>no_improvement_history:
        return True  
    
    if np.exp(best_logprob) > 0.8:  
        return True
    return False

def early_stopping_condition2(best_logprobs, target_model, logprob_dict, target_token, determinstic, no_improvement_history=2000):
    if  determinstic and logprob_dict != {}:
        argmax_token = max(logprob_dict, key=logprob_dict.get)
        if argmax_token in [target_token, ' '+target_token, 'Ġ'+target_token]:
            return True
        else:
            return False
        
    if len(best_logprobs) == 0:
        return False
    
    best_logprob = best_logprobs[-1]
    if isinstance(target_model.model, HuggingFace) and np.exp(best_logprob) > 0.45 and len(best_logprobs)>no_improvement_history:
        return True  
    
    if np.exp(best_logprob) > 0.55:  
        return True
    return False




def early_stopping_condition_mps(best_logprobs, target_model, logprob_dict, target_token,  no_improvement_history=200):
    if logprob_dict != {}:
        argmax_token = max(logprob_dict, key=logprob_dict.get)
        if argmax_token in [target_token, ' '+target_token, 'Ġ'+target_token]:
            return True
        else:
            return False
        
    if len(best_logprobs) == 0:
        return False
    
    best_logprob = best_logprobs[-1]

    
    if isinstance(target_model.model, HuggingFace) and np.exp(best_logprob) > 0.55 and len(best_logprobs)>no_improvement_history:
        return True 
    if np.exp(best_logprob) > 0.8:  
        return True
    return False


def locate_guess_answer_index(response_text: str, tokenizer) -> int:

    token_ids = tokenizer.encode(response_text, add_special_tokens=False) # 
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    
    guess_idx = None
    for i, token in enumerate(tokens):
        token_clean = token.strip()

        if token_clean.startswith("Guess"):

            if ":" in token_clean:
                guess_idx = i
            else:

                if i + 1 < len(tokens) and ":" in tokens[i+1]:
                    guess_idx = i + 1
                else:
                    guess_idx = i
            break
    if guess_idx is None:
        return None

    answer_idx = guess_idx + 1
    return answer_idx


def define_target_token(prompt: Prompt, sorted_logprobs):
    if prompt.target=="gender":
        if prompt.answer=="male":
            target_token = "Female"
        else:
            target_token = "Male"

    elif prompt.target== "age":
        for token, _ in sorted_logprobs:
            if token.lower().strip() not in str(prompt.answer)[0]:
                target_token = token
                break

    else:
        target_token = None
        for token, _ in sorted_logprobs:
            if token.lower().strip() not in str(prompt.answer).lower().strip():
                target_token = token
                if "no relation" in prompt.answer and "sin" in target_token.lower().strip():
                    target_token = None
                    continue
                if "married" in prompt.answer and "mar" in target_token.lower().strip():
                    target_token = None
                    continue
                if ';'in target_token or ':' in target_token:
                    continue
                break

        if target_token is None:
            target_token = define_target_token2(prompt)
    return target_token

def define_target_token2(prompt: Prompt):
    if prompt.target=="gender":
        if prompt.answer=="male":
            target_token = "Female"
        else:
            target_token = "Male"
    elif prompt.target=="income":
        if "no income" in prompt.answer:
            target_token = "Low"
        elif "low" in prompt.answer:
            target_token = "Medium"
        elif "medium" in prompt.answer:
            target_token = "High"
        elif "high (60" in prompt.answer:
            target_token = "Very"
        else:
            target_token = "High"
    elif prompt.target=="relationship_status":
        if "no relation" in prompt.answer:
            target_token = "In"
        elif "in relation" in prompt.answer:
            target_token = "No"
        elif "married" in prompt.answer:
            target_token = "Div" 
        elif "divorced" in prompt.answer:
            target_token = "Married" 
        else:
            target_token = "Married"
    elif prompt.target== "education":
        if "no highschool" in prompt.answer:
            target_token = "In"
        elif "in highschool" in prompt.answer:
            target_token = "HS"
        elif "hs diploma" in prompt.answer:
            target_token = "In"
        elif "in college" in prompt.answer:
            target_token = "College"
        elif "college degree" in prompt.answer:
            target_token = "PhD"
        elif "PhD" in prompt.answer:
            target_token = "College"
        else:
            target_token = "No"
    elif prompt.target== "age":
        target_age = 100-prompt.answer
        target_token = str(target_age)[0]

    elif prompt.target== "location":
        target_token = "Japan"

    elif prompt.target== "birth_city_country":
        target_token = "Japan"

    elif prompt.target== "occupation":
        target_token = "Teacher"

    else:
        target_token = prompt.answer
    
    return target_token

def gt_map(attribute: str, val: str) -> str:
    if attribute == "relationship_status":
        choices = {
            "single": "no relation",
            "widowed": "no relation",
            "no relation": "no relation",
            "engaged": "in relation",
            "in a relationship": "in relation",
            "relation": "in relation",
            "married": "married",
            "divorced": "divorced",
        }
    elif attribute == "income":
        choices = {
            "no": "no",
            "low": "low",
            "middle": "medium",
            "medium": "medium",
            "high": "high",
            "very high": "very high",
        }
    elif attribute == "education":
        choices = {
            "no hs": "no highschool",
            "no highschool": "no highschool",
            "in hs": "in highschool",
            "in highschool": "in highschool",
            "hs": "hs diploma",
            "hs diploma": "hs diploma",
            "in college": "in college",
            "college": "college degree",
            "college degree": "college degree",
            "phd": "phd",
            "bachelor's degree": "college degree",
            "master's degree": "master's degree",
            "diploma in some field": "diploma in some field",
            "law degree": "college degree",
            "degree in some field": "college degree"
        }
    else:
        assert False, f"Unknown attribute {attribute}"
    # print(val.lower())
    return choices[val.lower()]


