import sys
import math
import pandas as pd
import numpy as np
from datetime import datetime
import time
from tqdm import tqdm
import copy

import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
from torch.optim import AdamW
from torch.nn.utils.rnn import pad_packed_sequence

from sklearn.metrics import roc_auc_score, f1_score
from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline

from utils import *
from lle import SkipGram, strat_avg_all, strat_select_id

from predictors import LSTMClassifier, TransformerEncoderClassifier
from dataloader.data_manager import FastDataLoader, PackSequenceDataset, SequenceDataset

def create_ngram_dataset(data: pd.DataFrame, context: int, add_event_type: bool = True, verbose: bool = True) -> dict:
    """_summary_

    :param data: _description_
    :type data: pd.DataFrame
    :param context: _description_
    :type context: int
    :param add_event_type: _description_, defaults to True
    :type add_event_type: bool, optional
    :param verbose: _description_, defaults to True
    :type verbose: bool, optional
    :return: _description_
    :rtype: dict
    """
    def create_action_n_grams(actions: list, n: int) -> tuple:
        """_summary_

        :param actions: _description_
        :type actions: list
        :param n: _description_
        :type n: int
        :return: _description_
        :rtype: tuple
        """
        n_grams = set()
        actions = ["START"]*n + list(actions) + ["END"]*n
        for i in range(len(actions[n:-n])):
            X = actions[i+n]
            y = tuple([actions[i+n+x] for x in list(range(int(-n/2), 0)) + list(range(1, math.ceil(n/2) + 1))])
            n_grams.add((X,y))
        return n_grams

    embedding_touchpoints = data.touchpoint.tolist()
    if add_event_type:
        # create new touchpoints event_type:touchpoint
        embedding_touchpoints = [[":".join(map(str, i)) for i in zip(seq.event_type, seq.touchpoint)] for seq in data.itertuples()]
    
    # create trigrams for embedding training
    ngrams = {ngram for seq in embedding_touchpoints for ngram in create_action_n_grams(seq, context)}
    ngrams = [(gram[0], list(gram[1])) for gram in ngrams]

    # create a vocabulary map that maps embedding touchpoints to ids. Note tokens are descended sorted by frequancy except START and END token.
    vocab_map = {token: i for i, token in enumerate(pd.Series([gram[0] for gram in ngrams]).value_counts().index)}
    tmp = len(vocab_map)
    vocab_map.update({"START": tmp, "END": tmp+1})
    
    if verbose:
        print(
            f"""{len(ngrams)} ngrams created.\n{len(vocab_map)} different tokens exist for embedding Training.\n
            """
        )
    return {"ngrams": ngrams, "vocab_map": vocab_map}

def create_dataset_from_ngrams(ngrams, vocab_map, unknown_sample: float = 0.0, unknown_value: int = None) -> SequenceDataset:
    if not unknown_value:
        unknown_value = len(vocab_map)
    
    context = len(ngrams[0][1])
    X = np.zeros((len(ngrams)), dtype=np.int32)
    y = np.zeros((len(ngrams), context), dtype=np.int32)
    for i, ngram in enumerate(ngrams):
        X[i] = vocab_map[ngram[0]]
        y[i] = [vocab_map[v] for v in ngram[1]]


    return SequenceDataset(X, y, unknown_sample=unknown_sample, unknown_value=unknown_value)

def data_to_tensor(data, embedding, action_map, device = torch.device("cpu")):
    """_summary_

    :param data: _description_
    :type data: _type_
    """
    X = []
    for session in data:
        X.append(
            embedding.embed(
                torch.LongTensor(
                    [action_map.get(action, len(action_map)) for action in session]
                ).to(device)
            ).detach()
        )

    return X

