from tqdm import tqdm

import torch
import torch.nn.functional as F


def generate_image(generator, sample_size=1, condition=None, device='cpu'):
    generator.eval().to(device)
    if condition is None:
        condition = torch.randint(0, 10, (sample_size,))
        condition = condition.to(device)
    noise = torch.randn(sample_size, 100, 1, 1, device=device)
    generated_image = generator(noise, condition)
    return generated_image, condition


def generate_dataset(
        generator,
        classifier, 
        reward_function, 
        sample_size=10000, 
        device='cpu', 
        verbose=False,
        alpha=0.,
    ):
    generator.to(device)
    classifier.to(device)
    reward_function.to(device)
    
    if verbose:
        pbar = tqdm(total=sample_size)
    
    samples = []
    while len(samples) < sample_size:
        with torch.no_grad():
            images, labels = generate_image(generator, sample_size=100, device=device)
            images = images.view(100, 1, 28, 28)
            rewards = reward_function(images, labels)
            
        for label in range(10):
            idxs = torch.where(labels == label)[0]
            if len(idxs) < 2:
                continue
            
            _, rewards_idx = torch.sort(rewards, descending=True)
            rewards_idx = rewards_idx[labels[rewards_idx] == label]
            
            y1 = images[rewards_idx[0]]
            y0 = images[rewards_idx[-1]]
            y0 = torch.randn_like(y0).to(device) * alpha + y0 * (1 - alpha)
            samples.append((label, y1, y0))
            if verbose:
                pbar.update(1)
            if len(samples) >= sample_size:
                return samples
                
    return samples


class RewardFunction():
    def __init__(self, classifier, discriminator, device='cpu', alpha=1.):
        self.alpha = alpha
        self.device = device
        self.classifier = classifier
        self.discriminator = discriminator
        
    def to(self, device):
        self.device = device
        self.classifier.to(device)
        self.discriminator.to(device)
        return self

    def pred_label(self, y):
        pad_y = F.pad(y, (2, 2, 2, 2), value=0)
        logits = self.classifier(pad_y)
        return torch.argmax(logits, dim=1)

    def __call__(self, y, x=None):
        if x is None:
            x = self.pred_label(y).to(self.device)
        pad_y = F.pad(y, (2, 2, 2, 2), value=0)
        logits = self.classifier(pad_y)
        scores = self.discriminator(y, x)
        probs = torch.nn.functional.softmax(logits, dim=1)
        probs = probs[torch.arange(len(probs)), x]
        scores = 1 - scores
        return self.alpha * probs + (1 - self.alpha) * scores


class MNISTPreferneceDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        label, y1, y0 = self.dataset[idx]
        return label, y1, y0