# Standard Library Imports
import os

# Third-Party Library Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# Local Imports
from layers import encoder
from utils import save_dict

class classifier(nn.Module):
    def __init__(self):
        super(classifier, self).__init__()

        self.encoder = encoder(in_channels=4)

        self.fc1 = nn.Linear(500, 500)
        self.fc2 = nn.Linear(500, 500)
        self.fc3 = nn.Linear(500, 1)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):

        # Encoder
        out = self.encoder(x)

        # First layer
        out = self.fc1(out)
        out = self.relu(out)

        # Second layer
        out = self.fc2(out)
        out = self.relu(out)

        # 3rd layer
        out = self.fc3(out)
        out = self.sigmoid(out).squeeze(-1)

        return out
    
    def loss(self, pred_probs, true_vals):
        return torch.nn.functional.binary_cross_entropy(
            pred_probs, true_vals, reduction="mean"
        )
    
    def train_and_validate(self, dataloaders, optimizer, save_path, num_epochs=20, device='cuda', real_data=False):
        self.to(device)
        ops = ["train", "val"]
        best_acc = 0.0

        for epoch in range(num_epochs):

            print(f"Epoch [{epoch+1}/{num_epochs}]")
            
            for op in ops:
                if op == "train":
                    torch.set_grad_enabled(True)
                    self.train()
                elif op == "val":
                    torch.set_grad_enabled(False)
                    self.eval()

                loader = dataloaders[op]

                running_loss = 0.0
                running_acc = 0.0
                if real_data == True:
                    loader.dataset.on_epoch_start()

                for _, data in enumerate(tqdm(loader)):
                    x, y = data
                    x = x.float().to(device)
                    y = y.float().to(device)
                    pred_probs = self(x).squeeze(-1)

                    # Compute metrics
                    loss = self.loss(pred_probs, y)
                    y_hat = (pred_probs >= 0.5)*1.0
                    acc = torch.sum(y_hat == y)/len(y)
                    running_loss += loss.item()
                    running_acc += acc.item()

                    if op == "train":
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                epoch_loss = running_loss/len(loader)
                epoch_acc = running_acc/len(loader)

                if op == "train":
                    print(f"Train Loss: {epoch_loss:.4f}, Train Acc: {epoch_acc:.4f}")

                if op == "val":
                    print(f"Val Loss: {epoch_loss:.4f}, Val Acc: {epoch_acc:.4f}")
                    results_dict = {
                        "val_loss": epoch_loss,
                        "val_acc": epoch_acc,
                        "epoch": epoch,
                    }

                    if epoch_acc >= best_acc:
                        best_acc = epoch_acc
                        print("Saving new model!")
                        save_dict(results_dict, save_path, "results.txt")
                        torch.save(self.state_dict(), os.path.join(save_path, "clf.pt"))
                