def train_embedding(embedding, dataloader, criterion, optimizer, approach, approach_params, device) -> float:
    embedding.train()
    train_loss = []

    for batch in dataloader:
        embedding.zero_grad()
        features, targets = batch[0], batch[1]

        features = features.long().to(device)
        targets = targets.long().to(device)

        out = embedding(features)
        loss = sum([criterion(o, targets[:, i]) for i, o in enumerate(out)])
        
        if "lwf" in approach and "lwf" in approach_params:
            loss += lwf_loss(embedding, out, features=features, **approach_params["lwf"])
        if "mas" in approach and "mas" in approach_params:
            loss += mas_loss(embedding, **approach_params["mas"])
        if "ewc" in approach:
            if "ewc" in approach_params:
                loss += ewc_loss(embedding, **approach_params["ewc"])
            update_fisher_information(approach_params["fisher_info_new"], model=embedding, scale=1.0 / features.size(0))
        
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()

    return np.mean(train_loss)

def train_model(model, dataloader, criterion, optimizer, device) -> tuple:
    model.train()
    predictions = []
    predictions_probs = []
    all_targets = []

    for batch in dataloader:
        features, targets = batch[0], batch[1]

        features = features.to(device)
        features, lengths = pad_packed_sequence(features, padding_value=0.0, batch_first=True)
        features = (features, lengths)  # Always pack into a tuple
        targets = targets.to(device).float().unsqueeze(1)

        model.zero_grad()
        out = model(features)
        loss = criterion(out, targets)
        loss.backward()
        optimizer.step()

        predictions += torch.round(torch.sigmoid(out)).cpu().detach().tolist()
        predictions_probs += torch.sigmoid(out).cpu().detach().tolist()
        all_targets += targets.cpu().detach().tolist()
    
    try:
        f1 = f1_score(all_targets, predictions)
    except:
        f1 = 0.0
    try:
        auc = roc_auc_score(all_targets, predictions_probs)
    except:
        auc = 0.0
    return f1, auc


@torch.no_grad()
def validate_embedding(embedding, dataloader, criterion, device) -> float:
    embedding.eval()
    val_loss = []

    for batch in dataloader:
        features, targets = batch[0], batch[1]

        features = features.long().to(device)
        targets = targets.long().to(device)

        out = embedding(features)
        loss = sum([criterion(o, targets[:, i]) for i, o in enumerate(out)])
        val_loss.append(loss.item())

    return np.mean(val_loss)

@torch.no_grad()
def validate_model(model, dataloader, device) -> tuple:
    model.eval()
    predictions = []
    all_targets = []
    predictions_probs = []

    for batch in dataloader:
        features, targets = batch[0], batch[1]

        features = features.to(device)
        features, lengths = pad_packed_sequence(features, padding_value=0, batch_first=True)
        features = (features, lengths)  # Always pack into a tuple
        targets = targets.to(device).float().unsqueeze(1)

        out = model(features)
        predictions += torch.round(torch.sigmoid(out)).cpu().detach().tolist()
        predictions_probs += torch.sigmoid(out).cpu().detach().tolist()
        all_targets += targets.cpu().detach().tolist()
    
    try:
        f1 = f1_score(all_targets, predictions)
    except:
        f1 = 0.0
    try:
        auc = roc_auc_score(all_targets, predictions_probs)
    except:
        auc = 0.0
    return f1, auc


