import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import os
import json
from transformers import AutoModelForMaskedLM, AutoTokenizer
import re
from language_tool_python import LanguageTool
from inspect import signature
import finetune
from nltk.translate.bleu_score import sentence_bleu
from scipy.spatial import distance
from utils import get_best_device
from loguru import logger
from sklearn.preprocessing import LabelEncoder
from utils import unnest_dictionary
from datetime import datetime
import mauve
import time
from vendi_score import text_utils
from scipy.spatial.distance import cdist
import lstm
from torchtext.vocab import FastText
from torchtext.data import get_tokenizer
import self_bleu
from model_loader import load_model, load_tokenizer
from causal_finetuning import Finetune
from transformers import set_seed
import copy

class Evaluator:
    def __init__(self, dataset) -> None:
        self.dataset = dataset
        self.results = dict()

    @staticmethod
    def KL_divergence(word_distrib1 : dict, word_distrib2 : dict, smoothing=1e-6):
        #smoothing: Laplace smoothing
        words = set(word_distrib1.keys()).union(set(word_distrib2.keys()))
        
        kl_div = 0
        for word in words:
            freq1 = (word_distrib1.get(word, 0) + smoothing) / (1 + smoothing * len(words))
            freq2 = (word_distrib2.get(word, 0)+ smoothing) / (1 + smoothing * len(words))
            kl_div += freq1 * np.log(freq1 / freq2)
        return kl_div
    
    @staticmethod
    def JS_divergence(word_distrib1 : dict, word_distrib2 : dict, smoothing=1e-6):
        all_words = set(word_distrib1.keys()).union(word_distrib2.keys())
        average = {word: 1 / 2 * (word_distrib1.get(word, 0) + word_distrib2.get(word, 0)) for word in all_words}
        return 1 / 2 * Evaluator.KL_divergence(word_distrib1, average, smoothing) + 1 / 2 * Evaluator.KL_divergence(word_distrib2, average, smoothing)
    
    def set_results(self, results):
        self.results = results

    def save(self, filename):
        if os.path.dirname(filename) != "":
            os.makedirs(os.path.dirname(filename), exist_ok=True)
        if os.path.isdir(filename):
            file = len(os.listdir(filename))
            filename = os.path.join(filename, f"{file}.json")

        output = {
            "dataset_path": self.dataset.save_path,
            "results": self.results
        }

        with open(filename, "w") as f:
            json.dump(output, f, indent=2, sort_keys=True)

    def store_result(self, key_name, result, sub_key_name=None):
        logger.debug(f"Storing result {key_name}, {sub_key_name} for {str(self.dataset)} with result {result}")
        if sub_key_name is None or isinstance(result, dict):
            self.results[key_name] = result
        elif key_name in self.results:
            self.results[key_name][sub_key_name] = result
        else:
            self.results[key_name] = {sub_key_name: result}

    def in_result(self, function_name):
        for result in self.results:
            if result.startswith(function_name):
                return True

        return False

    def calculate_all(self, n_runs=1, exclude=[], release_memory=False, save_path=None, results=None, force=[], **kwargs):
        if results is not None:
            self.results = results

        for method in dir(self):
            if self.in_result(method) and results is not None and not method in force:
                continue
            if callable(getattr(self, method)) and not method in dir(Evaluator) and not method in exclude:
                sign = signature(getattr(self, method))
                kwargs_method = dict()
                todo = True
                for parameter in sign.parameters:
                    if sign.parameters[parameter].default == sign.parameters[parameter].empty:
                        todo = False
                        break
                    if sign.parameters[parameter].name in kwargs:
                        kwargs_method[sign.parameters[parameter].name] = kwargs[sign.parameters[parameter].name]

                if todo:
                    logger.debug(f"Calculating {method} for {str(self.dataset)}")
                    if not isinstance(self, SupervisedEvaluator) or n_runs == 1:
                        getattr(self, method)(**kwargs_method)
                    else:
                        train_datasets, valid_datasets = self.dataset.split(n_runs)
                        evaluators = [SupervisedEvaluator(train_dataset, self.evaluation_dataset, valid_dataset) for train_dataset, valid_dataset in zip(train_datasets, valid_datasets)]
                        for evaluator in evaluators:
                            getattr(evaluator, method)(**kwargs_method)
                        
                        for key in evaluators[0].results:
                            if isinstance(evaluators[0].results[key], dict):
                                self.results[key] = dict()
                                for sub_key in evaluators[0].results[key]:
                                    if not isinstance(evaluators[0].results[key][sub_key], str):
                                        self.store_result(key, np.mean([evaluator.results[key][sub_key] for evaluator in evaluators]), sub_key)
                                        self.store_result(key, np.std([evaluator.results[key][sub_key] for evaluator in evaluators]) / len(evaluators), str(sub_key) + "_std")
                            else:
                                if not isinstance(evaluators[0].results[key], str):
                                    self.store_result(key, np.mean([evaluator.results[key] for evaluator in evaluators]))
                                    self.store_result(key + "_std", np.std([evaluator.results[key] for evaluator in evaluators]) / len(evaluators))

        if save_path is not None:
            logger.info(f"Saving metrics to {save_path}")
            unnested = unnest_dictionary(self.results)
            if os.path.isfile(save_path):
                with open(save_path, "r") as f:
                    reloading = json.load(f)
                for key in reloading:
                    if key not in unnested:
                        unnested[key] = reloading[key]

            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            with open(save_path, "w") as f:
                json.dump(unnested, f, indent=2, sort_keys=True)
        
        if release_memory:
            self.dataset.release_memory()
        return self.results

