# Standard library imports
import os
import uuid

# Third party library imports
import wandb
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from tqdm import tqdm

class Classifier(nn.Module):
    def __init__(self, fine_tune=False):
        super().__init__()
        self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
        if fine_tune == True:
            for param in self.resnet.parameters():
                param.requires_grad = False
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.resnet(x)
        x = self.sigmoid(x)
        return x
    
    def loss(self, output, label):
        return nn.functional.binary_cross_entropy(output, label, reduction="mean")

    def train_model(self, dataloaders, optimizer, num_epochs, save_path, log=False, device="cuda"):
        self.to(device)
        if log == True:
            log_step = 50
            run_id = uuid.uuid4().hex[:4]
            wandb.init(project="CelebAHQ_v2", entity="beepulbharti", name=f"classifier-{run_id}")

        best_acc = 0.0
        ops = ["train", "val"]
        for _ in range(num_epochs):
            for op in ops:
                if op == "train":
                    torch.set_grad_enabled(True)
                    self.train()
                elif op == "val":
                    torch.set_grad_enabled(False)
                    self.eval()

                running_loss = 0.0
                running_acc = 0.0

                for i, data in enumerate(tqdm(dataloaders[op])):
                    imgs, labels = data

                    imgs = imgs.to(device)
                    labels = labels.to(device)
                    
                    output = self(imgs).squeeze()
                    prediction = (output >= 0.5).float()
                    loss = self.loss(output, labels)

                    running_loss += loss.item()
                    running_acc += (prediction == labels).sum().item() / labels.size(0)

                    if op == "train":
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                        if (i + 1) % log_step == 0:
                            wandb.log(
                                {
                                    f"train loss": running_loss / log_step,
                                    f"train acc": running_acc / log_step,
                                }
                            )
                            running_loss = 0.0
                            running_acc = 0.0

                if op == "val":
                    val_loss = running_loss/(i+1)
                    val_acc = running_acc/(i+1)

                    if val_acc > best_acc:
                        best_acc = val_acc
                        torch.save(self.state_dict(), os.path.join(save_path, "clf.pt"))
                    
                    wandb.log(
                        {
                            f"val_loss": val_loss,
                            f"val_accuracy": val_acc,
                        }
                    )

