import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import argparse
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from sklearn.model_selection import StratifiedKFold
import os, shutil
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Parameters
N = 20  # Number of input tokens (equal to output tokens for now)
M = N   # Number of output tokens (equal to input tokens per requirements)
VOCAB_SIZE = 1 + N + M  # <bos> + input tokens + output tokens
P_CORRECT_TRAIN = 0.8  # Probability of y = f(x) and y' = f(x')
P_CORRECT_EVAL = 0.5  # Probability of y = f(x) and y' = f(x')

# SEQ_LENGTH = 5  # [<bos>, x, y, x', y']
USE_BOS = False
SEQ_LENGTH = 4 if USE_BOS else 3  # [x, y, x', y']
# EMBEDDING_DIM = 256
EMBEDDING_DIM = 2 * VOCAB_SIZE + SEQ_LENGTH
NUM_HEADS = 1
BATCH_SIZE = 128
NUM_LAYERS = 1
LEARNING_RATE = 3e-1
USE_FIRST_LAYER_MLP = False # if True, the first layer is MLP-only, the others attention-only. Otherwise, all layers are attention-only.

# Token mappings
BOS_TOKEN = 0
INPUT_TOKENS = list(range(1, N+1))
OUTPUT_TOKENS = list(range(N+1, N+M+1))

