import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import pickle
from loguru import logger
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix


class LSTMEmotionClassifier(nn.Module):
    def __init__(self, embedding_dim: int, hidden_dim: int, output_dim: int, num_layers: int, dropout: float):
        super().__init__()
        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):
        _, (hidden, _) = self.lstm(text)
        last_hidden = self.dropout(hidden[-1])
        output = self.fc(last_hidden)

        return F.log_softmax(output, dim=1)
    
# Define binary accuracy function
def binary_accuracy(preds, y):
    rounded_preds = torch.argmax(preds, dim=1)
    correct = (rounded_preds == y).float()
    acc = correct.sum() / len(correct)
    return acc

class SimpleDataset(Dataset):
  def __init__(self, X, y):
        self.X = X
        self.y = y.reshape(-1)

  def __len__(self):
        return len(self.X)

  def __getitem__(self, index):
        return self.X[index], self.y[index], index

def preprocess_sentence(sentence, tokenizer, vectors, padding=None):
    embedding = vectors.get_vecs_by_tokens(tokenizer(sentence), lower_case_backup=True)
    if padding is not None:
        embedding = embedding[:padding]
        if len(embedding) < padding:
            embedding = torch.concat([torch.zeros((padding - len(embedding), embedding.shape[1])), embedding], dim=0)
    return embedding

# Define train function
def train(model, optimizer, loss_fn, iterator, device):
    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for batch in iterator:
        text, labels, index = batch
        text, labels = text.to(device), labels.to(device)
        preds = model(text).squeeze(1)
        loss = loss_fn(preds, labels, index)
        acc = binary_accuracy(preds, labels)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, loss_fn, iterator, device):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()

    with torch.no_grad():
        for batch in iterator:
            text, labels, index = batch
            text, labels = text.to(device), labels.to(device)
            preds = model(text).squeeze(1)
            loss = loss_fn(preds, labels, index)
            acc = binary_accuracy(preds, labels)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def predict(model, sentence, tokenizer, vectors, device):
    model.eval()
    tokens = preprocess_sentence(sentence, tokenizer, vectors)
    tensor = tokens.clone().detach().unsqueeze(0).to(device)
    prediction = model(tensor)
    return torch.argmax(prediction.reshape(-1)).item()


class LabelSmoothedCrossEntropyLoss(nn.Module):
    # https://arxiv.org/pdf/1512.00567.pdf
    # and https://arxiv.org/pdf/1610.02242.pdf, or better: https://arxiv.org/pdf/2211.03044.pdf
    def __init__(self, num_samples, num_classes, device, smoothing=0.0, alpha_temporal=0.9, lambda_temporal=0.0):
        super().__init__()
        self.smoothing = smoothing
        self.alpha_temporal = alpha_temporal
        self.lambda_temporal = lambda_temporal
        self.ensemble_targets = torch.zeros(num_samples, num_classes, dtype=torch.float).to(device)
        self.num_classes = num_classes
        self.device = device
        
    def forward(self, logits, target, batch_indices):
        target = target.view(-1)
        if self.ensemble_targets.size(0) < max(batch_indices) + 1:
            self.ensemble_targets = torch.cat((self.ensemble_targets, torch.zeros(max(batch_indices) + 1 - self.ensemble_targets.size(0), self.num_classes).to(self.device)), dim=0)

        probs = F.softmax(logits, dim=-1)
        log_softmax = F.log_softmax(logits, dim=-1)
        nll_loss = F.nll_loss(log_softmax, target)
        smooth_loss = -log_softmax.mean(dim=-1).mean()
        try:
            ensemble_loss = F.kl_div(self.ensemble_targets[batch_indices], probs).mean()
            loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss + self.lambda_temporal * ensemble_loss
            self.ensemble_targets[batch_indices] = self.alpha_temporal * self.ensemble_targets[batch_indices] + (1 - self.alpha_temporal) * probs.detach()
        except RuntimeError as e:
            logger.warning(f"RuntimeError: {e} - settings loss to 0. Probably some problem with nll loss?")
            return None

        return loss


