from converter import DatasetConverter
from fuzzywuzzy import fuzz
import numpy as np
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoConfig
from loguru import logger
from parse import ClassificationParser
from query import OpenAIQuery
from utils import get_best_device
from scipy.spatial.distance import cosine

class ExactMatchSelector(DatasetConverter):
    def __init__(self, **kwargs) -> None:
        super().__init__()

    def run(self, dataset, *args, **kwargs):
        dataset = dataset.copy()
        size_init = dataset.size()
        dataset.df.drop_duplicates(inplace=True)
        logger.info(f"Removing {dataset.size() - size_init} sentences")
        dataset.reset_calculation()

        return dataset

class IndexSelector(DatasetConverter):
    def __init__(self, drop_percentage, **kwargs):
        super().__init__(drop_percentage=drop_percentage)
    
    def run(self, dataset, correctness_indices, *args, **kwargs):
        dataset = dataset.copy()
        indices = np.where(correctness_indices == 0)[0]
        drop_indices = np.random.choice(indices, int(len(indices) * self.drop_percentage), replace=False)
        dataset.remove(drop_indices)
        return dataset
    
class IndexQuantileSelector(DatasetConverter):
    def __init__(self, drop_percentage, minimum=True, **kwargs):
        super().__init__(drop_percentage=drop_percentage, minimum=minimum)
    
    def run(self, dataset, correctness_indices, *args, **kwargs):
        dataset = dataset.copy()
        if self.minimum:
            quantile = np.quantile(correctness_indices, 1 - self.drop_percentage)
            indices = np.where(correctness_indices >= quantile)[0]
        else:
            quantile = np.quantile(correctness_indices, self.drop_percentage)
            indices = np.where(correctness_indices <= quantile)[0]
        
        dataset.remove(indices)
        return dataset
    
class RandomSelector(DatasetConverter):
    def __init__(self, drop=0.2, **kwargs) -> None:
        super().__init__(drop=drop, **kwargs)

    def run(self, dataset, *args, **kwargs):
        dataset = dataset.copy()
        drop_indices = np.random.choice(dataset.size(), int(dataset.size() * self.drop), replace=False)
        logger.info(f"Removing {len(drop_indices)} sentences")
        dataset.remove(drop_indices)

        return dataset

class ApproximateMatchSelector(DatasetConverter):
    def __init__(self, overlap=0.8, **kwargs) -> None:
        super().__init__(overlap=overlap)
        
    def run(self, dataset, *args, **kwargs):
        dataset = dataset.copy()
        index_to_remove = set()
        for i, text in enumerate(dataset.get_sentences()):
            overlaps = np.array([fuzz.ratio(text, other_text) / 100 for other_text in dataset.get_sentences()[i + 1:]])
            overlaps_bool = overlaps > self.overlap
            for j, overlap_bool in enumerate(overlaps_bool):
                if overlap_bool:
                    index_to_remove.add(i + j + 1)
        logger.info(f"Removing {len(index_to_remove)} sentences")
        dataset.remove(list(index_to_remove))

        return dataset

class CritiqueSelector(DatasetConverter):
    pass

class PerplexitySelector(DatasetConverter):
    def __init__(self, drop=0.2, model_name='distilbert-base-uncased', device=None, batch_size=16, **kwargs) -> None:
        super().__init__(drop=drop, model_name=model_name, device=device, batch_size=batch_size)
        
    def run(self, dataset, *args, **kwargs):
        dataset = dataset.copy()
        if self.device is None:
            self.device = torch.device(get_best_device())
        perplexities = []
        model = AutoModelForMaskedLM.from_pretrained(self.model_name).to(self.device)
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)

        for sentence in dataset.get_sentences():
            with torch.no_grad():
                tensor_input = tokenizer.encode(sentence, return_tensors='pt', max_length=512, truncation=True, padding='max_length', add_special_tokens=True)
                repeat_input = tensor_input.repeat(tensor_input.size(-1)-2, 1)
                mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:-2]
                masked_input = repeat_input.masked_fill(mask == 1, tokenizer.mask_token_id)
                labels = repeat_input.masked_fill( masked_input != tokenizer.mask_token_id, -100)
                try:
                    loss = 0
                    batches = 0
                    for i in range(0, masked_input.size(0), self.batch_size):
                        masked_input_batch = masked_input[i:i+self.batch_size]
                        labels_batch = labels[i:i+self.batch_size]
                        loss += model(masked_input_batch.to(self.device), labels=labels_batch.to(self.device)).loss.item()
                        batches += 1

                    if batches > 0:
                        loss /= batches
                except RuntimeError as e:
                    loss = torch.tensor(0)
                    logger.warning(f"RuntimeError {e} perplexity for sentence (probably empty): {sentence}")

                perplexities.append(loss)

        drop_indices = np.where(np.array(perplexities) >= np.quantile(np.array(perplexities), 1 - self.drop))[0]
        logger.info(f"Removing {len(drop_indices)} sentences")
        dataset.remove(drop_indices)

        del model, tokenizer
        torch.cuda.empty_cache()

        return dataset
    

class ClassSelectionSelector(DatasetConverter):
    def __init__(self, prompt, querier=OpenAIQuery(), parser=None):
        # 51 does this
        super().__init__(prompt=prompt, querier=querier, parser=parser)

    async def run(self, dataset, *args, **kwargs):
        dataset = dataset.copy()
        if self.parser is None:
            self.parser = ClassificationParser(dataset.get_labels())

        drop_indices = []
        results_queries = await self.querier.query([self.prompt for _ in range(dataset.size())], dataset.get_sentences())
        for i, result_query in enumerate(results_queries):
            if self.parser.parse(result_query)[0] != dataset.get_labels()[i]:
                drop_indices.append(i)

        logger.info(f"Removing {len(drop_indices)} sentences")
        dataset.remove(drop_indices)
        return dataset


class DiversitySelector(DatasetConverter):
    def __init__(self, drop=0.2) -> None:
        super().__init__(drop=drop)

    def run(self, dataset, *args, **kwargs):
        dataset = dataset.copy()
        vectors = dataset.get_vectors()
        amount_to_keep = int(dataset.size() * (1 - self.drop))
        average_distances = [0 for _ in range(dataset.size())]
        keep_indices = []
        for _ in range(amount_to_keep):
            max_index = np.argmax(average_distances)
            keep_indices.append(max_index)
            for j in range(dataset.size()):
                if j not in keep_indices:
                    average_distances[j] += cosine(vectors[max_index], vectors[j])
            average_distances[max_index] = -1

        drop_indices = [i for i in range(dataset.size()) if i not in keep_indices]
        logger.info(f"Removing {len(drop_indices)} sentences")
        dataset.remove(drop_indices)
        
        return dataset
