from data import SNLIDataset, snli_labels, create_collate_fn
from nli import BertNLI
import util
from pathlib import Path
import json
from torch.utils.data import ConcatDataset
from transformers import BertTokenizer, RobertaTokenizer
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

def get_best_epoch(model_save_dir):
    val_results = util.load_jsonl(model_save_dir/'val_results.json')
    val_results.sort(key=lambda x: x['acc'])
    return val_results[-1]['epoch']

if __name__ == '__main__':

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("model_save_dir", type=Path)
    parser.add_argument("data_dir", type=Path)
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--device', type=str, default='cpu')
    args = parser.parse_args()

    print(args.model_save_dir)
    preds_file = args.model_save_dir/'preds.jsonl'
    items = util.load_jsonl(preds_file) if preds_file.exists() else []
    items = [i for i in items if i['pairID'][-5:] != 'ab-ba']
    existing_ids = {item['pairID'] for item in items}
        
    train_args = json.load((args.model_save_dir/'args.json').open())


    match train_args['architecture']:
        case 'BERT':
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            roberta_se = False
        case 'RoBERTa+SE':
            tokenizer = RobertaTokenizer.from_pretrained('FacebookAI/roberta-base')
            roberta_se = True

    label_itos = snli_labels
    label_stoi = {l:i for i,l in enumerate(label_itos)}

    test_data = ConcatDataset([
        SNLIDataset(args.data_dir/'snli_1.0/snli_1.0_test.jsonl', 'gold_label', exclude_ids=existing_ids),
        SNLIDataset(args.data_dir/f'generated/llama3.2:3b_test.jsonl', 'model_label', exclude_ids=existing_ids),
        SNLIDataset(args.data_dir/f'generated/llama3.3:70b_test.jsonl', 'model_label', exclude_ids=existing_ids),
        SNLIDataset(args.data_dir/f'generated/deepseek-r1:70b_test.jsonl', 'model_label', exclude_ids=existing_ids),
        SNLIDataset(args.data_dir/f'inferred/llama3.3:70b_test.jsonl', None, exclude_ids=existing_ids),
    ])


    best_epoch = get_best_epoch(args.model_save_dir)
    model = torch.load(args.model_save_dir/f'epoch-{best_epoch}', map_location=args.device)
    model.eval()

    collate_fn = create_collate_fn(tokenizer, label_stoi, args.device, 
                                   roberta_se=roberta_se,
                                   hypothesis_only=train_args['hypothesis_only'])
    test_loader = DataLoader(test_data, args.batch_size, collate_fn=collate_fn, shuffle=False)


    with torch.no_grad():
        for item_ids, x, y in tqdm(test_loader):
            logits = model(x)
            batch_pred_probs = F.softmax(logits, dim=1)
            for item_id, pred_probs in zip(item_ids, batch_pred_probs):
                items.append({'pairID': item_id} | dict(zip(snli_labels, pred_probs.tolist())))

    with preds_file.open('w') as f:
        for line in items:
            util.write_jsonl(line, f)