def train_model(X, y, tokenizer, vectors, label_encoder, max_epochs=30, n_training_steps=10000, early_stopping_epochs=5, batch_size=64, dropout=0.3, embedding_dim=300, hidden_dim=256, 
                num_layers=2, save_path=None, lr=1e-4, label_smoothing=0.1, 
                device='cpu', valid_size=0.2, padding=128):

    if isinstance(X, str):
        with open(X, "rb") as f:
            X = pickle.load(f)
    elif isinstance(X[0], str):
        X_n = []
        wrong_indices = []
        for i, sentence in enumerate(tqdm(X)):
            try:
                X_n.append(preprocess_sentence(sentence, tokenizer, vectors, padding))
            except RuntimeError as e:
                logger.warning(f"Skipping sentence {sentence} due to RuntimeError: {e}")
                X_n.append(None)
                wrong_indices.append(i)
        
        X = [x for x in X_n if x is not None]
        y = [y[i] for i in range(len(y)) if i not in wrong_indices]
    
    y = label_encoder.transform(y)
    X_train, X_valid, y_train, y_valid = train_test_split(X, np.array(y), test_size=valid_size, stratify=y)
    # Create the data iterators for training and validation sets
    train_iterator = DataLoader(SimpleDataset(X_train, y_train), batch_size=batch_size)
    valid_iterator = DataLoader(SimpleDataset(X_valid, y_valid), batch_size=batch_size)

    output_dim = len(np.unique(y_train))

    model = LSTMEmotionClassifier(
        embedding_dim, hidden_dim, output_dim, num_layers, dropout
    ).to(device)

    # Define the loss function and optimizer
    loss_fn = LabelSmoothedCrossEntropyLoss(smoothing=label_smoothing, num_samples=len(train_iterator), num_classes=len(np.unique(y)), device=device)
    valid_loss_fn = LabelSmoothedCrossEntropyLoss(smoothing=label_smoothing, num_samples=len(train_iterator), num_classes=len(np.unique(y)), device=device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Train the model with early stopping
    best_valid_loss = float('inf')
    early_stopping_counter = 0

    for epoch in range(max_epochs):

        train_loss, train_acc = train(model, optimizer, loss_fn, train_iterator, device)
        valid_loss, valid_acc = evaluate(model, valid_loss_fn, valid_iterator, device)

        logger.debug(f'Epoch: {epoch+1}')
        logger.debug(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
        logger.debug(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            early_stopping_counter = 0
            if save_path is not None:
                torch.save(model.state_dict(), save_path)
        else:
            early_stopping_counter += 1
            if early_stopping_counter == early_stopping_epochs:
                logger.debug('Early stopping')
                break

        if len(train_iterator) * (epoch + 1) > n_training_steps:
            break

    if save_path is not None:
        # Load the best model parameters and make predictions on test data
        model.load_state_dict(torch.load(save_path))

    return model

def evaluate_metrics(model, tokenizer, label_encoder, dataset_df, vectorizer, device):
    predictions = []
    text, labels = dataset_df["text"], dataset_df["label"]

    correct_indices = np.array([label in label_encoder.classes_ for label in labels])
    if not np.all(correct_indices):
        logger.warning(f"Found {len(correct_indices) - np.sum(correct_indices)} labels not in label encoder. Ignoring them.")

    text = text[correct_indices]
    labels = labels[correct_indices]
    labels = label_encoder.transform(labels)
    wrong_indices = []
    for i, sentence in enumerate(text):
        try:
            predictions.append(predict(model, sentence, tokenizer, vectorizer, device))
        except RuntimeError as e:
            logger.warning(f"Skipping sentence {sentence} due to RuntimeError: {e}")
            wrong_indices.append(i)
            predictions.append(None)
    
    predictions = [p for p in predictions if p is not None]
    labels = [labels[i] for i in range(len(labels)) if i not in wrong_indices]

    metrics = {
        "accuracy": float(accuracy_score(labels, np.array(predictions))), 
        "f1": float(f1_score(labels, np.array(predictions), average="macro")),
        "confusion": confusion_matrix(labels, np.array(predictions)).tolist()
    }
    return metrics