def check_embedding_size(embedding, ngrams, criterion, optimizer, epochs, batch_size, device):
    """
    Check the necessary embedding size for the embedding model by checking for overfitting.
    The embedding model is trained on the ngrams and the loss is calculated. If the loss is not decreasing, the embedding size is too small. (underfitting)
    If the loss is decreasing, the embedding size is sufficient.
    If the gap between train and validation loss is too big, the embedding size is too big. (overfitting)
    The function returns the epoch, train loss and validation loss.
    """
    
    train_ids, val_ids = np.split(np.random.permutation(len(ngrams)), [int(0.9 * len(ngrams))])
    train_ids, val_ids = set(train_ids), set(val_ids)
    train_dataset = create_dataset_from_ngrams([n for i, n in enumerate(ngrams) if i in train_ids], embedding.embedding.vocab_map, unknown_sample=0.01)
    val_dataset = create_dataset_from_ngrams([n for i, n in enumerate(ngrams) if i in val_ids], embedding.embedding.vocab_map, unknown_sample=0.01)

    train_loader = FastDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = FastDataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    train_loss = []
    val_loss = []
    for epoch in range(epochs):
        train_loss.append(
            train_embedding(embedding, train_loader, criterion, optimizer, [], {}, device)
        )
        val_loss.append(
            validate_embedding(embedding, test_loader, criterion, device)
        )

        if detect_loss_divergence(train_loss, val_loss, growth_factor=1.05):
            print(f"Loss diverged at epoch {epoch} for embedding size {embedding.embedding.embedding_dim}")
            return epoch, np.mean(train_loss[-5:]), np.mean(val_loss[-5:]), True
    
    best_epoch = np.argmin(val_loss)
    return best_epoch, train_loss[best_epoch], val_loss[best_epoch], False

def adapt_embedding_size(embedding, ngrams, criterion, optimizer_function, epochs, batch_size, context_size, device) -> tuple:
    step = 2
    best_embedding_dim = embedding.embedding.embedding_dim
    best_epochs = epochs

    embedding.cpu() # embedding is only used for copies so move to cpu
    embedding_new = copy.deepcopy(embedding).to(device) # create always a new copy of the best old embedding
    # check if the embedding is overfitting
    epochs_trained, train_loss, val_loss, overfitting = check_embedding_size(
        embedding_new, ngrams, criterion, optimizer_function(embedding_new.parameters(), lr=1e-3), epochs, batch_size, device
    )
    best_embedding_loss = val_loss
    best_epochs = epochs_trained
    

    if overfitting:
        # if overfitting check for a smaller embedding size
        embedding_size = best_embedding_dim
        while overfitting and embedding_size > step:
            embedding_size -= step
            embedding_new = SkipGram(embedding.embedding.vocab_map, embedding_dim=embedding_size, context_size=context_size, unknown_idx=True, padding_idx=True).to(device)
            epochs_trained, train_loss, val_loss, overfitting = check_embedding_size(
                embedding_new, ngrams, criterion, optimizer_function(embedding_new.parameters(), lr=1e-3), epochs, batch_size, device
            )
            if val_loss < best_embedding_loss:
                # if the embedding val loss is lower than the best embedding loss
                # set the best parameters
                best_embedding_loss = val_loss
                best_epochs = epochs_trained
                best_embedding_dim = embedding_size

    else:
        # model is not overfitting
        # check if the embedding size can be increased
        embedding_size = best_embedding_dim
        while not overfitting and embedding_size < embedding.embedding.embedding_dim * 2: 
            embedding_new = copy.deepcopy(embedding).to(device) # create always a new copy of the best old embedding
            embedding_size += step
            embedding_new.extend_embedding(embedding_size)
            epochs_trained, train_loss, val_loss, overfitting = check_embedding_size(
                embedding_new, ngrams, criterion, optimizer_function(embedding_new.parameters(), lr=1e-3), epochs, batch_size, device
            )
            if not overfitting and val_loss < best_embedding_loss:
                # if the embedding is not overfitting and the val loss is lower than the best embedding loss
                # set the best parameters
                best_embedding_loss = val_loss
                best_epochs = epochs_trained
                best_embedding_dim = embedding_size
            elif best_embedding_loss < val_loss:
                break # break the loop if the val loss is not decreasing anymore
    # epochs are count from 0 to epochs-1, so add 1 to the best epochs
    embedding_new.cpu()
    print(f"Embedding size: {best_embedding_dim}, Validation loss: {best_embedding_loss}, Epochs: {best_epochs}")
    return best_embedding_dim, best_epochs+1, best_embedding_loss

