import torch
import torch.optim as optim
import numpy as np
from tqdm.auto import tqdm
from src.utils.utility import get_lr_scheduler
from src.utils.test import test_classification_model
from transformers import (
    TrainingArguments,
    Trainer,
    TrainerCallback,
)
from sklearn.metrics import accuracy_score


def train_classification_model_step(
    model,
    train_loader,
    optimizer,
    loss_fn,
    epoch,
    device,
    lr_scheduler=None,
):
    model.train()
    losses = []
    for batch_idx, (data, target) in enumerate(
        tqdm(train_loader, total=len(train_loader))
    ):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output.logits, target)
        loss.backward()
        optimizer.step()
        if lr_scheduler is not None:
            lr_scheduler.step(epoch + batch_idx / len(train_loader))
        losses.append(loss)
    total_loss = torch.mean(torch.tensor(losses)).item()
    print("Train Epoch: {} \tLoss: {}".format(epoch, total_loss))
    return model, total_loss


def train_classification_model(
    model,
    train_loader,
    loss_fn,
    device,
    lr,
    num_epochs,
    val_loader=None,
    lr_scheduler: None | str = None,
    patience: int = 5,
    weight_decay: float = 1e-6,
    save_path: None | str = None,
):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    lrs = get_lr_scheduler(lr_scheduler=lr_scheduler, optimizer=optimizer)
    best_loss = torch.inf
    counter = 0
    for e in range(1, num_epochs + 1):
        model, epoch_loss = train_classification_model_step(
            model=model,
            train_loader=train_loader,
            optimizer=optimizer,
            loss_fn=loss_fn,
            epoch=e,
            device=device,
            lr_scheduler=(
                lrs.scheduler
                if lrs.scheduler is not None and lrs.is_batch_scheduler
                else None
            ),
        )

        if val_loader is not None:
            val_loss, val_acc = test_classification_model(
                model, val_loader, loss_fn, device
            )

            if lrs.scheduler is not None and (not lrs.is_batch_scheduler):
                lrs.scheduler.step(val_loss)

            if best_loss > val_loss:
                best_loss = val_loss
                counter = 0
                if save_path is not None:
                    torch.save(
                        {
                            "best_epoch": e + 1,
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                            "loss": best_loss,
                        },
                        save_path,
                    )

            counter += 1

            if counter > patience:
                print(
                    "Early Stopping: Epoch {} \t Best Val Loss {}".format(e, best_loss)
                )
                break
        else:
            if save_path is not None:
                torch.save(
                    {
                        "best_epoch": e + 1,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "loss": epoch_loss,
                    },
                    save_path,
                )

    return model


def train_hf_model(
    model,
    train_dataset,
    val_dataset,
    test_dataset,
    num_epochs,
    model_name,
    batch_size=32,
    lr=2e-5,
):
    class EarlyStoppingCallback(TrainerCallback):
        def __init__(self, early_stopping_patience=2):
            self.early_stopping_patience = early_stopping_patience
            self.best_metric = None
            self.patience_counter = 0

        def on_evaluate(self, args, state, control, metrics, **kwargs):
            current_metric = metrics.get("eval_accuracy")
            if self.best_metric is None:
                self.best_metric = current_metric
            elif current_metric > self.best_metric:
                self.best_metric = current_metric
                self.patience_counter = 0
                torch.save(model.state_dict(), f"{model_name}.pt")
            else:
                self.patience_counter += 1
                if self.patience_counter >= self.early_stopping_patience:
                    control.should_training_stop = True

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return {"accuracy": accuracy_score(labels, predictions)}

    training_args = TrainingArguments(
        output_dir="./results",
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=lr,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs,
        weight_decay=0.01,
        load_best_model_at_end=False,
        metric_for_best_model="accuracy",
        logging_dir="./logs",
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
    )
    trainer.train()
    model.load_state_dict(torch.load(f"{model_name}.pt"))
    test_results = trainer.predict(test_dataset)
    print(f"Test Accuracy: {test_results.metrics['test_accuracy']:.4f}")
