# Standard library imports
import argparse
import os

# Third party library imports
import torch
from torch.utils.data.dataloader import DataLoader
import wandb
from tqdm import tqdm

# Local Imports
from dataset import RSNADataset
from clf import HemorrhageDetector

"""
PREREQUISITES
"""
# Necessary function
def focal_loss(pred, target):
    alpha = 0.60 if weak_supervision else 0.86
    y_hat = pred * (target == 1) + (1 - pred) * (target == 0)
    gamma = 5 * (y_hat < 0.2) + 3 * (y_hat >= 0.2)
    loss = -alpha * (1 - y_hat) ** gamma * torch.log(y_hat)
    return loss.mean()

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Directory names
rsna_dir = '/export/gaon1/data/jteneggi/data/rsna-intracranial-hemorrhage-detection/RSNA'
model_dir = os.path.join("models")
os.makedirs(model_dir, exist_ok=True)

# Load data
weak_supervision = True
ops = ["train", "val"]
datasets = {
    op: RSNADataset(
        data_dir=rsna_dir,
        op=op,
        weak_supervision=weak_supervision,
    )
    for op in ops
}

dataloaders = {
    op: DataLoader(
        d,
        batch_size=4,
        num_workers=4,
        shuffle=op == "train",
        persistent_workers=True,
    )
    for op, d in datasets.items()
}

# Load classifier
model = HemorrhageDetector(
    encoder="resnet18",
    n_dim=128,
    hidden_size=64,
    embedding_dropout=0.50,
    attention_dropout=0.25,
    attention_activation="sparsemax",
)
model = model.to(device)

# Criterion and loss
criterion = focal_loss
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-5, weight_decay=1e-7)
optimizer.zero_grad()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.3)

# Training params
num_epochs = 15

# Main function
wandb.init(project="Hemorrhage_Detection", entity="beepulbharti", name="CT_clf_training")
if __name__ == "__main__":
    best_val_accuracy = 0
    running_loss = 0
    running_correct = 0
    running_count = 0
    for epoch_idx in range(num_epochs):
        model.train()
        torch.set_grad_enabled(True)

        for data in tqdm(dataloaders["train"]):
            images, norm_images, target = data
            n = target.size(0)

            norm_images = norm_images.to(device)
            target = target.to(device)

            output = model(norm_images, attention=weak_supervision)

            output = output.squeeze()
            target = target.squeeze()

            loss = criterion(output, target)
            prediction = output >= 0.5

            running_loss += loss.detach() * n
            running_correct += torch.sum(prediction == target)
            running_count += n

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if running_count == 52:
                wandb.log(
                    {
                        f"train_loss": running_loss / running_count,
                        f"train_accuracy": running_correct / running_count,
                    }
                )
                running_loss = 0.0
                running_correct = 0
                running_count = 0

        val_step = 1
        if (epoch_idx + 1) % val_step == 0:
            model.eval()
            torch.set_grad_enabled(False)

            val_loss = 0.0
            val_correct = 0

            for data in tqdm(dataloaders["val"]):
                images, norm_images, target = data
                n = target.size(0)

                norm_images = norm_images.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)

                output = model(norm_images, attention=weak_supervision)

                output = output.squeeze()
                target = target.squeeze()

                loss = criterion(output, target)
                prediction = output >= 0.5

                val_loss += loss.detach() * n
                val_correct += torch.sum(prediction == target)

            val_loss = val_loss / len(datasets["val"])
            val_accuracy = val_correct / len(datasets["val"])
            wandb.log(
                {
                    f"val_loss": val_loss.item(),
                    f"val_accuracy": val_accuracy.item(),
                }
            )

            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                count_no_improve = 0
                best_model = model
                torch.save(
                    best_model.state_dict(),
                    os.path.join(
                        model_dir,
                        f"{'wl_model' if weak_supervision else 'sl_model'}.pt",
                    ),
                )
        scheduler.step()