import os
import shutil
import sys
import random
import torch

import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import AdamW
from typing import Optional, Union, List

from datasets import Dataset
from transformers import (AutoTokenizer, AutoModelForCausalLM, EncoderDecoderModel, BitsAndBytesConfig,
                          TrainingArguments, Trainer, DataCollatorForLanguageModeling, EarlyStoppingCallback)
from transformers import EarlyStoppingCallback
from transformers import TrainerCallback

from utils.augmentation.utils_data_augmentation import get_feature_distribution


class EvalLossEarlyStopping(TrainerCallback):
    def __init__(self, patience=3):
        self.patience = patience
        self.best_loss = None
        self.counter = 0

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        eval_loss = metrics.get("eval_loss")
        if eval_loss is None:
            return control

        if self.best_loss is None or eval_loss < self.best_loss:
            self.best_loss = eval_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print(f"[EarlyStopping] No improvement in eval_loss for {self.patience} evals. Stopping at step {state.global_step}.")
                control.should_training_stop = True

        return control


_train_args = {}
if sys.platform == 'darwin':
    # _train_args['fp16']=True
    _train_args['use_mps_device'] = True
    # _train_args['bf16'] = True if torch.backends.mps.is_available() else False
else:
    pass

class SDForger(object):
    def __init__(self,
                 model_path: str,
                 text_template: str,
                 output_dir: str = None,
                 float_type: str = 'float32',
                 finetune: Optional[str] = None,
                 device: Optional[str] = None,
                 **kwargs
                 ):
        self.model_path = model_path
        self.text_template = text_template
        if not output_dir:
            self.output_dir = '/'.join(['output/finetuned_models', model_path.split('/')[-1]])
        else:
            self.output_dir = output_dir
        # Delete what's already present in the output directory
        if os.path.exists(self.output_dir):
            for file_name in os.listdir(self.output_dir):
                file_path = os.path.join(self.output_dir, file_name)
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.remove(file_path)  # Remove the file
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)  # Remove the directory
            # TODO: uncomment print
            # print(f"All files in {self.output_dir} have been deleted.")
        else:
            pass
            # TODO: uncomment print
            # print(f"The directory {self.output_dir} does not exist.")
        self.kwargs = kwargs
        self.train_args = _train_args
        self.seed = kwargs['seed'] if 'seed' in kwargs else 42
        self.set_seed()
        self.device = self.choose_device(device)

        # TODO: uncomment print
        # print(f" Device: {self.device}")
        if 'k_bit' in self.kwargs:
            if self.kwargs['k_bit'] == 4:
                self.quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=self.dtype)
            elif self.kwargs['k_bit'] == 8:
                self.quantization_config = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=self.dtype)
            else:
                raise ValueError('k_bit must be either 4 or 8')
        else:
            if float_type == 'float32':
                self.dtype = torch.float32
            elif float_type == 'float16':
                self.dtype = torch.float16
            else:
                raise ValueError("sdforger_float_type can only be 'float16' or 'float32'")

            print('')
            print(f'dtype = {self.dtype}')

            self.quantization_config = None
        self.handle_model(model_path, self.quantization_config)
        if finetune is not None:
            # TODO: set utils function that set the type of finetuning in utils, move to fit
            pass

    def handle_model(self, path, quantization_config):
        try:
            if 'granite' in self.model_path:  # Check for Granite model
                self.model = AutoModelForCausalLM.from_pretrained(
                    path,
                    torch_dtype=self.dtype,
                    # device_map='auto',  # Automatically map to available devices
                    quantization_config=quantization_config,  # Use quantization if provided
                    trust_remote_code=True,
                    ignore_mismatched_sizes=True
                )
            elif 'gpt2' in self.model_path:
                self.model = AutoModelForCausalLM.from_pretrained(
                    path,
                    torch_dtype=self.dtype,
                )
                self.model.to(self.device)
            else:
                self.model = AutoModelForCausalLM.from_pretrained(
                    path,
                    torch_dtype=self.dtype,
                    # device_map='auto',
                    # trust_remote_code=True
                )
        except Exception as e:
            print(e)
            raise ValueError('Model not yet handled, please raise an issue on the repository.')
        finally:
            self.tokenizer = AutoTokenizer.from_pretrained(path)
            if self.tokenizer.pad_token is None:
                print('pad token is None')
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self.model.resize_token_embeddings(len(self.tokenizer))
            print('padding_side', self.tokenizer.padding_side)
            # self.tokenizer.padding_side = 'right' # if we want to force padding on the right

    def fit(self,
            dataset: Union[pd.DataFrame, pd.Series, np.ndarray],
            hf_trainer: bool = True,
            learning_rate: float = 8e-5, 
            embedded_dims: list = [],
            **kwargs
            ):
        if not isinstance(dataset, pd.DataFrame) and not isinstance(dataset, np.ndarray):
            raise TypeError('Dataset must be a pandas dataframe or numpy array')
        if isinstance(dataset, np.ndarray):
            if not kwargs['columns'] :
                # assigning fantom column names
                columns = [f'column_{i+1}' for i in range(dataset.shape(1))]
                dataset = pd.DataFrame(dataset, columns=columns)

        # Set `self.original_data` to store the initial training dataset
        self.original_data = dataset.copy()
        self.embedded_dims = embedded_dims

        self.feature_distribution = {}
        for col in dataset.columns:
            print('feature_distribution col', col)
            self.feature_distribution[col] = get_feature_distribution(dataset[col].tolist())

        self._column_format(dataset)
        self.decimal = 4 if 'decimal' not in kwargs else kwargs['decimal']

        if 'shuffle' not in kwargs or kwargs['shuffle']:
            dataset = dataset.sample(frac=1, random_state=self.seed)

        dataset = Dataset.from_pandas(dataset)

        def data_to_text(row, column, permute: bool, text_template: str):    
            if permute:
                    random.shuffle(column)
            if text_template == 'base_template':
                text = ", ".join(
                    [
                        f"{c} is {row[c]:.{self.decimal}f}" if self.feature_distribution[c][0] == 'numerical' else
                        f"{c} is {row[c]}" for c in column
                    ]
                )
                text = text + self.tokenizer.eos_token
            elif text_template == 'fim_template':
                text_input = ", ".join(
                    [
                        f"{c} is [blank]" for c in column
                    ]
                )
                text_target = " ".join(
                    [
                        f"{row[c]:.{self.decimal}f} [answer]" if self.feature_distribution[c][0] == 'numerical' else
                        f"{row[c]} [answer]" for c in column
                    ]
                )
                text = "Input: " + text_input + " [sep] Target: " + text_target + self.tokenizer.eos_token
            elif text_template == 'fim_template_textual_encoding':
                text_categorical = "".join(
                    [
                        f"{c} is {row[c]}, " if self.feature_distribution[c][0] == 'categorical' else
                        "" for c in column
                    ]
                ).strip(", ")

                text_input = "".join(
                    [
                        f"{c} is [blank], " if self.feature_distribution[c][0] == 'numerical' else
                        "" for c in column
                    ]
                ).strip(", ")
                text_target = "".join(
                    [
                        f"{float(row[c]):.{self.decimal}f} [answer] " if self.feature_distribution[c][0] == 'numerical' else
                        "" for c in column
                    ]
                ).strip(" ")
                text = "Condition: " + text_categorical + " [sep] Input: " + text_input + " [sep] Target: " + text_target + self.tokenizer.eos_token
            return text

        if 'permute' in kwargs:
            self.permute = kwargs['permute']
        else:
            self.permute = True

        data2text = [data_to_text(row, self.columns.copy(), self.permute, self.text_template) for row in dataset]
        
        print('')
        print('')
        print('')
        print('-----------------------------------------------')
        print('FROM TABULAR DATA TO TEXT')
        print('-----------------------------------------------')
        print('\n DEVICE:', self.device, '\n')
        # print('')
        # print(data2text)
        
        print("\n\n".join(data2text))
        dataset = dataset.add_column('data2text', data2text)

        print('')
        print('')
        print('')
        print('-----------------------------------------------')
        print('              FINE-TUNING LLM')
        print('-----------------------------------------------')
        # print('')
        # TODO: remove after demo
        import warnings
        # Suppress warnings
        warnings.filterwarnings('ignore')
        print('\n DEVICE:', self.device, '\n')

        def preprocess_function(row):
            model_inputs = self.tokenizer(row['data2text'], padding=True)
            model_inputs['labels'] = model_inputs['input_ids']
            return model_inputs
        # train_data = dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)

        # Preprocess full dataset
        tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)
        # Split into train (80%) and validation (20%)
        split_dataset = tokenized_dataset.train_test_split(test_size=0.2, seed=self.seed)
        train_data = split_dataset["train"]
        val_data = split_dataset["test"]

        # data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)

        # if 'learning_rate' in kwargs:
        #     self.train_args['learning_rate'] = kwargs['learning_rate']
        self.train_args['learning_rate'] = learning_rate
        self.train_args['num_train_epochs'] = kwargs['epochs']
        self.train_args['per_device_train_batch_size'] = kwargs['batch_size']
        if hf_trainer:
            print(f'{self.model.dtype=}')

            training_args = TrainingArguments(
                self.output_dir,
                **self.train_args,
                adam_epsilon=1e-04,
                logging_strategy="steps",
                logging_steps=10,
                # weight_decay=0.0001,
                # optim='adafactor',
                # max_grad_norm=1,
                # max_grad_norm=5,
                evaluation_strategy="steps",
                eval_steps=5,
                save_strategy="steps",
                save_steps=100,
                load_best_model_at_end=True,
                # metric_for_best_model="loss",
                metric_for_best_model="eval_loss",
                greater_is_better=False,
                bf16=True,
            )

            # def compute_metrics(eval_pred):
            #     return {}

            trainer = Trainer(
                model=self.model,
                # compute_metrics=compute_metrics,
                args=training_args,
                tokenizer=self.tokenizer,
                train_dataset=train_data,
                # eval_dataset=train_data,
                eval_dataset=val_data,
                # data_collator=data_collator,
                # callbacks=[
                #     EarlyStoppingCallback(
                #         early_stopping_patience=1,
                #         early_stopping_threshold=0.0  # Optional: minimum improvement required
                #     )]
                callbacks=[EvalLossEarlyStopping(patience=5)]
            )
            trainer.train()
            print(f'{self.model.dtype=}')
            trainer.save_model(self.output_dir+'/best_model')
            self.handle_model(self.output_dir + '/best_model', self.quantization_config)

        else:
            # Custom Training Procedure
            epochs = self.train_args.get('num_train_epochs', 3)
            batch_size = self.train_args.get('per_device_train_batch_size', 8)
            learning_rate = self.train_args.get('learning_rate', 5e-5)

            # Define optimizer (equivalent to what Trainer would use)
            optimizer = AdamW(self.model.parameters(), lr=learning_rate, eps=1e-04)

            # Set up the data loader
            # train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
            train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

            # Training loop
            self.model.train()  # Set model to training mode
            # TODO: add a learning rate scheduler

            for epoch in tqdm(range(epochs)):
                print(f"Epoch {epoch + 1}/{epochs}")
                epoch_loss = 0

                for batch in train_loader:
                    optimizer.zero_grad()  # Reset gradients

                    # Move batch to the same device as the model
                    input_ids = batch['input_ids'].to(self.model.device)
                    attention_mask = batch['attention_mask'].to(self.model.device)
                    labels = batch['labels'].to(self.model.device)

                    # Forward pass
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss

                    # Backward pass
                    loss.backward()
                    optimizer.step()

                    epoch_loss += loss.item()

                avg_loss = epoch_loss / len(train_loader)
                print(f"Average loss for epoch {epoch + 1}: {avg_loss}")

        #
        # return trainer

    def _column_format(self, df: pd.DataFrame):
        # Update the column names (and numerical columns for some sanity checks after sampling)
        self.columns = df.columns.to_list()
        # self.num_cols = df.select_dtypes(include=np.number).columns.to_list()

    def generate(self,
                 n_samples_min: int = 10,
                 n_samples_max: int = 1500,
                 stopping_treshold: float = 0.98,
                 start_col: Optional[str] = "",
                 start_col_dist: Optional[Union[dict, list]] = None,
                 temperature: float = 1,
                 k: int = 100, # 32 JUST FOR PHI
                 max_length: int = 3000,
                 drop_nan: bool = False,
                 permute: bool = None,
                 check_distribution: bool = False,
                 init_value: bool = True,
                 device: Optional[str] = None,
                 ) -> pd.DataFrame:

        # TODO: check starting strategy
        self.device = self.choose_device(device)
        self.model.to(self.device)

        if permute:
            self.permute = permute

        # Init list for generated DataFrames
        dfs = []

        print('')
        print('')
        print('')
        print('-----------------------------------------------')
        print('NEW SAMPLES GENERATION')
        print('-----------------------------------------------')

        # Start generation process
        # TODO: uncomment first
        # with (tqdm(total=n_samples_max) as pbar)
        with (tqdm(total=n_samples_max, disable=True) as pbar):

            # Initialize list to monitor data generation
            numerical_columns = [col for col in self.columns if self.feature_distribution[col][0] == 'numerical']  # TODO: add this
            split_indices = np.cumsum(self.embedded_dims)[:-1]  # Compute split points based on embedded_dims
            original_data_splits = np.split(self.original_data[numerical_columns].values, split_indices, axis=1)
            old_l2_norms_splits = [list(np.linalg.norm(split.astype(float), axis=1)) for split in original_data_splits]
            new_samples = []
            duplicates = []
            nas = []
            norms_diversity = [[] for _ in self.embedded_dims]

            already_generated = 0
            _cnt = 0
            try:
                while n_samples_max > already_generated:
                    input_prompts = self.init_tokens(k, start_col, start_col_dist, distribution=None, init_value=init_value, text_template=self.text_template) # starting token
                    input_ids = self.tokenizer(
                        input_prompts,
                        return_tensors='pt'
                    ).to(self.device)

                    print('')
                    print(' --------------')
                    print(f' ITER {_cnt}')
                    print(' --------------')

                    small_input_prompts = input_prompts[0:5]
                    print('')
                    print('  * INPUT PROMPTS:')
                    print("\n".join([f"{prompt}" for prompt in small_input_prompts]))
                    print('   ...\n')

                    # Generate tokens
                    self.model.eval()

                    with torch.inference_mode():
                        tokens = self.model.generate(
                            **input_ids,
                            max_length=max_length,
                            do_sample=True,
                            temperature=temperature,
                            # pad_token_id=self.tokenizer.pad_token_id,
                            eos_token_id=self.tokenizer.pad_token_id,
                        )

                    # Convert tokens back to tabular data
                    text_data = self.tokenizer.batch_decode(tokens)
                    # Clean text
                    text_data = [d.replace(self.tokenizer.eos_token, "") for d in text_data]
                    text_data = [d.replace("\n", " ") for d in text_data]
                    text_data = [d.replace("\r", "") for d in text_data]
                    # print(text_data)

                    small_text_data = text_data[0:5]
                    print('')
                    print('  * GENERATED TEXT:')
                    print("\n".join([f"{sample}" for sample in small_text_data]))
                    print('   ...')
                    print('   ...')

                    df_gen = self._convert_text_to_tabular_data(text_data, self.columns, self.text_template)
                    # df_gen = self._convert_fim_text_to_tabular_data(text_data, self.columns)

                    # Remove rows where we have not generated anything
                    df_gen = df_gen[~(df_gen == "NaN").any(axis=1)]
                    df_gen = df_gen[~df_gen.isnull().any(axis=1)]

                    # Remove rows where all values are NaN
                    df_gen = df_gen.dropna(how="all")

                    # Optional: Remove rows with any NaN values
                    if drop_nan:
                        df_gen = df_gen.dropna()

                    # Remove rows with flawed numerical values but keep NaNs 
                    # HERE CHECK SELF DISTRIBUTION TO ENSURE THAT WE DO NOT CHECK CATEGORICAL VALUES

                    for col, dist in self.feature_distribution.items():
                             if dist[0] != 'categorical':
                                 coerced_series = pd.to_numeric(df_gen[col], errors="coerce")
                                 df_gen = df_gen[coerced_series.notnull() | df_gen[col].isna()]
                                 df_gen[col] = df_gen[col].astype(float)
                    
                    # if check_distribution:
                    #     # Checking distribution
                    #     for col, dist in self.feature_distribution.items():
                    #         if dist[0] == 'numerical':
                    #             if not dist[1]:
                    #                 df_gen = df_gen[~((df_gen[col] < dist[2]['min']) | (df_gen[col] > dist[2]['max']))]
                    #                 # df_gen = df_gen[~((df_gen[col] < dist[2]['min']) | (df_gen[col] > dist[2]['max']))]
                    #             else:
                    #                 series_gen = (df_gen[col].copy() - dist[2]['mean'])/ dist[2]['std']
                    #                 df_gen = df_gen[~((series_gen < -1.96 * dist[2]['std']) | (series_gen > 1.96 * dist[2]['std']))]

                    # Track the size of df_gen after NaN drop and duplicates removal
                    after_na = df_gen.shape[0]

                    # Append df_gen to dfs, then concatenate and remove duplicates across all generated samples
                    dfs.append(df_gen)
                    data = pd.concat(dfs).round(3).drop_duplicates()
                    after_drop_duplicates = data[already_generated:].shape[0]

                    if check_distribution:
                        # Identify numerical columns and partition new vs. old generated data
                        new_data = data[already_generated:]  # New batch of generated samples

                        # Calculate multiple L2 norms for each subset based on embedded_dims
                        new_data_splits = np.split(new_data[numerical_columns].values, split_indices, axis=1)  # Split the data

                        # Calculate L2 norms for each subset
                        new_generated_norms = [np.linalg.norm(split, axis=1) for split in new_data_splits]

                        # Calculate IQR bounds for each set of norms
                        bounds = []
                        for norms in old_l2_norms_splits:
                            Q1, Q3 = np.percentile(norms, [25, 75])
                            IQR = Q3 - Q1
                            bounds.append((Q1 - 3 * IQR, Q3 + 3 * IQR))  # Store lower and upper bounds for each subset

                        # Check if each row in the new data passes the IQR bounds for each subset norm
                        accepted_mask = np.ones(len(new_data), dtype=bool)
                        for i, (norms, (lower, upper)) in enumerate(zip(new_generated_norms, bounds)):
                            discard_mask = (norms < lower) | (norms > upper)  # Outliers for the current subset
                            accepted_mask &= ~discard_mask  # Update the accepted mask based on current subset check

                            # Print details for discarded norms, if any
                            if np.any(discard_mask):
                                print(f"Var {i + 1} Discarded Norms: {norms[discard_mask]}")
                                pass

                        # Keep only accepted data within bounds across all subsets
                        data = pd.concat([data[:already_generated], new_data[accepted_mask]])  # Update data with accepted norms

                        # Update old norms for each subset
                        for i, norms in enumerate(new_generated_norms):
                            old_l2_norms_splits[i].extend(norms[accepted_mask])

                        # # Optional: print statistics if needed
                        # print(f"Accepted new samples: {accepted_mask.sum()}")

                    # printing
                    # TODO: uncomment after demo
                    print(f"\n After NaN drop: {k} -> {after_na}")
                    print(f" After duplicates drop: {after_na} -> {after_drop_duplicates}")
                    if check_distribution:
                        # print(f"After norm-check drop: {after_drop_duplicates} -> {delta_update}")
                        print(f" After norm-check drop: {after_drop_duplicates} -> {sum(accepted_mask)}")
                    print(f" Total samples: {already_generated} -> {data.shape[0]}")
                    # print('')

                    # update lists
                    delta = data.shape[0] - already_generated
                    new_samples.append(delta)
                    duplicates.append(after_na - after_drop_duplicates)
                    nas.append(k - after_na)
                    unique_ratios = data.round(3).apply(lambda col: col.nunique() / len(col))

                    for i, subset_norms in enumerate(old_l2_norms_splits):
                        # Compute standard deviation for the current subset of norms
                        std_norms = np.std(subset_norms)

                        rounding_factor = 3

                        # Round norms based on calculated rounding factor and compute diversity for the subset
                        unique_norms = len(set(np.round(subset_norms, rounding_factor)))
                        diversity_score = unique_norms / len(subset_norms)
                        norms_diversity[i].append(diversity_score)  # Proportion of unique rounded norms

                    # important, update already generated
                    already_generated = data.shape[0]

                    # Reset dfs to hold the combined unique samples for the next cycle
                    dfs = [data]

                    # Update the progress bar with the number of new samples accepted in this iteration
                    pbar.update(delta)

                    # Check if we are actually generating synthetic samples and if not, break everything
                    _cnt += 1
                    if _cnt > 13 and already_generated == 0:
                        raise Exception("Breaking the generation loop!")

                    # check stopping criterion
                    max_element = max(sublist[-1] for sublist in norms_diversity)
                    if (already_generated > n_samples_min) & (max_element < stopping_treshold):
                        print(f"\nStop generation due to stopping criterion ({stopping_treshold})")
                        print(f"Generated samples: {already_generated}")
                        df_gen = pd.concat(dfs)
                        df_gen = df_gen.reset_index(drop=True)
                        return df_gen

            except Exception as e:
                print(f"ERROR: {str(e)}")

        print(f"\nStopping criterion not satisfied. Reached max generations ({n_samples_max})")
        df_gen = pd.concat(dfs)
        df_gen = df_gen.reset_index(drop=True)

        return df_gen.head(n_samples_max)

    @staticmethod
    def choose_device(device) -> str:
        if device:
            return device
        if sys.platform == 'darwin':
            return 'mps'
        return 'cuda' if torch.cuda.is_available() else 'cpu'

    def init_tokens(self, k, start_col, start_col_dist, distribution: Optional[str] = None, init_value: bool = True, text_template: str = 'base_template'):
        
        if text_template == 'base_template':
            if not distribution:
                distrib_uni_f = lambda c: np.random.uniform(low=self.feature_distribution[c][2]['min'], high=self.feature_distribution[c][2]['max'], size=1).item()
                distrib_nor_f = lambda c: np.random.normal(self.feature_distribution[c][2]['mean'], self.feature_distribution[c][2]['std'], size=1).item()
                distrib_r_c_f = lambda c: np.random.choice(self.feature_distribution[c][2]['data']) + np.random.normal(0, 0.1*self.feature_distribution[c][2]['std'])

            if not start_col_dist:
                if self.permute:
                    col_starter = random.choices(self.columns, k=k)
                else:
                    col_starter = [self.columns[0]] * k
                if init_value:
                    starter = [
                        (c, f'{distrib_nor_f(c):.{self.decimal}f}') if self.feature_distribution[c][1]
                        else (c, f'{distrib_r_c_f(c):.{self.decimal}f}')
                        for c in col_starter
                    ]
                else:
                    starter = [(c, '') for c in col_starter]

                input_prompts = [''.join([c, " is ", v, "," if init_value else '']) for c, v in starter]
                input_prompts = [s[:-1] if s.endswith(" ") else s for s in input_prompts] # remove last space of the prompt if present

                len_input = [len(self.tokenizer(prompt)['input_ids']) for prompt in input_prompts]
                max_tokens = max(len_input)
                padding = self.tokenizer.pad_token
                input_prompts = [padding * (max_tokens - l) + prompt if l < max_tokens else prompt for l, prompt in zip(len_input, input_prompts)]

            else:
                raise NotImplemented
            
        elif text_template == 'fim_template':
            col_starter = []
            if self.permute:
                col_list = self.columns.copy()
                for i in range(0,k):
                    random.shuffle(col_list)
                    col_starter.append(col_list.copy())
        
            else:
                col_starter = [self.columns.copy()] * k
            input_prompts = ['Input: ' + ' is [blank], '.join(starter) + ' is [blank] [sep] Target:'for starter in col_starter]
            len_input = [len(self.tokenizer(prompt)['input_ids']) for prompt in input_prompts]
            max_tokens = max(len_input)
            padding = self.tokenizer.pad_token
            input_prompts = [padding * (max_tokens - l) + prompt if l < max_tokens else prompt for l, prompt in zip(len_input, input_prompts)]
        
        elif text_template == 'fim_template_textual_encoding':
            input_prompts = []

            for i in range(0,k):
                
                if self.permute:
                    col_list = self.columns.copy()
                    random.shuffle(col_list)
                
                text_categorical = "".join([f"{c} is {random.choice(list(self.feature_distribution[c][2].keys()))}, " if self.feature_distribution[c][0] == 'categorical' else "" for c in self.columns]).strip(", ")
                
                text_input = "".join([f"{c} is [blank], " if self.feature_distribution[c][0] == 'numerical' else"" for c in col_list]).strip(", ")
                
                text = "Condition: " + text_categorical + " [sep] Input: " + text_input + " [sep] Target:"

                input_prompts.append(text)

            len_input = [len(self.tokenizer(prompt)['input_ids']) for prompt in input_prompts]
            max_tokens = max(len_input)
            padding = self.tokenizer.pad_token
            input_prompts = [padding * (max_tokens - l) + prompt if l < max_tokens else prompt for l, prompt in zip(len_input, input_prompts)]

        return input_prompts

    def _convert_tokens_to_text(self):
        pass

    def _convert_text_to_tabular_data(self,
            text: List[str], columns: List[str], text_template: str
    ) -> pd.DataFrame:
        """Converts the sentences back to tabular data

        Args:
            text: List of the tabular data in text form
            columns: Column names of the data

        Returns:
            Pandas DataFrame with the tabular data from the text appended
        """
        generated = []

        if text_template == 'base_template':
            # Convert text to tabular data for base_template
            for t in text:
                features = t.split(",")
                td = dict.fromkeys(columns, "NaN")

                # Transform all features back to tabular data
                for f in features:
                    values = f.strip().split(" is ")
                    if values[0] in columns and td[values[0]] == "NaN":
                        try:
                            if len(values) >= 2:
                                if self.feature_distribution[values[0]][0] == 'numerical':
                                    i = 1
                                    while i <= len(values):
                                        try:
                                            trial = float(values[i])
                                            td[values[0]] = values[i]
                                            i = len(values) + 1
                                        except ValueError:
                                            i += 1
                                            continue
                                elif self.feature_distribution[values[0]][0] == 'categorical':
                                    # TODO: create categorical check
                                    pass
                        except IndexError:
                            # print("An Index Error occurred - if this happends a lot, consider fine-tuning your model further.")
                            pass
                generated.append(td)
            df_gen = pd.DataFrame(generated)
            df_gen.replace("None", None, inplace=True)
        
        if text_template == 'fim_template':
            # Convert text to tabular data
            for t in text:
                try: 
                    features_input = t.split(",")
                    td = dict.fromkeys(columns, "NaN")

                    input_part, output_part = t.split(" [sep] Target: ")
                    input_values = input_part.removeprefix("Input: ")
                    output_values = output_part.removesuffix(" [answer]")

                    features_input = input_values.split(",")
                    features_output = output_values.split(" [answer] ")
                except ValueError:
                    # If text is malformed, skip it
                    print(f"Skipping malformed text: {t}")
                    continue
                if len(features_output) >= len(features_input):
                    # Transform all features back to tabular data
                    for i in range(0, len(features_input)):
                        f = features_input[i]
                        col_i = f.strip().split(" is ")[0]
                        if col_i in columns:
                            try:
                                val_i = features_output[i]
                                if col_i in columns:
                                    if self.feature_distribution[col_i][0] == 'numerical':
                                        td[col_i] = float(val_i)  # Parse numerical values
                                    else:
                                        td[col_i] = val_i  # Keep categorical as string
                            except ValueError:
                                print(f"Skipping malformed output: {features_output}")
                                continue
                generated.append(td)
            # Create DataFrame and replace placeholder "NaN" with actual NaN
            df_gen = pd.DataFrame(generated)
            df_gen.replace("NaN", None, inplace=True)
        
        if text_template == 'fim_template_textual_encoding':
            # Convert text to tabular data
            for t in text:
                try: 
                    features_input = t.split(",")
                    td = dict.fromkeys(columns, "NaN")
                                
                    # get data_name
                    new_part_condition, new_part_values = t.split(" [sep] Input: ")
                    input_textual_data = new_part_condition.removeprefix("Condition: data is ")
                    
                    input_values, output_part = new_part_values.split(" [sep] Target: ")
                    # input_values = input_part.removeprefix("Input: ")
                    output_values = output_part.removesuffix(" [answer]")
                    
                    features_input = input_values.split(",")
                    features_output = output_values.split(" [answer] ")
                
                except ValueError:
                    # If text is malformed, skip it
                    print(f"Skipping malformed text: {t}")
                    continue

                if len(features_output) >= len(features_input):
                    # Transform all features back to tabular data
                    td['data']=input_textual_data
                    for i in range(0, len(features_input)):
                        f = features_input[i]
                        col_i = f.strip().split(" is ")[0]
                        if col_i in columns:
                            try:
                                val_i = features_output[i]
                            
                                if col_i in columns:
                                    if self.feature_distribution[col_i][0] == 'numerical':
                                        td[col_i] = float(val_i)  # Parse numerical values
                                    else:
                                        td[col_i] = val_i  # Keep categorical as string
                            except ValueError:
                                print(f"Skipping malformed output: {features_output}")
                                continue
                generated.append(td)
            df_gen = pd.DataFrame(generated)
            df_gen.replace("NaN", None, inplace=True)

        return df_gen

    def set_seed(self):
        os.environ['PYTHONHASHSEED'] = '0'
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)
        torch.cuda.manual_seed_all(self.seed)