def pca_layers_plot(model,
                    f_map,
                    token_indices=(2, 3),
                    num_eval_samples=2000,
                    batch_size=BATCH_SIZE,
                    seed=2024,
                    max_points=2000,
                    fname_prefix="pca",
                    show=False):
    """
    One figure containing *every* token’s PCA grids.

    • Rows   → different probe tokens
    • Columns→ layers (embedding + each transformer block)
    • Points are coloured by `is_true` (truthfulness).

    File saved as  <fname_prefix>_alltokens.png
    """
    import math
    from sklearn.decomposition import PCA

    model.eval()
    rng = np.random.default_rng(seed)

    # -------- collect activations once --------
    acts = {tok: None for tok in token_indices}
    labels = []
    collected = 0
    while collected < num_eval_samples:
        cur_bs = min(batch_size, num_eval_samples - collected)
        x, is_true = gen_batch(rng, cur_bs, f_map, is_train=False)
        inp = torch.from_numpy(x[:, :-1]).to(device)

        with torch.no_grad():
            _, hids = model(inp, return_activations=True)   # list len = L+1

        for tok in token_indices:
            tok_str = "Y1" if tok == 2 else "X2"
            per_layer = [h[:, tok, :].cpu().numpy() for h in hids]
            acts[tok] = (per_layer if acts[tok] is None
                         else [np.concatenate([o, n], axis=0)
                               for o, n in zip(acts[tok], per_layer)])

        labels.append(is_true)
        collected += cur_bs

    labels = np.concatenate(labels)          # shape (N,)

    # --------- figure layout --------
    n_tok     = len(token_indices)
    n_layers  = len(next(iter(acts.values())))          # embed + blocks
    n_cols    = min(4, n_layers)
    n_rows    = n_tok * math.ceil(n_layers / n_cols)

    fig_w = 4 * n_cols
    fig_h = 3 * n_rows
    plt.figure(figsize=(fig_w, fig_h))
    plt.suptitle("Layer-wise PCA • coloured by truthfulness")

    # --------- draw every token’s grids --------
    for r, tok in enumerate(token_indices):
        tok_str = "Y1" if tok == 2 else "X2"
        vecs_per_layer = acts[tok]
        # compute PCA once per layer
        pcs_per_layer = [PCA(n_components=2).fit_transform(X)
                         for X in vecs_per_layer]

        for i, pcs in enumerate(pcs_per_layer):
            # global subplot index
            row_group = r * math.ceil(n_layers / n_cols)
            row  = row_group + (i // n_cols)
            col  = i % n_cols
            ax   = plt.subplot(n_rows, n_cols, row * n_cols + col + 1)

            # optional subsample for over-crowded plots
            if pcs.shape[0] > max_points:
                idx = rng.choice(pcs.shape[0], max_points, replace=False)
                ax.scatter(pcs[idx, 0], pcs[idx, 1],
                           c=labels[idx], s=6, alpha=0.7, cmap="coolwarm")
            else:
                ax.scatter(pcs[:, 0], pcs[:, 1],
                           c=labels, s=6, alpha=0.7, cmap="coolwarm")

            ax.set_xticks([]); ax.set_yticks([])
            ax.set_title(f"tok {tok_str}  •  L{i}", fontsize=8)

    plt.tight_layout(rect=[0, 0.03, 1, 0.96])
    out_name = f"{fname_prefix}_alltokens.png"
    plt.savefig(out_name, dpi=120, bbox_inches="tight")
    if show:
        plt.show()
    plt.close()



def eval_linear_separability(model,
                             f_map,
                             token_indices=(2, 3),
                             num_eval_samples=2000,
                             batch_size=BATCH_SIZE,
                             n_splits=5,
                             seed=12345):
    """
    k-fold (stratified) cross-validated linear probes.

    Returns
    -------
    auc_dict : dict[int, tuple[list[float], list[float]]]
        key   = token_index
        value = (mean_auc_per_layer, std_auc_per_layer)
    """
    model.eval()
    rng = np.random.default_rng(seed)
    token_indices = tuple(token_indices)

    # --------- collect hidden states & labels ---------
    acts = {tok: None for tok in token_indices}
    labels = []

    collected = 0
    while collected < num_eval_samples:
        cur_bs = min(batch_size, num_eval_samples - collected)
        x, is_true = gen_batch(rng, cur_bs, f_map, is_train=False)
        inp = torch.from_numpy(x[:, :-1]).to(device)

        with torch.no_grad():
            _, hids = model(inp, return_activations=True)   # list len = L+1

        for tok in token_indices:
            per_layer = [h[:, tok, :].cpu().numpy() for h in hids]
            if acts[tok] is None:
                acts[tok] = per_layer
            else:
                acts[tok] = [np.concatenate([old, new], axis=0)
                              for old, new in zip(acts[tok], per_layer)]

        labels.append(is_true)
        collected += cur_bs

    labels = np.concatenate(labels)     # (num_eval_samples,)

    # --------- train & score probes with CV ---------
    auc_dict = {}
    for tok, vecs_per_layer in acts.items():
        means, stds = [], []
        for X in vecs_per_layer:
            skf = StratifiedKFold(n_splits=n_splits,
                                  shuffle=True,
                                  random_state=seed)

            fold_aucs = []
            for train_idx, dev_idx in skf.split(X, labels):
                clf = SGDClassifier(loss="log_loss",
                                    max_iter=1000,
                                    tol=1e-4)
                clf.fit(X[train_idx], labels[train_idx])
                prob = clf.predict_proba(X[dev_idx])[:, 1]
                fold_aucs.append(roc_auc_score(labels[dev_idx], prob))

            means.append(np.mean(fold_aucs))
            stds .append(np.std (fold_aucs))
        auc_dict[tok] = (means, stds)

    return auc_dict


def plot_auc(auc_dict,
             step_idx=None,
             prefix="lin_sep",
             show=False):
    """
    Single figure that overlays every token’s mean ± std AUC curve.

    Saves:   <prefix>_alltokens_step{step}.png
    """
    layers = range(len(next(iter(auc_dict.values()))[0]))  # x-axis once
    plt.figure()

    for tok, (means, stds) in auc_dict.items():
        tok_str = "X1" if tok == 2 else "X2"
        plt.errorbar(layers, means, yerr=stds,
                     fmt='-o', capsize=3, label=f"token {tok_str}")

    plt.xlabel("Layer (0 = embedding)")
    plt.ylabel("AUC (mean ± 1σ)")
    if step_idx is not None:
        plt.title(f"Linear separability (step {step_idx})")
        fname = f"classification_plots/{prefix}_alltokens_step{step_idx}.png"
    else:
        plt.title("Linear separability")
        fname = f"classification_plots/{prefix}_alltokens.png"

    plt.ylim(0.0, 1.0)
    plt.grid(alpha=0.3)
    plt.xticks(layers)
    plt.legend()
    plt.tight_layout()
    plt.savefig(fname, dpi=120, bbox_inches="tight")
    if show:
        plt.show()
    plt.close()

# Create a random permutation function f for inputs to outputs
def create_random_mapping():
    # Create a random permutation for mapping input tokens to output tokens
    output_indices = np.random.permutation(N)
    # output_indices = np.arange(N)
    f_map = {INPUT_TOKENS[i]: OUTPUT_TOKENS[output_indices[i]] for i in range(N)}
    return f_map

def generate_single_sample(rng, f_map, is_train=True):
    # Sample x and x' randomly from input tokens
    x = rng.choice(INPUT_TOKENS)
    x_prime = rng.choice(INPUT_TOKENS)
    
    is_true = (rng.random() < P_CORRECT_TRAIN) if is_train else (rng.random() < P_CORRECT_EVAL)
    # Determine y and y' based on probability p
    if is_true:
        # With probability p, y = f(x) and y' = f(x')
        y = f_map[x]
        y_prime = f_map[x_prime]
    else:
        # Otherwise, y and y' are random output tokens
        y = rng.choice(OUTPUT_TOKENS)
        y_prime = rng.choice(OUTPUT_TOKENS)
    
    # Create sequence [<bos>, x, y, x', y']
    sequence = [BOS_TOKEN, x, y, x_prime, y_prime] if USE_BOS else [x, y, x_prime, y_prime]
    return (sequence, is_true)

def gen_batch(rng: np.random.Generator, batch_size: int, f_map: dict, is_train: bool = True):
    seqs = []
    is_trues = []
    for _ in range(batch_size):
        seq, is_true = generate_single_sample(rng, f_map, is_train )
        seqs += seq
        is_trues += [1 if is_true else 0]
    x = np.array(seqs).reshape(batch_size, SEQ_LENGTH + 1)
    is_true = np.array(is_trues)
    return x, is_true

def iterate_batches(f_map: dict,
                    batch_size: int = 20,
                    seed: int = 42,
                    is_train: bool = True):
    def worker(queue, rng):
        while True:
            x, is_true = gen_batch(rng, batch_size, f_map, is_train)
            queue.put((x, is_true))

    import multiprocessing as mp
    num_cores = max(1, mp.cpu_count() - 1)
    q = mp.Queue(maxsize=10000)
    processes = [mp.Process(target=worker, args=(q, np.random.default_rng([seed, i]))) for i in range(num_cores)]
    for p in processes:
        p.start()

    try:
        while True:
            x, is_true = q.get()
            yield (x[:,:-1], x[:,1:], is_true)
    except:
        for p in processes:
            p.kill()


# Multi-head self-attention layer with causal masking
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Combined projections for Q, K, V
        # self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # Calculate Q, K, V
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        # qkv = self.qkv_proj(x)
        # qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        # qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, batch_size, num_heads, seq_len, head_dim]
        # q, k, v = qkv[0], qkv[1], qkv[2]
        q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)
        
        # Calculate attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Apply causal mask (lower triangular)
        mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
        scores.masked_fill_(mask, float('-inf'))
        
        # Apply attention
        attn_weights = F.softmax(scores, dim=-1)
        self.attn_weights = attn_weights
        attn_output = torch.matmul(attn_weights, v)
        
        # Reshape and project output
        attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
        output = self.out_proj(attn_output)
        
        return output

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, no_attention=False, no_mlp=False, no_norm=False, dropout=0):
        super().__init__()
        self.attention = CausalSelfAttention(d_model, num_heads) if not no_attention else None
        self.norm1 = nn.RMSNorm(d_model, elementwise_affine=False) if not no_norm else nn.Identity()
        self.norm2 = nn.RMSNorm(d_model, elementwise_affine=False) if not no_norm else nn.Identity()
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        ) if not no_mlp else None
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Self-attention with residual connection
        if self.attention is not None:
            attn_output = self.attention(x)
            x = self.norm1(x + self.dropout(attn_output))
            # x = x + attn_output
        
        # Feed-forward with residual connection
        if self.feed_forward is not None:
            ff_output = self.feed_forward(x)
            x = self.norm2(x + self.dropout(ff_output))
            # x = x + ff_output
        
        return x

