import random

import torch
from torch.utils.data import Dataset


def generate_preference_dataset(
    pretrained_model, reward_function, 
    finetuned_model=None,
    sample_size=10000, epsilon=0.1, min_dist=0
    ):
    if finetuned_model is None:
        finetuned_model = pretrained_model
    with torch.no_grad():
        y0 = pretrained_model.sample(num_samples=sample_size)
        y1 = finetuned_model.sample(num_samples=sample_size)
        r1 = reward_function(y1)
        r0 = reward_function(y0)

        idx = r1 <= r0
        y1[idx], y0[idx] = y0[idx], y1[idx]
        
        close = torch.abs(r1 - r0) < min_dist
        y1, y0 = y1[~close], y0[~close]

    dataset = PreferenceDataset(y1, y0, epsilon=epsilon)
    return dataset


class FlowMatcherSampler:
    def __init__(self, model, flow, device='cpu'):
        self.device = device
        self.model = model
        self.flow = flow

    def sample(self, num_samples=1):
        with torch.no_grad():
            source = self.model.sample(num_samples=num_samples).to(self.device)
            target = self.flow.compute_target(source, use_torchdiffeq=False)
        return target


class PreferenceDataset(Dataset): 
    def __init__(self, y1, y0, x=None, epsilon=0.):
        self.x = x
        self.y1 = y1
        self.y0 = y0
        self.epsilon = epsilon
        
        perturb_idx = random.sample(
            range(len(y1)), int(epsilon * len(y1))
        )
        for idx in perturb_idx:
            self.y1[idx], self.y0[idx] = self.y0[idx], self.y1[idx]
        
    def __len__(self): 
        return len(self.y0)

    def __getitem__(self, idx): 
        x = torch.zeros(1)
        if self.x is not None:
            x = self.x[idx]
        return x, self.y1[idx], self.y0[idx]