def embedding_train_loop(embedding, ngrams, criterion, optimizer, epochs, batch_size, approaches, approach_params, device):  
    dataset = create_dataset_from_ngrams(ngrams, embedding.embedding.vocab_map, unknown_sample=0.01)
    dataloader = FastDataLoader(dataset, batch_size=batch_size, shuffle=True)

    pbar = tqdm(range(epochs))
    for epoch in pbar:
        loss = train_embedding(embedding, dataloader, criterion, optimizer, approaches, approach_params, device)
        pbar.set_description(f"Epoch {epoch+1}/{epochs}")
        pbar.set_postfix({
            'Train Loss': f"{loss:.4f}",
        })
    
    if "mas" in approaches:
        approach_params["new_importance_weights"] = compute_mas_importance_weights(embedding, dataloader, device)

def model_train_loop(model, embedding, X_train, y_train, X_test, y_test, criterion, optimizer, epochs, batch_size, device, early_stopping: int = 5) -> tuple:
    train_dataset = PackSequenceDataset(data_to_tensor(X_train, embedding, embedding.embedding.vocab_map, device), torch.from_numpy(y_train))
    train_dataloader = FastDataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = PackSequenceDataset(data_to_tensor(X_test, embedding, embedding.embedding.vocab_map, device), torch.from_numpy(y_test))
    test_dataloader = FastDataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print(f"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}")
    best_f1, best_auc = 0, 0
    epochs_counter = 0

    pbar = tqdm(range(epochs))
    for epoch in pbar:
        train_f1, train_auc = train_model(model, train_dataloader, criterion, optimizer, device)
        f1, auc = validate_model(model, test_dataloader, device)

        pbar.set_description(f"Epoch {epoch+1}/{epochs}")
        pbar.set_postfix({
        'Train F1': f"{train_f1:.4f}",
        'Train AUC': f"{train_auc:.4f}",
        'Test F1': f"{f1:.4f}",
        'Test AUC': f"{auc:.4f}"
        })

        if f1 > best_f1 and auc > best_auc:
            best_f1 = f1
            best_auc = auc
            epochs_counter += 1
            if epochs_counter > early_stopping:
                break
        else:
            epochs_counter = 0
            
        
    return epoch, f1, auc, train_f1, train_auc

