import torch
import random
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
from datasets import load_dataset
import json
import argparse

def prepare_recover_data(model, trainset, batch_size, path, ratio=0.01):
    all_indices = list(range(len(trainset)))
    num_samples = int(len(trainset) * ratio)
    indices = random.sample(all_indices, num_samples)

    recover_data = {key: [] for key in trainset[0].keys() if key != 'label'}
    recover_data['label'] = []
    true_labels = [] 
    
    for idx in indices:
        item = trainset[idx]
        for key in recover_data.keys():
            if key != 'label':
                recover_data[key].append(item[key])
        true_labels.append(item['label']) 

    subset = torch.utils.data.Subset(trainset, indices)
    def collate_fn(batch):
        input_ids = [item['input_ids'] for item in batch]
        attention_mask = [item['attention_mask'] for item in batch]
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
        }

    dataloader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    correct, tot = 0, 0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Processing"):
            inputs = {
                "input_ids": batch["input_ids"].to(device),
                "attention_mask": batch["attention_mask"].to(device),
            }
            outputs = model(**inputs)
            outputs = outputs.logits if hasattr(outputs, 'logits') else outputs
            predictions = torch.argmax(outputs, dim=1)
            for i in range(inputs["input_ids"].size(0)):
                predicted_label = predictions[i].cpu().tolist()
                recover_data["label"].append(predicted_label) 
                tot += 1
                if predicted_label == true_labels[len(recover_data["label"]) - 1]:
                    correct += 1
    accuracy = correct / tot
    print(f"Accuracy of predicted labels vs. true labels: {accuracy:.4f}")
    with open(path, "w") as f:
        json.dump(recover_data, f, indent=4)
        

def parse_args():
    parser = argparse.ArgumentParser(description='Obfuscate with random v')

    # which model you tend to finetuing
    parser.add_argument('--model_name_or_path', type=str, required=True, help='model name or path, you can also pass the path of model you want to attack')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--data_path', type=str, required=True, help='Path to the original training dataset.')
    parser.add_argument('--num_labels', type=int, required=True, help='The number of the dataset labels.')
    parser.add_argument('--batch_size', type=int, required=True, help='Batch size')
    parser.add_argument('--output_path', type=str, required=True, help='The output path of recover dataset.')
    parser.add_argument('--ratio', type=float, default=0.01, help='ratio=len(recover_dataset)/len(train_dataset).')
    
    
    args = parser.parse_args()

    return args


def main():
    args = parse_args()
    
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

    model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=args.num_labels, trust_remote_code=True)
    model = model.to(device)

    # print the number of the model parameters
    print(model)

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    def tokenize_function(example):
        result = None
        if 'mnli' in args.data_path:
            result = tokenizer(example["premise"], example["hypothesis"], truncation=True, max_length=args.src_len, padding='max_length')
        elif 'sst2' in args.data_path:
            result = tokenizer(example["sentence"], truncation=True, max_length=args.src_len, padding='max_length')
        elif 'qnli' in args.data_path:
            result = tokenizer(example["question"], example["sentence"], truncation=True, max_length=args.src_len, padding='max_length')
        elif 'qqp' in args.data_path:
            result = tokenizer(example["question1"], example["question2"], truncation=True, max_length=args.src_len, padding='max_length')
        elif 'wic' in args.data_path:
            word = example["word"]
            sentence1 = example["sentence1"]
            sentence2 = example["sentence2"]
            label = example["label"]  # 0: 不同，1: 相同
            
            input_texts = []
            for word, sent1, sent2 in zip(word, sentence1, sentence2):
                input_texts.append(f"{sent1} [SEP] {sent2} [SEP] {word}")
            
            encoding = tokenizer(
                input_texts,
                truncation=True, 
                max_length=args.src_len, 
                padding='max_length'
            )
            return {
                "input_ids": encoding["input_ids"],
                "attention_mask": encoding["attention_mask"],
                "label": torch.tensor(label, dtype=torch.long)
                }
        if "label" in example:
            result["label"] = torch.tensor(example["label"], dtype=torch.long)
        return result
    
    train_dataset = load_dataset(args.data_path)['train']
    tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
    prepare_recover_data(model, tokenized_train_dataset, args.batch_size, args.output_path, args.ratio)

    
if __name__ == "__main__":
    main()