import json
from utils.utils import set_seed
import pickle
import sys
set_seed(11)
from transformers import AutoTokenizer
import os
from transformers.trainer_utils import get_last_checkpoint
from transformers import logging, AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM
from arguments import get_args
from llm_logger import main_logger
from src import get_trainer
from src.better_tasks import get_preprocessed_dataset
from pathlib import Path
import numpy as np
import datasets
from attacks.mia_utils import get_losses
import torch
import getpass
import argparse
import re

def get_trained_model(path):
    tokenizer = AutoTokenizer.from_pretrained(path)
    try: 
        model = AutoPeftModelForCausalLM.from_pretrained(path).to("cuda")
        model.print_trainable_parameters()
    except:
        model = AutoModelForCausalLM.from_pretrained(path).to("cuda")

    return tokenizer, model


def parse_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')
    parser.add_argument('--output_dir', type=str, default="results")
    parser.add_argument('--cache_dir', type=str, default="results")

    return parser.parse_args()


reg = re.compile(r"(\d+)e-(\d+)")
def get_checkpoint_info(checkpoint):
    name = reg.sub(r"\1e~\2", checkpoint.split("/")[-1])
    info = {}    
    for i, pos in enumerate(['dataset', 'setting', 'lr', 'epoch', 'batch_size', 'target_epsilon', 'prefix_length', 'prefix_type', 'shadow_id']):
        info[pos] = name.split('-')[i].replace('~', '-')
    return info


def main():
    args = parse_args()
    print(args)
    os.makedirs(args.output_dir, exist_ok=True)
    info = get_checkpoint_info(args.init_checkpoint)
    
    tokenizer, model = get_trained_model(args.init_checkpoint)
    
    tokenizer.pad_token = tokenizer.eos_token
    max_seq_length = 256
    output = get_preprocessed_dataset(info['dataset'].lower(), args.cache_dir,
                                   tokenizer, max_seq_length, 
                                   int(info['shadow_id']), 
                                   info['prefix_type'], 10, int(info['prefix_length']), 0.01,
                                   z_ratio=0.1
                                   )
    
    train_set = datasets.Dataset.from_dict(tokenizer(tokenizer.batch_decode(output['train_tokens'], skip_special_tokens=True), padding="max_length", truncation=True, max_length=max_seq_length, return_tensors='pt'))
    val_set = datasets.Dataset.from_dict(tokenizer(tokenizer.batch_decode(output['val_tokens'], skip_special_tokens=True), padding="max_length", truncation=True, max_length=max_seq_length, return_tensors='pt'))
    main_logger.info(f"Dataset loaded. Train dataset length: {len(train_set)}, Validation dataset length: {len(val_set)}")
    
    with torch.no_grad():
        # Save the losses
        for name in ['train_tokens', 'val_tokens', 'z_tokens']:
            tokens = tokenizer(tokenizer.batch_decode(output[name], skip_special_tokens=True), padding="max_length", truncation=True, max_length=max_seq_length, return_tensors='pt')#.input_ids
            losses = get_losses(model, tokens, 16).numpy(force=True)
            print(np.mean(losses))
            np.save(f'{args.output_dir}/losses_{name}.npy', losses)



if __name__ == '__main__':    
    main()