class UnsupervisedEvaluator(Evaluator):
    def __init__(self, dataset) -> None:
        super().__init__(dataset)

    def vendi_score(self, max_sentences=2000):
        try:
            score = text_utils.embedding_vendi_score(list(self.dataset.get_sentences()[:max_sentences]), model_path="princeton-nlp/unsup-simcse-bert-base-uncased")
            self.store_result("vendi_score", float(score))
            return float(score)
        except Exception as e:
            logger.warning(f"Vendi score failed with {e} for {str(self.dataset)}")
            return None

    def distinctness(self, n=3, store=True, sentences=None):
            # From "A Diversity-Promoting Objective Function for Neural Conversation Models"
            # Don't like this for longer texts since the longer a text, the more words will be repeated even though it might be very diverse
            # Actually, if the distribution over words is uniform, the number of unique n-grams scales about linearly with length -> it does work!
            # emprirically I find something different though, in fact it even drops exponentially with length
            if sentences is None:
                sentences = self.dataset.get_sentences()
            pattern = re.compile(r"[^a-zA-Z ]")
            stripped_text = pattern.sub("", " ".join(sentences))
            ngrams = stripped_text.split(" ")
            ngrams_disinct = np.unique(ngrams)
            if len(ngrams) > 0:
                result = len(ngrams_disinct) / len(ngrams)
                if store:
                    self.store_result("distinctness", result, n)
                return result
            else:
                return None
            
    def average_distinctness(self, n=3, group_size=5000, repetitions=50, store=True, sentences=None, doc=None, spacy=False):
        if sentences is None and not spacy:
            sentences = self.dataset.get_sentences()
            sentences = [sentence.split(" ") for sentence in sentences]
            sentences = [item for sublist in sentences for item in sublist]
        elif doc is not None and spacy:
            sentences = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct]
        elif sentences is None and spacy:
            sentences = [token.lemma_ for token in self.dataset.get_doc() if not token.is_stop and not token.is_punct]

        n_sentences = len(sentences)
        if n_sentences < group_size:
            logger.warning(f"Dataset {self.dataset} has less than {group_size} sentences.")
            result = self.distinctness(n, store=False)
        else:
            total = 0
            for _ in range(repetitions):
                selected_sentences = np.random.choice(sentences, group_size, replace=False)
                total += self.distinctness(n, store=False, sentences=selected_sentences)
            result = total / repetitions
        
        if store:
            self.store_result("average_distinctness", result, n)

        return result
    
    def average_distinctness_spacy(self, n=1, group_size=5000, repetitions=1000, store=True, doc=None):
        result = self.average_distinctness(n, group_size, repetitions, False, doc=doc, spacy=True)
        if store:
            self.store_result("average_distinctness_spacy", result, n)
        
        return result

    def self_bleu(self, sentences=None, n=3, max_sentences=500, store=True):
        # Source: https://dl.acm.org/doi/pdf/10.1145/3209978.3210080
        if sentences is None:
            sentences = self.dataset.get_sentences()
        sentences = list(sentences.reshape(-1))
        num_sentences = len(sentences)

        if num_sentences > 0:
            average_bleu_score = self_bleu.cal_self_bleu(sentences, n_sample=min(num_sentences, max_sentences), n_gram=n)[0]
            if store:
                self.store_result("self_bleu", average_bleu_score, n)
            return average_bleu_score
        return None

    def self_blue_labeled(self, n=3):
        average_blue = 0
        for label in self.dataset.get_labels():
            score = self.self_bleu(self.dataset.get_sentences(label), n, store=False)
            if score is not None:
                average_blue += score

        average_blue /= len(self.dataset.get_labels())
        self.store_result("self_blue_labeled", average_blue, n)
        return average_blue

    def size(self):
        size = self.dataset.size()
        self.store_result("size", size)
        return size
    
    def spelling_grammar(self, language="en-US"):
        tries = 0
        while tries < 10:
            tries += 1
            try:
                tool = LanguageTool(language)
                sentences = self.dataset.get_sentences()
                mistakes = tool.check(" ".join(sentences))
                if len(sentences) > 0:
                    total_mistakes = len(mistakes) / len(sentences)
                    all_categories = np.array([mistake.category for mistake in mistakes])
                    categories = np.unique(all_categories)
                    mistakes_by_category = {category: np.count_nonzero(all_categories == category) / len(sentences) for category in categories}
                    result = {
                        "total": total_mistakes,
                        "by_category": mistakes_by_category
                    }
                    self.store_result("spelling_grammar", result)
                    return result
                return None
            except Exception as e:
                logger.warning(f"TimeoutError while calculating spelling and grammar for dataset {self.dataset}. Trying again.")
                time.sleep(60)
                continue
        logger.error(f"TimeoutError while calculating spelling and grammar for dataset {self.dataset}. Skipping.")

        return None
    
    def spacy_analysis(self):
        stop_percentage = 0
        alpha_percentage = 0
        pos = dict()
        size = len(self.dataset.get_doc())
        for token in self.dataset.get_doc():
            if token.is_stop:
                stop_percentage += 1
            if token.is_alpha:
                alpha_percentage += 1
            pos[token.pos_] = pos.get(token.pos_, 0) + 1
        
        if size > 0:
            result = {
                "stop": stop_percentage / size,
                "alpha": alpha_percentage / size,
                "pos": {pos: count / size for pos, count in pos.items()}
            }
            self.store_result("spacy_analysis", result)
            return result
        return None

    def distinctness_labeled(self, n=3):
        average = 0
        for label in self.dataset.get_labels():
            pattern = re.compile(r"[^a-zA-Z ]")
            stripped_text = pattern.sub("", " ".join(self.dataset.get_sentences(label)))
            ngrams = stripped_text.split(" ")
            ngrams_disinct = np.unique(ngrams)
            if len(ngrams) > 0:
                subresult = len(ngrams_disinct) / len(ngrams)
                average += subresult

        result = average / len(self.dataset.get_labels())
        self.store_result("distinctness_labeled", result, n)
        
        return result
    
    def distinctness_averaged_labeled(self, n=3, group_size=5000, repetitions=1000):
        average = 0
        for label in self.dataset.get_labels():
            sentences = self.dataset.get_sentences(label)
            sentences = [sentence.split(" ") for sentence in sentences]
            sentences = [item for sublist in sentences for item in sublist]
            res = self.average_distinctness(n, group_size=group_size, repetitions=repetitions, sentences=sentences, store=False)
            if res is not None:
                average += res

        result = average / len(self.dataset.get_labels())
        self.store_result("distinctness_averaged_labeled", result, n)
        
        return result
    
    def distinctness_averaged_spacy_labeled(self, n=1, group_size=1000, repetitions=1000):
        average = 0
        for label in self.dataset.get_labels():
            res = self.average_distinctness_spacy(n, group_size=group_size, repetitions=repetitions, doc=self.dataset.get_doc(label), store=False)
            if res is not None:
                average += res

        result = average / len(self.dataset.get_labels())
        self.store_result("distinctness_averaged_spacy_labeled", result, n)
        
        return result
    
    def vector_distinctness(self, vectors=None, distance_measure="cosine", store=True, max_sentences=10000):
        if vectors is None:
            vectors = np.array(self.dataset.get_vectors())
        n_vectors = min(vectors.shape[0], max_sentences)

        if len(vectors) > 0:
            avg = np.mean(cdist(vectors[:n_vectors], vectors[:n_vectors], distance_measure))
            if store:
                self.store_result(f"vector_distinctness", float(avg), distance_measure)
            return avg
        else:
            return None
    
    def euclid_vector_distinctness(self):
        result = self.vector_distinctness(distance_measure="euclidean", store=False)
        self.store_result("euclid_vector_distinctness", float(result))
        return result
    
    def vector_label_distinctness(self, distance_measure="cosine", store=True):
        mean_distance = 0
        for label in self.dataset.get_labels():
            _, indices = self.dataset.get_label_sentences(label)
            vectors = np.array(self.dataset.get_vectors())[indices]
            distance = self.vector_distinctness(vectors, distance_measure=distance_measure, store=False)
            if distance is not None: # can happen if the number of elements with that label is 1
                mean_distance += distance

        result = mean_distance / len(self.dataset.get_labels())
        if store:
            self.store_result("vector_label_distinctness", float(result), distance_measure)

        return result
    
    def euclid_labeled_vector_distinctness(self):
        result = self.vector_label_distinctness(distance_measure="euclidean", store=False)
        self.store_result("euclid_labeled_vector_distinctness", float(result))
        return result

    @staticmethod
    def get_perplexity(dataset, model, tokenizer, device, batch_size=1, max_length=None, dtype=torch.bfloat16, **kwargs):
        logger.info(f"Calculating perplexity for dataset of size {len(dataset)}")
        if isinstance(model, str):
            model = load_model(model, dtype=dtype)
        if max_length is not None:
            dataset = dataset[:max_length]
        sum_nllos = 0
        n_tokens = 0
        batch_sentences = []
        for index, sample in enumerate(dataset.iterrows()):
            sentence = sample[1].get("text", None)
            if sentence is None:
                continue
            batch_sentences.append(sentence)

            if (index + 1) % batch_size == 0 or index + 1 == len(dataset):
                encodings = tokenizer(batch_sentences, return_tensors='pt', truncation=True, padding=True, max_length=128)
                input_ids = encodings['input_ids'].to(model.device)
                attention_mask = encodings['attention_mask'].to(model.device)

                with torch.no_grad():
                    output = model(input_ids, labels=input_ids, attention_mask=attention_mask)
                    logprobs = output.logits.log_softmax(dim=-1).to(model.device)

                loss_func = torch.nn.NLLLoss(ignore_index=tokenizer.pad_token_id, reduction='sum')

                labels = input_ids[..., 1:]
                labels = labels.contiguous()
                logprobs_loss = logprobs[..., :-1, :].contiguous()
                loss = loss_func(logprobs_loss.view(-1, logprobs_loss.size(-1)), labels.view(-1))
                loss = loss.to(torch.float32).detach().cpu().numpy()

                # count actual tokens (ignore padding tokens)
                n_tokens_here = attention_mask.sum().item()

                if n_tokens_here > 0:
                    sum_nllos += loss
                    n_tokens += n_tokens_here
                batch_sentences = []

        real = np.exp(sum_nllos / n_tokens)

        return real

    def perplexity(self, model="llama-13b", batch_size=8):
        device = get_best_device()
        model, tokenizer = load_model(model, return_tokenizer=True)
        model.eval()
        df = pd.DataFrame({"text": self.dataset.get_sentences()})
        perplexity = self.get_perplexity(df, model, tokenizer, device, batch_size=batch_size)
        self.store_result("perplexity", perplexity)

