from torch.utils.data import DataLoader
from src.datasets import load_dataset
from sentence_transformers import SentenceTransformer, InputExample, losses
from pathlib import Path
import numpy as np
import logging

logger = logging.getLogger(__name__)

def compute_acc(model, edits, rephrases, localities):
    # before training
    corrects = []
    for anchor, positive, negative in zip(edits[:, 0].tolist(), rephrases[:, 0].tolist(), localities[:, 0].tolist()):
        scores = model.similarity(
            model.encode([anchor], show_progress_bar=False),
            model.encode([positive, negative], show_progress_bar=False)
        )[0]
        corrects.append(scores[0].item() >= 0.6 and scores[1].item() < 0.6)
    return np.array(corrects).mean()

def train_gate(cfg):
    if Path(f'checkpoints/{cfg.name}').exists():
        logger.info('Gate has been trained, overwrite')
    edits, rephrases, localities = load_dataset(cfg)
    edits, rephrases, localities = np.array(edits), np.array(rephrases), np.array(localities)
    model = SentenceTransformer('all-MiniLM-L6-v2').to('cuda')

    logger.info(f'Before-acc: {compute_acc(model, edits, rephrases, localities)}')

    train_examples = []
    for anchor, positive in zip(edits[:, 0].tolist(), rephrases[:, 0].tolist()):
        train_examples.append(
            InputExample(texts=[anchor, positive])
        )

    logger.info(f'Training examples: {len(train_examples)}')

    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
    train_loss = losses.MultipleNegativesRankingLoss(model=model)

    model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=50)

    logger.info(f'After-acc: {compute_acc(model, edits, rephrases, localities)}')

    model.save(f'checkpoints/{cfg.name}')