def main(use_case: str):
    now = datetime.now().strftime("%Y%m%d-%H%M%S")

    data_path = f"to_set"

    data = pd.read_pickle(data_path)
    if not pd.api.types.is_datetime64_any_dtype(data.first_ts):
        # If not datetime, infer unit
        first_val = data.first_ts.iloc[0]
        unit = 'ns' if first_val > 1e18 else 'ms' if first_val > 1e12 else 's'
        data["first_ts"] = pd.to_datetime(data.first_ts, unit=unit)

    data["week"] = data.first_ts.dt.isocalendar().week
    data["week"] = data.week - data.week.min()
    week_ngrams = []
    week_maps = []

    if "event_type" in data.columns:
        data["X"] = [[":".join(map(str, pair)) for pair in zip(x, y)] for x, y in zip(data.event_type, data.touchpoint)]
    else:
        data["X"] = data.touchpoint
    

    for week in data.week.unique():
        week_data = data[data.week == week].reset_index(drop=True)
        week_ngram = create_ngram_dataset(week_data, 2, add_event_type="event_type" in data.columns, verbose=False)
        week_ngrams.append(week_ngram["ngrams"])
        week_maps.append(week_ngram["vocab_map"])    

    

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # initial embedding size check:
    # try dim [2 ... 64] with step 2 until overfitting or no lass decrease after 50 epochs
    epochs = 50
    batch_size = 128
    context_size = 2
    
    criterion = CrossEntropyLoss()
    embedding_starting_dim = 32
    
    embedding = SkipGram(set(week_maps[0].keys()), embedding_starting_dim, context_size=context_size, unknown_idx=True, padding_idx=True).to(device)
    best_embedding_dim, best_epochs, best_embedding_loss = adapt_embedding_size(
        embedding=embedding, 
        ngrams=week_ngrams[0], 
        criterion=criterion, 
        optimizer_function=AdamW, 
        epochs=epochs, 
        batch_size=batch_size, 
        context_size=context_size, 
        device=device)
    
    del embedding # free memory
    
    base_embedding = SkipGram(set(week_maps[0].keys()), embedding_dim=best_embedding_dim, context_size=context_size, unknown_idx=True, padding_idx=True).to(device)

    approaches = "ewc+mas"
    approach_params = {
        "fisher_info_new" : {
            "fisher_info": {}
        }
    }
    embedding_train_loop(
        embedding=base_embedding,
        ngrams=week_ngrams[0],
        criterion=criterion,
        optimizer=AdamW(base_embedding.parameters(), lr=1e-3),
        epochs=best_epochs,
        batch_size=batch_size,
        approaches=approaches,
        approach_params=approach_params,
        device=device
    )

    base_fisher_info = approach_params["fisher_info_new"]["fisher_info"]
    base_importance = approach_params["new_importance_weights"]

    base_models = {
        "lstm": LSTMClassifier(input_size=best_embedding_dim, hidden_size=best_embedding_dim, output_size=1, num_layers=1),
        "transformer": TransformerEncoderClassifier(input_size=best_embedding_dim, output_size=1, num_heads=4, num_layers=2)
    }

    results = {
        "approach": [],
        "embedding_size": [],
        "context_size": [],
        "strat": [],
        "auc": [],
        "f1": [],
        "train_f1": [],
        "train_auc": [],
        "week": [],
        "epochs": [],
        "train_tokens": [],
        "unknown_tokens": [],
        "removed_tokens": [],
        "models": [],
        "is_dynamic": [],
        "is_unused_tokens": []
        # "embedding_train_time": [],
        # "model_train_time": []
    }

    # over_sampler = RandomOverSampler(sampling_strategy=0.2, random_state=9)
    under_sampler = RandomUnderSampler(
        sampling_strategy=lambda y: {
            0: int(1.5 * sum(y == 1)),
            1: sum(y == 1)
        }, 
        random_state=9
    )
    # 
    sampling_pipe = Pipeline(steps=[('u', under_sampler)])

    X, y = data.X[data.week == 0].to_numpy(), data[data.week == 0].purchase.to_numpy()
    X_test, y_test = data.X[data.week == 1].to_numpy(), data[data.week == 1].purchase.to_numpy()
    X_id, y = sampling_pipe.fit_resample(np.arange(len(X)).reshape(-1, 1), y)
    X = X[X_id.flatten()]
    X_id, y_test = sampling_pipe.fit_resample(np.arange(len(X_test)).reshape(-1, 1), y_test)
    X_test = X_test[X_id.flatten()]

    for name, model in base_models.items():
        torch.cuda.empty_cache()
        epoch, f1, auc, train_f1, train_auc = model_train_loop(
            model=model.to(device),
            embedding=base_embedding,
            X_train=X,
            y_train=y,
            X_test=X_test,
            y_test=y_test,
            criterion=BCEWithLogitsLoss(),
            optimizer=AdamW(model.parameters(), lr=1e-3),
            epochs=epochs,
            batch_size=batch_size,
            device=device
        )
        model.cpu() # free memory on gpu

        results["approach"].append("base")
        results["embedding_size"].append(best_embedding_dim)
        results["context_size"].append(context_size)
        results["strat"].append("base")
        results["auc"].append(auc)
        results["f1"].append(f1)
        results["train_f1"].append(train_f1)
        results["train_auc"].append(train_auc)
        results["week"].append(0)
        results["epochs"].append(epoch)
        results["train_tokens"].append(len(base_embedding.embedding.vocab_map))
        results["unknown_tokens"].append(len(base_embedding.embedding.vocab_map.keys() - set(week_maps[1].keys())))
        results["removed_tokens"].append(0)
        results["models"].append(name)
        results["is_dynamic"].append(False)
        results["is_unused_tokens"].append(False)
    
    strats = ["None", "avg", "random", "unknown"] # Which strategy to apply to update the embedding with the new week data
    approaches = ["None", "ewc", "mas", "lwf", "ewc+mas", "ewc+lwf", "mas+lwf", "ewc+mas+lwf"]

    dynamics = [False, True] # if the embedding is extended dynamicly or not for the new week data
    remove_unused_classes = [True, False] # if the unused classes should be removed after each week


    base_embedding.cpu() # after here base_embedding is only used for copies so move to cpu
    for strat in strats:
        for approach in approaches:
            for dynamic in dynamics:
                for remove_unused in remove_unused_classes:
                    # before starting week iterations set the embedding to the base embedding that is trained for week 0
                    # and set the fisher information to the base fisher information
                    # and set the importance to the base importance
                    embedding_old = copy.deepcopy(base_embedding)
                    old_fisher_info = base_fisher_info
                    old_importance = base_importance
                    # reset best_parameter to the one from base
                    best_embedding_dim = base_embedding.embedding.embedding_dim
                    ######## Start the loop for the new data, here weeks ############
                    for week in range(1, data.week.max()):
                        if dynamic and strat == "None":
                            break
                        # 
                        embedding = copy.deepcopy(embedding_old)
                        print(f"Week {week}, Approach: {approach}, Strategy: {strat}, Dynamic: {dynamic}")
    ######################### 1 part #############################
    # ####################### update embedding #######################                    
                        # update the embedding with the new week data
                        token_map = embedding.embedding.vocab_map.keys() | week_maps[week].keys()
                        if strat == "random":
                            embedding.update_embedding(token_map)
                        elif strat == "unknown":
                            embedding.update_embedding(token_map, strat_select_id, {"idx": embedding.embedding.unknown_idx})
                        elif strat == "avg":
                            embedding.update_embedding(token_map, strat_avg_all)
                        else:
                            # if no strategy is given a new embedding is created for the new week data
                            embedding = SkipGram(token_map, embedding_dim=embedding.embedding.embedding_dim, context_size=context_size, unknown_idx=True, padding_idx=True).to(device)
    ######################### 1. part end #############################

    ######################### 2. part ################################
    ######################### check overfitting ######################
                        if dynamic and strat != "None":
                            best_embedding_dim, best_epochs, best_embedding_loss = adapt_embedding_size(
                                copy.deepcopy(embedding).to(device), week_ngrams[week], criterion, AdamW, epochs, batch_size, context_size, device
                            )
    ######################### 2. part end #############################

    ######################### 3. part ################################
    ######################### train embeddings w. CL #################
                        approach_params = {}
                        if "lwf" in approach:
                            approach_params["lwf"] = {
                                "old_model": embedding_old,
                            }
                        if "mas" in approach:
                            approach_params["mas"] = {
                                "old_model": embedding_old,
                                "importance_weights": old_importance,
                            }
                        if "ewc" in approach:
                            approach_params["ewc"] = {
                                "old_model": embedding_old,
                                "fisher_info": old_fisher_info
                            }
                            approach_params["fisher_info_new"] = {
                                "fisher_info": {}
                            }
                        
                        # if the embeddings size needs to be increased
                        if best_embedding_dim > embedding.embedding.embedding_dim:
                            embedding.extend_embedding(best_embedding_dim)
                        # if the embedding size needs to be decreased
                        elif best_embedding_dim < embedding_old.embedding_dim:
                            embedding = SkipGram(token_map, embedding_dim=best_embedding_dim, context_size=context_size, unknown_idx=True, padding_idx=True).to(device)

                        embedding.to(device)
                        embedding_train_loop(
                            embedding=embedding,
                            ngrams=week_ngrams[week],
                            criterion=criterion,
                            optimizer=AdamW(embedding.parameters(), lr=1e-3),
                            epochs=5,
                            batch_size=batch_size,
                            approaches=[approach],
                            approach_params=approach_params,
                            device=device
                        )
                        # update the fisher information and importance weights for the next iteration
                        if "fisher_info_new" in approach_params:
                            old_fisher_info = approach_params["fisher_info_new"]["fisher_info"]
                        if "new_importance_weights" in approach_params:
                            old_importance = approach_params["new_importance_weights"]
    ######################### 3. part end #############################

    ######################### 4. part ################################
    ######################### remove useless classes #################
                        removed_tokens = 0
                        if remove_unused:
                            tokens = len(embedding.embedding.vocab_map)
                            embedding.remove_unused_classes()
                            removed_tokens = tokens - len(embedding.embedding.vocab_map)
                            print(f"Removed {removed_tokens} unused classes from embedding. New size: {len(embedding.embedding.vocab_map)}")
    ######################### 4. part end #############################

    ######################### 5. part ################################
    ######################### train model ############################
                        X, y = data.X[data.week == week].to_numpy(), data[data.week == week].purchase.to_numpy()
                        X_test, y_test = data.X[data.week == week + 1].to_numpy(), data[data.week == week + 1].purchase.to_numpy()
                        X_id, y = sampling_pipe.fit_resample(np.arange(len(X)).reshape(-1, 1), y)
                        X = X[X_id.flatten()]
                        X_id, y_test = sampling_pipe.fit_resample(np.arange(len(X_test)).reshape(-1, 1), y_test)
                        X_test = X_test[X_id.flatten()]

                        for model_name in ["lstm"]:
                            model = None
                            if model_name == "lstm":
                                model = LSTMClassifier(input_size=embedding.embedding.embedding_dim, hidden_size=embedding.embedding.embedding_dim, output_size=1, num_layers=1).to(device)
                            elif model_name == "transformer":
                                model = TransformerEncoderClassifier(input_size=embedding.embedding.embedding_dim, output_size=1, num_heads=4, num_layers=2).to(device)
                            
                            torch.cuda.empty_cache()
                            epoch, f1, auc, train_f1, train_auc = model_train_loop(
                                model=model,
                                embedding=embedding,
                                X_train=X,
                                y_train=y,
                                X_test=X_test,
                                y_test=y_test,
                                criterion=BCEWithLogitsLoss(),
                                optimizer=AdamW(model.parameters(), lr=1e-3),
                                epochs=epochs,
                                batch_size=batch_size,
                                device=device
                            )

                            model.cpu() # free memory on gpu
                            results["approach"].append(approach)
                            results["embedding_size"].append(embedding.embedding.embedding_dim)
                            results["context_size"].append(context_size)
                            results["strat"].append(strat)
                            results["auc"].append(auc)
                            results["f1"].append(f1)
                            results["train_f1"].append(train_f1)
                            results["train_auc"].append(train_auc)
                            results["week"].append(week)
                            results["epochs"].append(epoch)
                            results["train_tokens"].append(len(embedding.embedding.vocab_map))
                            results["unknown_tokens"].append(len(week_maps[week+1].keys() - set(embedding.embedding.vocab_map.keys())))
                            results["removed_tokens"].append(removed_tokens)
                            results["is_unused_tokens"].append(remove_unused)
                            results["models"].append(model_name)
                            results["is_dynamic"].append(dynamic)
                        
                        pd.DataFrame(results).to_csv(f"results/{now}_results_{use_case}.csv", index=False)
                        embedding_old = copy.deepcopy(embedding)
                        embedding.cpu() # free memory on gpu
                        del model
                        del embedding
                        torch.cuda.empty_cache()
######################### 5. part end #############################

if __name__ == "__main__":
    if len(sys.argv) > 1:
        use_case = sys.argv[1]
        print(f"Use case: {use_case}")
        main(use_case)
    