import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import pickle
import torch
import torch.nn.functional as F
from tqdm import tqdm
import os
import argparse
import pandas as pd
import random
from secret_utils import DICT_GET_PRIOR
import gc
import hashlib
from datetime import datetime
from datasets import load_dataset

torch.set_float32_matmul_precision('high')


parser = argparse.ArgumentParser(description='Get data')
parser.add_argument('--model_name', type=str, default='EleutherAI/pythia-12b', help='Name of the model to use')
parser.add_argument('--max_length', type=int, default=1024, help='Maximum length of the tokens')
parser.add_argument('--suffix_length', type=int, default=64, help='Length of the suffix in characters')
parser.add_argument('--secret_type', type=str, required=True, help='Type of secret to evaluate')
parser.add_argument('--dataset', type=str, required=True, help='')
parser.add_argument('--n_nonmembers', type=int, default=128, help='Number of non-members to evaluate')
parser.add_argument('--start_id', type=int, default=0, help='Start id for the dataset')
parser.add_argument('--end_id', type=int, default=100, help='End id for the dataset')
parser.add_argument('--input_path', type=str, default='.', help='Path to the input data')
parser.add_argument('--output_path', type=str, delefault='.', help='Path to the output data')

args = parser.parse_args()
input_path = args.input_path
output_path = args.output_path

os.makedirs(f'{output_path}', exist_ok=True)

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(42)



device = torch.device('cuda')
print(device)
print('args:', args)

df = pd.read_pickle(f'{input_path}/{args.dataset}.pkl')
df = df[(df['secret_type'] == args.secret_type)]
if len(df) == 0:
    print('No data')
    exit(0)

print('number of secrets:', len(df))


tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
printable_model_name = args.model_name.split('/')[-1]


model = AutoModelForCausalLM.from_pretrained(
    args.model_name,
    device_map=device,
    torch_dtype=torch.bfloat16,
    # force_download=True,
).eval()

try:
    model = model.to_bettertransformer()
except:
    pass
model = torch.compile(model, mode="max-autotune")

def hinge_score(raw_predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    assert raw_predictions.dim() >= 2 and labels.dim() == 1 and raw_predictions.size(0) == len(labels)
    raw_predictions = raw_predictions.to(dtype=torch.float64)

    target_predictions = raw_predictions[torch.arange(len(labels)), ..., labels]
    raw_predictions[torch.arange(len(labels)), ..., labels] = float("-inf")
    return target_predictions - torch.max(raw_predictions, dim=-1).values


@torch.no_grad()
def score_logits(logits, labels, scores_string, prefix):
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    token_log_probs = log_probs.reshape(-1, log_probs.shape[-1])[torch.arange(len(labels.flatten())),labels.flatten()].reshape(labels.shape)
    mu = (probs * log_probs).sum(-1)
    sigma = (probs * torch.square(log_probs)).sum(-1) - torch.square(mu)
    scores_string[f'{prefix}token_log_probs'] = token_log_probs.cpu()
    scores_string[f'{prefix}mu'] = mu.cpu()
    scores_string[f'{prefix}sigma'] = sigma.cpu()
    scores_string[f'{prefix}classification'] = (logits.argmax(-1) == labels).cpu()
    scores_string[f'{prefix}hinge'] = hinge_score(logits.reshape(-1, logits.shape[-1]), labels.flatten()).reshape(labels.shape).cpu()


@torch.no_grad()
def get_score(model, input_ids, ref_input_ids):
    output = model(input_ids[:, :-1], use_cache=False)
    labels = input_ids[:, 1:]
    logits = output.logits.to(torch.float64)
    scores_string = {}
    
    score_logits(logits, labels, scores_string, '')
    del logits, labels, output
    gc.collect()
    torch.cuda.empty_cache()

    input_ids = torch.cat([ref_input_ids[0:1].repeat(input_ids.shape[0], 1), input_ids], dim=1)
    output = model(input_ids[:, :-1], use_cache=False)
    labels = input_ids[:, 1:]
    logits = output.logits.to(torch.float64)
    score_logits(logits, labels, scores_string, 'recall_')

    return scores_string


def compute_metrics(model, input_ids, ref_input_ids, batch_size=32):
    scores_string = {}
    ref_input_ids = ref_input_ids.to(device)
    for i in tqdm(range(0, len(input_ids), batch_size)):
        with torch.no_grad():
            batch_input_ids = input_ids[i:i+batch_size].to(device)
            d_kv = get_score(model, batch_input_ids, ref_input_ids)
            for k, v in d_kv.items():
                if k not in scores_string:
                    scores_string[k] = []
                scores_string[k].append(v)
            del batch_input_ids
    for k, v in scores_string.items():
        scores_string[k] = torch.cat(v, dim=0).cpu()
    return scores_string


# Find the best batch size
batch_size = 2
while True:
    torch.cuda.empty_cache()
    gc.collect()
    ref_batch_input_ids = torch.randint(0, tokenizer.vocab_size, (4, args.max_length+1), device=device)
    try:
        with torch.no_grad():
            batch_input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, args.max_length+1), device=device)
            _ = get_score(model, batch_input_ids, ref_batch_input_ids)
        batch_size *= 2
        if batch_size > 2048:
            break
    except RuntimeError as e:
        print(e)
        break
