import pandas as pd
import random
import torch
    
class PromptDataset:
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)
        self.unseen_indices = list(self.data.index)

    def get_random_prompts(self, num_prompts=1):
        # Ensure that the number of prompts requested is not greater than the number of unseen prompts
        num_prompts = min(num_prompts, len(self.unseen_indices))

        # Randomly select num_prompts indices from the list of unseen indices
        selected_indices = random.sample(self.unseen_indices, num_prompts)
        
        # Remove the selected indices from the list of unseen indices
        for index in selected_indices:
            self.unseen_indices.remove(index)

        # return the prompts corresponding to the selected indices
        return self.data.loc[selected_indices, 'prompt'].tolist()

    def has_unseen_prompts(self):
        # check if there are any unseen prompts
        return len(self.unseen_indices) > 0
    
    def reset(self):
        self.unseen_indices = list(self.data.index)
        
    def check_unseen_prompt_count(self):
        return len(self.unseen_indices)
    
def retain_prompt(dataset_retain):
    # Prompt Dataset to be retained

    if dataset_retain == 'imagenet243':
        retain_dataset = PromptDataset('./data/prompts/train/imagenet243_retain.csv')
    elif dataset_retain == 'imagenet243_no_filter':
        retain_dataset = PromptDataset('./data/prompts/train/imagenet243_no_filter_retain.csv')
    elif dataset_retain == 'coco_object':
        retain_dataset = PromptDataset('./data/prompts/train/coco_object_retain.csv')
    elif dataset_retain == 'coco_object_no_filter':
        retain_dataset = PromptDataset('./data/prompts/train/coco_object_no_filter_retain.csv')
    else:
        raise ValueError('Invalid dataset for retaining prompts')
    
    return retain_dataset

def retain_loss(diffuser, fintuner, retain_dataset, retain_batch, nsteps, latents, criteria, device):
    if retain_dataset.check_unseen_prompt_count() < retain_batch:
        retain_dataset.reset()
    
    iteration = torch.randint(1, nsteps - 1, (1,)).item()
    diffuser.set_scheduler_timesteps(nsteps)
    
    with torch.no_grad():
        retain_words = retain_dataset.get_random_prompts(retain_batch)
        retain_emb  = diffuser.get_text_embeddings(retain_words, n_imgs=retain_batch)
        
        retain_z, _ = diffuser.diffusion(
            latents,
            retain_emb,
            start_iteration=0,
            end_iteration=iteration,
            guidance_scale=3, 
            show_progress=False
        )
        
        diffuser.set_scheduler_timesteps(1000)
        iteration = int(iteration / nsteps * 1000)
        
        retain_emb_p = diffuser.predict_noise(iteration, retain_z[0], retain_emb, guidance_scale=1)
    with fintuner:
        retain_emb_n = diffuser.predict_noise(iteration, retain_z[0], retain_emb, guidance_scale=1)
        
    return criteria(retain_emb_p.to(device), retain_emb_n.to(device))