# Causal Transformer Model
class CausalTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, dropout=0, first_layer_mlp=False, no_norm=False):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.position_encoding = nn.Parameter(torch.zeros(1, SEQ_LENGTH, d_model))
        
        if first_layer_mlp:
            self.transformer_blocks = nn.ModuleList([
                TransformerBlock(d_model, num_heads, no_attention=(i == 0), no_mlp=(i > 0), no_norm=no_norm, dropout=dropout)
                for i in range(num_layers)
            ])
        else: # attention-only
            self.transformer_blocks = nn.ModuleList([
                TransformerBlock(d_model, num_heads, no_attention=False, no_mlp=True, no_norm=no_norm, dropout=dropout)
                for i in range(num_layers)
            ])
        
        self.output_layer = nn.Linear(d_model, vocab_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            
    def forward(self, x, return_activations=False):
        batch_size, seq_len = x.size()
        
        # Token embeddings + positional encodings
        token_emb = self.token_embedding(x)
        pos_enc = self.position_encoding[:, :seq_len, :]
        x = self.dropout(token_emb + pos_enc)
        
        # Apply transformer blocks
        activations = [x]
        for block in self.transformer_blocks:
            x = block(x)
            activations.append(x)
            
        # Output projection
        logits = self.output_layer(x)
        
        if return_activations:
            return logits, activations
        else:
            return logits

# Training function
def train_model(epochs=1, evaluate=True, num_steps=None, freeze_embeddings=False, one_hot=False, plot_attention_maps=False, no_norm=False,
                plot_classification=False, plot_pca=False):
    # Create dataset and dataloader
    f_map = create_random_mapping()
    
    # Initialize the model
    model = CausalTransformer(
        vocab_size=VOCAB_SIZE,
        d_model=EMBEDDING_DIM,
        num_heads=NUM_HEADS,
        num_layers=NUM_LAYERS,
        first_layer_mlp=USE_FIRST_LAYER_MLP,
        no_norm=no_norm
    ).to(device)

    if one_hot:
        assert EMBEDDING_DIM == 2 * VOCAB_SIZE + SEQ_LENGTH, "Embedding dimension must be equal to vocabulary size for one-hot encoding"
        model.token_embedding.weight.data.zero_()
        model.token_embedding.weight.data[torch.arange(VOCAB_SIZE), torch.arange(VOCAB_SIZE)] = 1
        model.position_encoding.data.zero_()
        model.position_encoding.data[0, torch.arange(SEQ_LENGTH), VOCAB_SIZE + torch.arange(SEQ_LENGTH)] = 1
        model.output_layer.weight.data.zero_()
        model.output_layer.weight.data[torch.arange(VOCAB_SIZE), VOCAB_SIZE + SEQ_LENGTH + torch.arange(VOCAB_SIZE)] = 1

        # for block in model.transformer_blocks[1:]:
        #     block.attention.qkv_proj.weight.data.zero_()
        #     block.attention.qkv_proj.weight.data[torch.arange(EMBEDDING_DIM), torch.arange(EMBEDDING_DIM)] = 1
        #     block.attention.out_proj.weight.data[torch.arange(EMBEDDING_DIM), torch.arange(EMBEDDING_DIM)] = 1
        #     block.attention.out_proj.weight.requires_grad = False


    if freeze_embeddings:
        model.token_embedding.requires_grad_(False)
        model.position_encoding.requires_grad_(False)
        model.output_layer.requires_grad_(False)

    if True:
        model.transformer_blocks[0].attention.q_proj.weight.data.zero_()
        model.transformer_blocks[0].attention.q_proj.weight.requires_grad_(False)
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss(reduction='none')
    # optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    
    res = []
    # Training loop
    model.train()
    total_loss = 0
    batch_idx = 0
    for inputs, targets, is_true in iterate_batches(f_map, batch_size=BATCH_SIZE, is_train=True):
        if num_steps is not None and batch_idx >= num_steps:
            break
        inputs, targets, is_true = torch.from_numpy(inputs).to(device), torch.from_numpy(targets).to(device), torch.from_numpy(is_true).to(device)
        
        # Forward pass
        logits = model(inputs)

        # Reshape logits and targets for loss calculation
        logits = logits.view(-1, VOCAB_SIZE)
        targets = targets.view(-1)

        # Calculate loss
        all_loss = criterion(logits, targets)
        loss = all_loss.mean()
        total_loss += loss.item()

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # loss on last token
        batch_size = inputs.size(0)
        with torch.no_grad():
            bos_idx = 1 if USE_BOS else 0
            first_pred_loss = all_loss.view(batch_size, -1)[is_true == 1, bos_idx].mean()
            last_pred_loss = all_loss.view(batch_size, -1)[is_true == 1, bos_idx + 2].mean()
            # Get logits for the first prediction (y) and last prediction (y')
            # Reshape logits to [batch_size, sequence_length, vocab_size]
            logits_reshaped = logits.view(batch_size, -1, VOCAB_SIZE)
            targets_reshaped = targets.view(batch_size, -1)
            
            probs = torch.softmax(logits_reshaped, dim=-1)
            # Get the logit at bos_idx for the vocab element corresponding to the next token (y)
            first_pred_prob = probs[is_true == 1, bos_idx, targets_reshaped[is_true == 1, bos_idx]].mean()
            # Get the logit at bos_idx+2 for the vocab element corresponding to the next token (y')
            last_pred_prob = probs[is_true == 1, bos_idx + 2, targets_reshaped[is_true == 1, bos_idx + 2]].mean()

        if (batch_idx + 1) % 100 == 0:
            print(f"Batch {batch_idx+1}, Loss: {loss.item():.4f}, First Pred Loss: {first_pred_loss.item():.4f}, Last Pred Loss: {last_pred_loss.item():.4f}, First Pred Prob: {first_pred_prob.item():.4f}, Last Pred Prob: {last_pred_prob.item():.4f}")
        batch_idx += 1

        if plot_classification and (batch_idx + 1) % 1000 == 0:
            auc_dict = eval_linear_separability(model, f_map,
                                        token_indices=(2, 3) if USE_BOS else (1, 2),
                                        num_eval_samples=1000)
            for tok, (mean_aucs, std_aucs) in auc_dict.items():
                    formatted = "  ".join(
                    f"L{idx}:{m:.3f}±{s:.3f}"
                    for idx, (m, s) in enumerate(zip(mean_aucs, std_aucs))
                )
                    print(f"[Probe] step {batch_idx:>6}  token {tok}  {formatted}")
            # Plots
            plot_auc(auc_dict, step_idx=batch_idx, show=False)

        if plot_pca and (batch_idx + 1) % 5000 == 0:
                pca_layers_plot(model, f_map,
                    token_indices=(2, 3) if USE_BOS else (1, 2),
                    num_eval_samples=2000,
                    fname_prefix=f"pca_plots/pca_step{batch_idx}",
                    show=False)
    
        # Plot attention maps every 100 batches
        if plot_attention_maps and (batch_idx + 1) % 1000 == 0:
            # Plot attention maps for first example in batch

            for i in range(len(model.transformer_blocks)):
                if i == 0 and USE_FIRST_LAYER_MLP:
                    continue
                plt.subplot(1, len(model.transformer_blocks), i + 1)
                plt.imshow(model.transformer_blocks[i].attention.attn_weights.detach().cpu().numpy()[i, 0])
                ticks = ['bos', 'x', 'y', "x'"] if USE_BOS else ['x', 'y', "x'"]
                plt.xticks(range(SEQ_LENGTH), ticks)
                plt.yticks(range(SEQ_LENGTH), ticks)
                plt.xlabel('Key')
                plt.ylabel('Query')
            # plt.show()
            plt.savefig(f'attention_maps/attention_maps_batch{batch_idx+1}.png')

        res.append({"loss": loss.item(), "first_pred_loss": first_pred_loss.item(), "last_pred_loss": last_pred_loss.item(), "batch_idx": batch_idx})
    
    # Save model
    torch.save(model.state_dict(), "causal_transformer_model.pt")
    
    # Evaluate model
    if evaluate:
        evaluate_model(model, f_map)
    
    return model, f_map, res

def evaluate_linear_separability(model, f_map, n_eval_samples=1000):
    """
    Evaluate the ability to linearly classify the truth values of examples on the eval set.
    Train a SGD classifier on the output of the model on each layer (over token 2 and 4).
    """
    



def evaluate_model(model, f_map):
    model.eval()
    
    # Check if the model learned the mapping function f
    print("Evaluating if model learned the mapping function f...")
    
    all_correct = 0
    for input_token in INPUT_TOKENS:
        # Create a test sequence [<bos>, input_token, ?]
        # We'll let the model predict the output token
        test_seq = torch.tensor([[BOS_TOKEN, input_token]], dtype=torch.long).to(device)
        
        # Get model prediction for the output token (position 2)
        with torch.no_grad():
            logits = model(test_seq)
            prediction = logits[0, 1, :].argmax().item()
        
        expected = f_map[input_token]
        is_correct = (prediction == expected)
        
        print(f"Input: {input_token}, Predicted: {prediction}, Expected: {expected}, Correct: {is_correct}")
        if is_correct:
            all_correct += 1
    
    f_accuracy = all_correct / len(INPUT_TOKENS)
    print(f"Accuracy for mapping function f: {f_accuracy:.4f}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--evaluate", action="store_true")
    parser.add_argument("--freeze_embeddings", action="store_true")
    parser.add_argument("--one_hot", action="store_true")
    parser.add_argument("--plot_attention_maps", action="store_true")
    parser.add_argument("--plot_classification", action="store_true")
    parser.add_argument("--plot_pca", action="store_true")
    parser.add_argument("--no_norm", action="store_true")
    parser.add_argument("--num_steps", type=int, default=None)
    args = parser.parse_args()
    if os.path.isdir("classification_plots"):
        shutil.rmtree("classification_plots")
    if os.path.isdir("pca_plots"):
        shutil.rmtree("pca_plots")
    if os.path.isdir("attention_maps"):
        shutil.rmtree("attention_maps")
    os.makedirs("classification_plots", exist_ok=True)
    os.makedirs("pca_plots", exist_ok=True)
    os.makedirs("attention_maps", exist_ok=True)
    model, f_map, res = train_model(epochs=args.epochs, evaluate=args.evaluate, num_steps=args.num_steps, freeze_embeddings=args.freeze_embeddings, one_hot=args.one_hot, plot_attention_maps=args.plot_attention_maps, no_norm=args.no_norm,
                                    plot_classification=args.plot_classification, plot_pca=args.plot_pca)