batch_size = max(1, batch_size//2)
print('batch_size:', batch_size)
torch.cuda.empty_cache()
gc.collect()



extra_data = {
    'members_scores_string': [],
    'nonmembers_scores_string': [],
    'prefix': [],
    'suffix': [],
    'secrets': [],
    'prior_secrets': [],
    'extra': [],
    'group_id': [],
    'args': vars(args),
}

for group_id, ((secret_type, prefix, suffix, extra, dataset), curr_df) in enumerate(df.groupby(['secret_type', 'prefix', 'suffix', 'extra', 'dataset'])):
    print(group_id)
    if args.start_id > group_id:
        continue
    if args.end_id <= group_id:
        break

    print('*'*100)
    print('INFO')
    print(group_id, secret_type, len(curr_df))
    
    set_seed(
        int(hashlib.sha256(
            (str(secret_type) + str(prefix) + str(suffix) + str(extra) + str(dataset)).encode()
        ).hexdigest(), base=16) & 0xffffffff
    )
    
    try:
        prior_secrets = DICT_GET_PRIOR[secret_type](extra, args.n_nonmembers)
    except ValueError as e:
        print(e)
        continue
    

    secrets = curr_df['secret'].tolist()
    
    tokens = tokenizer(
        [prefix+s+suffix[:args.suffix_length] for s in secrets] + [prefix+s+suffix[:args.suffix_length] for s in prior_secrets], 
        return_tensors='pt', max_length=args.max_length+1, truncation=True, padding='max_length')
    
    ref_tokens = tokenizer(
        [prefix+s+suffix[:args.suffix_length] for s in DICT_GET_PRIOR[secret_type](extra, 2)], 
        return_tensors='pt', max_length=args.max_length+1, truncation=True, padding='max_length'
    )
    
    print('[tokenized] member: ', tokenizer.batch_decode(tokens.input_ids[0:1], skip_special_tokens=True))
    print('[tokenized] non-member: ', tokenizer.batch_decode(tokens.input_ids[-1:], skip_special_tokens=True))
    while True:
        try:        
            scores_string = compute_metrics(model, tokens.input_ids, ref_tokens.input_ids, batch_size=batch_size)
        except torch.OutOfMemoryError:
            print('Out of memory, retrying with smaller batch size')
            batch_size = max(1, batch_size//2)
            print('batch_size:', batch_size)
            gc.collect()
            torch.cuda.empty_cache()
            continue
        break

    members_scores_string = {k: v[:len(secrets)] for k, v in scores_string.items()}
    nonmembers_scores_string = {k: v[len(secrets):] for k, v in scores_string.items()}
    del scores_string
    gc.collect()

    extra_data['members_scores_string'].append(members_scores_string)
    extra_data['nonmembers_scores_string'].append(nonmembers_scores_string)
    extra_data['prefix'].append(prefix)
    extra_data['suffix'].append(suffix)
    extra_data['extra'].append(extra)
    extra_data['secrets'].append(secrets)
    extra_data['prior_secrets'].append(prior_secrets)
    extra_data['group_id'].append(group_id)


pickle.dump(extra_data, open(f'{output_path}/{args.secret_type} {args.dataset} {args.start_id:06} {args.end_id:06} {printable_model_name}.pkl', 'wb'))
print('DONE:', datetime.now())
