import torch
import random
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from datasets import load_dataset
import json
import argparse

def prepare_recover_data(model, trainset, batch_size, path, ratio=0.01):
    all_indices = list(range(len(trainset)))
    num_samples = int(len(trainset) * ratio)
    indices = random.sample(all_indices, num_samples)

    recover_data = {key: [] for key in trainset[0].keys() if key != 'labels'}
    recover_data['labels'] = []
    true_labels = [] 
    
    for idx in indices:
        item = trainset[idx]
        for key in recover_data.keys():
            if key != 'labels':
                recover_data[key].append(item[key])
        true_labels.append(item['labels']) 

    subset = torch.utils.data.Subset(trainset, indices)
    def collate_fn(batch):
        input_ids = [item['input_ids'] for item in batch]
        attention_mask = [item['attention_mask'] for item in batch]
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
        }

    dataloader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Processing"):
            inputs = {
                "input_ids": batch["input_ids"].to(device),
                "attention_mask": batch["attention_mask"].to(device),
            }
            outputs = model(**inputs)
            outputs = outputs.logits if hasattr(outputs, 'logits') else outputs
            probabilities = torch.sigmoid(outputs)
            predictions = (probabilities > 0.5).int()
            for i in range(inputs["input_ids"].size(0)):
                predicted_label = predictions[i].cpu().tolist()
                recover_data["labels"].append(predicted_label) 
    with open(path, "w") as f:
        json.dump(recover_data, f, indent=4)
        

def parse_args():
    parser = argparse.ArgumentParser(description='Obfuscate with random v')

    # which model you tend to finetuing
    parser.add_argument('--model_name_or_path', type=str, required=True, help='model name or path, you can also pass the path of model you want to attack')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--data_path', type=str, required=True, help='Path to the original training dataset.')
    parser.add_argument('--batch_size', type=int, required=True, help='Batch size')
    parser.add_argument('--output_path', type=str, required=True, help='The output path of recover dataset.')
    parser.add_argument('--ratio', type=float, default=0.01, help='ratio=len(recover_dataset)/len(train_dataset).')
    
    
    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=28, problem_type="multi_label_classification")
    model = model.to(device)

    # print the number of the model parameters
    print(model)

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    
    def tokenize_function(examples):
        tokenized = tokenizer(
            examples['text'],
            truncation=True,
            padding='max_length',
            max_length=args.src_len,
            return_tensors="pt"
        )
        
        labels = torch.zeros((len(examples['labels']), 28), dtype=torch.float32)
        for i, label_indices in enumerate(examples['labels']):
            labels[i, label_indices] = 1.0
        
        tokenized['labels'] = labels

        return tokenized
    
    train_dataset = load_dataset(args.data_path)['train']
    tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
    prepare_recover_data(model, tokenized_train_dataset, args.batch_size, args.output_path, args.ratio)

    
if __name__ == "__main__":
    main()