class SemiSupervisedEvaluator(UnsupervisedEvaluator):
    def __init__(self, dataset, evaluation_dataset, valid_dataset=None) -> None:
        super().__init__(dataset)
        self.evaluation_dataset = evaluation_dataset
        self.valid_dataset = valid_dataset

    def unlabeled_JS_divergence(self):
        result = self.JS_divergence(self.dataset.get_word_distribution(), self.evaluation_dataset.get_word_distribution())
        self.store_result("unlabeled_JS_divergence", float(result))
        return result
    
    def unlabeled_KL_divergence(self):
        result = self.KL_divergence(self.dataset.get_word_distribution(), self.evaluation_dataset.get_word_distribution())
        self.store_result("unlabeled_KL_divergence", float(result))
        return result


class SupervisedEvaluator(SemiSupervisedEvaluator):
    def __init__(self, dataset, evaluation_dataset, valid_dataset=None) -> None:
        super().__init__(dataset, evaluation_dataset, valid_dataset=valid_dataset)

    def average_vector_distance(self, n_vectors=10 ** 5, distance_metric=distance.cosine, store=True):
        indices1 = np.random.randint(0, self.dataset.size(), (n_vectors,))
        indices2 = np.random.randint(0, self.evaluation_dataset.size(), (n_vectors,))
        vectors1 = np.array(self.dataset.get_vectors())[indices1]
        vectors2 = np.array(self.evaluation_dataset.get_vectors())[indices2]
        result = np.mean([distance_metric(vectors1[i], vectors2[i]) for i in range(n_vectors)])
        if store:
            self.store_result("average_vector_distance", float(result), distance_metric.__name__)
        return result
    
    def average_closest_vector_distance(self, n_vectors=2000, distance_metric="cosine", k=10, store=True):
        indices = np.random.randint(0, self.dataset.size(), (n_vectors,))
        vectors = np.array(self.dataset.get_vectors())[indices]
        cdists = cdist(vectors, self.evaluation_dataset.get_vectors(), metric=distance_metric)
        result = np.partition(cdists, k - 1, axis=1)[:, :k - 1].mean()
        if store:
            self.store_result("average_closest_vector_distance", float(result), distance_metric)
        return result

    def average_inverse_closest_vector_distance(self, n_vectors=2000, distance_metric="cosine", k=10, store=True):
        indices = np.random.randint(0, self.evaluation_dataset.size(), (n_vectors,))
        vectors = np.array(self.evaluation_dataset.get_vectors())[indices]
        cdists = cdist(vectors, self.dataset.get_vectors(), metric=distance_metric)
        result = np.partition(cdists, k - 1, axis=1)[:, :k - 1].mean()
        if store:
            self.store_result("average_inverse_closest_vector_distance", float(result), distance_metric)
        return result

    def euclidean_average_vector_distance(self, n_vectors=10 ** 5):
        result = self.average_vector_distance(n_vectors=n_vectors, distance_metric=distance.euclidean, store=False)
        self.store_result("euclidean_average_vector_distance", float(result))
        return result

    def euclidean_average_closest_vector_distance(self, n_vectors=2000, k=10):
        result = self.average_closest_vector_distance(n_vectors=n_vectors, k=k, distance_metric="euclidean", store=False)
        self.store_result("euclidean_average_closest_vector_distance", float(result))
        return result

    def euclidean_average_inverse_closest_vector_distance(self, n_vectors=2000, k=10):
        result = self.average_inverse_closest_vector_distance(n_vectors=n_vectors, k=k, distance_metric="euclidean", store=False)
        self.store_result("euclidean_average_inverse_closest_vector_distance", float(result))
        return result

    def mauve(self, vectors=None, eval_vectors=None, store=True, max_sentences=4000):
        if vectors is None:
            vectors = self.dataset.get_vectors()
        if eval_vectors is None:
            eval_vectors = self.evaluation_dataset.get_vectors()
        try:
            result = mauve.compute_mauve(vectors[:max_sentences], eval_vectors[:max_sentences])
            if store:
                self.store_result("mauve", float(result.mauve))
        except RuntimeError as e:
            logger.warning(f"RuntimeError {e} mauve for dataset {self.dataset}.")
            return None
        return result.mauve
    
    def normalized_distinctness(self, n=3, sentences=None, other_sentences=None, store=True):
        if sentences is None:
            sentences = self.dataset.get_sentences()
            sentences = [sentence.split(" ") for sentence in sentences]
            sentences = [sentence for sublist in sentences for sentence in sublist]
        if other_sentences is None:
            other_sentences = self.evaluation_dataset.get_sentences()
            other_sentences = [sentence.split(" ") for sentence in other_sentences]
            other_sentences = [sentence for sublist in other_sentences for sentence in sublist]
        max_sentences = min(len(sentences), len(other_sentences))
        distinctness = self.distinctness(n=n, sentences=sentences[:max_sentences], store=False)
        other_distinctness = self.distinctness(n=n, sentences=other_sentences[:max_sentences], store=False)
        result = distinctness / other_distinctness
        if store:
            self.store_result("normalized_distinctness", float(result))
        return result
    
    def spacy_normalized_distinctness(self, n=1, doc=None, other_doc=None, store=True):
        if doc is None:
            doc = self.dataset.get_doc()
        if other_doc is None:
            other_doc = self.evaluation_dataset.get_doc()
        self_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct]
        other_tokens = [token.lemma_ for token in other_doc if not token.is_stop and not token.is_punct]
        result = self.normalized_distinctness(n=n, sentences=self_tokens, other_sentences=other_tokens, store=False)
        if store:
            self.store_result("spacy_normalized_distinctness", float(result))
        return result

    def labeled_normalized_distinctness(self, n=3):
        mean_distance = 0
        for label in self.dataset.get_labels():
            sentences, _ = self.dataset.get_label_sentences(label)
            sentences = [sentence.split(" ") for sentence in sentences]
            sentences = [sentence for sublist in sentences for sentence in sublist]
            other_sentences, _ = self.evaluation_dataset.get_label_sentences(label)
            other_sentences = [sentence.split(" ") for sentence in other_sentences]
            other_sentences = [sentence for sublist in other_sentences for sentence in sublist]
            distance = self.normalized_distinctness(sentences=sentences, other_sentences=other_sentences, store=False, n=n)
            if distance is not None: # can happen if the number of elements with that label is 1
                mean_distance += distance

        result = mean_distance / len(self.dataset.get_labels())
        self.store_result("labeled_normalized_distinctness", float(result))

        return result
    
    def labeled_spacy_normalized_distinctness(self, n=1):
        mean_distance = 0
        for label in self.dataset.get_labels():
            doc = self.dataset.get_doc(label)
            other_doc = self.evaluation_dataset.get_doc(label)
            distance = self.spacy_normalized_distinctness(doc=doc, other_doc=other_doc, store=False, n=n)
            if distance is not None: # can happen if the number of elements with that label is 1
                mean_distance += distance

        result = mean_distance / len(self.dataset.get_labels())
        self.store_result("labeled_spacy_normalized_distinctness", float(result))

        return result

    def mauve_labeled(self):
        mean_mauve = 0
        count = 0
        for label in self.dataset.get_labels():
            _, indices = self.dataset.get_label_sentences(label)
            vectors = np.array(self.dataset.get_vectors())[indices]
            _, indices = self.evaluation_dataset.get_label_sentences(label)
            eval_vectors = np.array(self.evaluation_dataset.get_vectors())[indices]
            res = self.mauve(vectors, eval_vectors, store=False)
            if res is not None:
                mean_mauve += res
                count += 1
        
        if count > 0:
            result = mean_mauve / count
            self.store_result("mauve_labeled", float(result))
            return result

    def labeled_JS_divergence(self):
        average_JS = 0
        for label in self.dataset.get_labels():
            average_JS += self.JS_divergence(self.dataset.get_word_distribution(label), self.evaluation_dataset.get_word_distribution(label))
        result = average_JS / len(self.dataset.get_labels())
        self.store_result("labeled_JS_divergence", float(result))
        return result
    
    def labeled_KL_divergence(self):
        average_KL = 0
        for label in self.dataset.get_labels():
            average_KL += self.KL_divergence(self.dataset.get_word_distribution(label), self.evaluation_dataset.get_word_distribution(label))
        
        result = average_KL / len(self.dataset.get_labels())
        self.store_result("labeled_KL_divergence", float(result))
        return result
    
    def preprocess_data(self, X, y=None):
        try:
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
        except ValueError: # happens if one label only appears once.
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)
        return X_train, X_test, y_train, y_test, scaler
    
    def linear(self, label_change=0):
        _, y_fake = self.dataset.get_all()
        _, y_real = self.evaluation_dataset.get_all()
        if self.dataset.is_dual_task():
            X_fake1, X_fake2 = self.dataset.get_vectors(include_other=True)
            X_fake = np.concatenate((X_fake1, X_fake2), axis=1)
            X_real1, X_real2 = self.dataset.get_vectors(include_other=True)
            X_real = np.concatenate((X_real1, X_real2 ), axis=1)
        else:
            X_fake = self.dataset.get_vectors()
            X_real = self.evaluation_dataset.get_vectors()

        X_train_real, X_test_real, y_train_real, y_test_real, scaler_real = self.preprocess_data(X_real, y_real)
        X_train_fake, X_test_fake, y_train_fake, y_test_fake, scaler_fake = self.preprocess_data(X_fake, y_fake)

        if label_change > 0:
            change_indices = np.random.choice(len(y_train_real), int(len(y_train_real) * label_change), replace=False)
            y_train_real[change_indices] = np.random.choice(self.dataset.get_labels(), len(change_indices))

            change_indices = np.random.choice(len(y_train_fake), int(len(y_train_fake) * label_change), replace=False)
            y_train_fake[change_indices] = np.random.choice(self.dataset.get_labels(), len(change_indices))

        log_reg_fake = LogisticRegression(max_iter=1000)
        log_reg_fake.fit(X_train_fake, y_train_fake)
        prediction_fake = log_reg_fake.predict(X_test_fake)
        fake_to_fake_metrics = {
            "accuracy": float(accuracy_score(prediction_fake, y_test_fake)), 
            "f1": float(f1_score(y_test_fake, prediction_fake, average="macro")), 
            "confusion": confusion_matrix(y_test_fake, prediction_fake).tolist()
        }
        self.store_result("linear_fake_to_fake", fake_to_fake_metrics)
        
        prediction_real = log_reg_fake.predict(scaler_fake.transform(X_real))

        fake_to_real_metrics = {
            "accuracy": float(accuracy_score(y_real, prediction_real)), 
            "f1": float(f1_score(y_real, prediction_real, average="macro")), 
            "confusion": confusion_matrix(y_real, prediction_real).tolist()
        }
        self.store_result("linear_fake_to_real", fake_to_real_metrics)

        log_reg_real = LogisticRegression(max_iter=1000)
        log_reg_real.fit(X_train_real, y_train_real)

        prediction_real = log_reg_real.predict(X_test_real)
        accuracy_score(prediction_real, y_test_real), confusion_matrix(y_test_real, prediction_real)

        real_to_real_metrics = {
            "accuracy": float(accuracy_score(y_test_real, prediction_real)), 
            "f1": float(f1_score(y_test_real, prediction_real, average="macro")), 
            "confusion": confusion_matrix(y_test_real, prediction_real).tolist()
        }

        self.store_result("linear_real_to_real", real_to_real_metrics)
    
        prediction_fake = log_reg_real.predict(scaler_real.transform(X_fake))
        real_to_fake_metrics = {
            "accuracy": float(accuracy_score(y_fake, prediction_fake)), 
            "f1": float(f1_score(y_fake, prediction_fake, average="macro")), 
            "confusion": confusion_matrix(y_fake, prediction_fake).tolist()
        }

        self.store_result("linear_real_to_fake", real_to_fake_metrics)
    
        return fake_to_fake_metrics, fake_to_real_metrics, real_to_fake_metrics, real_to_real_metrics
    
    def finetune(self, batch_size=8, max_length=64, training_steps=2000, lr=1e-5, label_smoothing=0.1, alpha_temporal=0.9, lambda_temporal=0.0, device=None, store_result_name="finetune", 
                       persist_temp_folder=False, num_warmup_steps=600, store=True, reload_best_model=True, only_store_model=False, max_epochs=5):
        if device is None:
            device = torch.device(get_best_device())

        time_folder = datetime.now().strftime("%Y%m%d-%H%M%S%f")

        set_seed(42)

        classes = list(set(self.dataset.get_labels()).union(set(self.evaluation_dataset.get_labels())))
        label_encoder = LabelEncoder()
        label_encoder.fit(classes)
        if self.valid_dataset is None:
            train, test = train_test_split(self.dataset.df, test_size=0.3)
        else:
            train, test = self.dataset.df, self.valid_dataset.df
        model, tokenizer, label_encoder, store_folder_model, store_folder_label_encoder = finetune.run(train, batch_size, max_length, training_steps, lr, label_smoothing, device=device, 
                                                       alpha_temporal=alpha_temporal, lambda_temporal=lambda_temporal, label_encoder=label_encoder, validation_data=test, 
                                                       temp_folder=os.path.join("temp", time_folder), persist_temp_folder=persist_temp_folder, 
                                                       num_warmup_steps=num_warmup_steps, reload_best_model=reload_best_model, max_epochs=max_epochs)
        
        if persist_temp_folder:
            self.store_result("model", store_folder_model, store_result_name)
            self.store_result("label_encoders", store_folder_label_encoder, store_result_name)

        metrics_real = finetune.evaluate_metrics(model, tokenizer, label_encoder, self.evaluation_dataset.df, device, max_length)
        if store or only_store_model:
            self.store_result(store_result_name + "_fake_to_real", metrics_real)

        metrics_fake = finetune.evaluate_metrics(model, tokenizer, label_encoder, test, device, max_length)

        if store or only_store_model:
            self.store_result(store_result_name + "_fake_to_fake", metrics_fake)

        del model, tokenizer
        torch.cuda.empty_cache()

        return {
            "fake": metrics_fake,
            "real": metrics_real
        }

    def finetune_no_smoothing(self, batch_size=8, max_length=64, training_steps=2000, lr=1e-5, device=None, 
                              persist_temp_folder=False, num_warmup_steps=600, reload_best_model=True, only_store_model=False, max_epochs=5):
        return self.finetune(batch_size=batch_size, max_length=max_length, training_steps=training_steps, lr=lr, 
                            label_smoothing=0, device=device, store_result_name="finetune_no_smoothing", persist_temp_folder=persist_temp_folder, 
                            num_warmup_steps=num_warmup_steps, reload_best_model=reload_best_model, only_store_model=only_store_model, max_epochs=max_epochs)

    def finetune_temporal(self, batch_size=8, max_length=64, training_steps=2000, lr=1e-5, label_smoothing=0.1, alpha_temporal=0.9, lambda_temporal=1.0, 
                                device=None, persist_temp_folder=False, num_warmup_steps=600, store=True, reload_best_model=True, only_store_model=False, max_epochs=5):
        return self.finetune(batch_size=batch_size, max_length=max_length, training_steps=training_steps, lr=lr, 
                            label_smoothing=label_smoothing, alpha_temporal=alpha_temporal, lambda_temporal=lambda_temporal, 
                            device=device, store_result_name="finetune_temporal", persist_temp_folder=persist_temp_folder, 
                            num_warmup_steps=num_warmup_steps, store=store, reload_best_model=reload_best_model, only_store_model=only_store_model, max_epochs=max_epochs)

    def neural(self, max_epochs=5, n_training_steps=1000, batch_size=32, early_stopping_epochs=3):
        tokenizer = get_tokenizer("basic_english")
        vectorizer = FastText()
        text, labels = self.dataset.get_all()
        classes = list(set(self.dataset.get_labels()).union(set(self.evaluation_dataset.get_labels())))
        label_encoder = LabelEncoder()
        label_encoder.fit(classes)

        device = get_best_device()
        model = lstm.train_model(text, labels, tokenizer, vectorizer, device=device, label_encoder=label_encoder,
                                 max_epochs=max_epochs, n_training_steps=n_training_steps, batch_size=batch_size, 
                                 early_stopping_epochs=early_stopping_epochs)
        

        metrics = lstm.evaluate_metrics(model, tokenizer, label_encoder, self.evaluation_dataset.df, vectorizer, device)

        self.store_result("neural_fake_to_real", metrics)
        return metrics

    def linear_classifier(self):
        X_fake = self.dataset.get_vectors()
        X_real = self.evaluation_dataset.get_vectors()
        X = np.concatenate([X_fake, X_real], axis=0)
        y = [0 for _ in range(len(X_fake))] + [1 for _ in range(len(X_real))]
        X_train, X_test, y_train, y_test, scaler = self.preprocess_data(X, y)
        log_reg = LogisticRegression(max_iter=1000)
        log_reg.fit(X_train, y_train)

        prediction = log_reg.predict(X_test)

        result = {
            "accuracy": float(accuracy_score(y_test, prediction)), 
            "f1": float(f1_score(y_test, prediction))
        }

        self.store_result("linear_classifier", result)

        return result
        
    def finetune_causal(self, dataset, model_name="gpt2-xl", dtype=torch.bfloat16, model=None):
        set_seed(42)
        if model is None:
            model, tokenizer = load_model(model_name, return_tokenizer=True, dtype=dtype)
        else:
            tokenizer = load_tokenizer(model_name)
        finetuner = Finetune()
        model = finetuner.finetune(model_name, dataset, model=model)
        return model, tokenizer

    def finetune_causal_metrics(self, model_name="gpt2-xl", persist_temp_folder=False, dtype=torch.bfloat16, batch_size=8, max_length=1000):
        train, test = train_test_split(self.dataset.df, test_size=0.3, random_state=42)
        average_perplexity = []
        average_eval_perplexity = []
        time_folder = datetime.now().strftime("%Y%m%d-%H%M%S%f")
        state_model = load_model(model_name, return_tokenizer=False, dtype=dtype)

        for label in self.dataset.get_labels():
            df_train = train[train["label"] == label]
            df_train.drop(columns=["label"], inplace=True)
            model, tokenizer = self.finetune_causal(df_train, model_name=model_name, dtype=dtype, model=copy.deepcopy(state_model))
            # set model in eval mode and calculate perplexity on test set
            if persist_temp_folder:
                model.save_pretrained(os.path.join("temp", time_folder, label))

            model.eval()
            perplexity = self.get_perplexity(test[test["label"] == label], model, tokenizer, get_best_device(), batch_size=batch_size, 
                                                max_length=max_length)
            average_perplexity.append(perplexity)
            perplexity_eval = self.get_perplexity(self.evaluation_dataset.df[self.evaluation_dataset.df["label"] == label], model, 
                                                  tokenizer, get_best_device(), batch_size=batch_size, max_length=max_length)

            average_eval_perplexity.append(perplexity_eval)
        
            del model
            torch.cuda.empty_cache()

        if persist_temp_folder:
            self.store_result("model", os.path.join("temp", time_folder), "finetune_causal_metrics")

        result = {
            "perplexity_fake_to_fake": np.mean(average_perplexity),
            "perplexity_fake_to_real": np.mean(average_eval_perplexity)
        }

        self.store_result("finetune_causal_metrics", result)
