import torch
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from colorcubenet import CustomEfficientNet
from colorcube import ColorCubeTransform  # Import the ColorCubeTransform from colorcube.py
import os
import numpy as np

# File paths for saving the model and logging
output_dir = './output'
os.makedirs(output_dir, exist_ok=True)
file_name = os.path.join(output_dir, 'colorcubenet')

# Define device (GPU/CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define dataset paths
train_link = '/projects/emarasco/iiitd_patch/train'

# Define transformation
transform_train = transforms.Compose([
    transforms.Resize(256),          # Resize the image to 256x256 pixels
    transforms.RandomCrop(224),      # Randomly crop to 224x224 for data augmentation
    transforms.RandomHorizontalFlip(), # Randomly flip the image horizontally for augmentation
    ColorCubeTransform(),            # Convert to the custom ColorCube
    transforms.Normalize(mean=[0.485] * 9, std=[0.229] * 9)  # Normalize for 9 channels
])

# Load the full training dataset
full_train_dataset = datasets.ImageFolder(train_link, transform=transform_train)

# Compute class weights for loss function
class_counts = np.bincount(full_train_dataset.targets)
class_weights = 1. / class_counts
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Print dataset statistics
print("Training data: {}".format(len(train_dataset)))
print("Validation data: {}".format(len(val_dataset)))

# Initialize the model, loss function, and optimizer
model = CustomEfficientNet(num_classes=2)  # For binary classification (Live/Spoof)
model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(weight=class_weights)  # Using class weights in the loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Define early stopping class
class EarlyStopping:
    def __init__(self, patience=5, delta=0.01, verbose=False):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model_state = None  # Store the best model state here
        self.verbose = verbose

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()  # Store the model state
        elif val_loss > self.best_loss + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()  # Store the new best model state
            self.counter = 0

# Training function
def train(model, train_loader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, labels) in enumerate(train_loader, 1):
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Print progress for each batch
        if batch_idx % 10 == 0 or batch_idx == len(train_loader):
            print(f"Epoch [{epoch}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}")

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = correct / total

    print(f"Epoch [{epoch}] Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
    return epoch_loss, epoch_acc

# Evaluation function
def evaluate(model, val_loader, criterion, device, epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(val_loader, 1):
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Print progress for each batch
            if batch_idx % 10 == 0 or batch_idx == len(val_loader):
                print(f"Epoch [{epoch}], Validation Batch [{batch_idx}/{len(val_loader)}], Loss: {loss.item():.4f}")

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = correct / total

    print(f"Epoch [{epoch}] Validation Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
    return epoch_loss, epoch_acc

# Training loop with early stopping
num_epochs = 30
early_stopping = EarlyStopping(patience=5, delta=0.01, verbose=True)

for epoch in range(1, num_epochs + 1):
    # Train the model
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, epoch)
    
    # Validate the model
    val_loss, val_acc = evaluate(model, val_loader, criterion, device, epoch)
    
    # Early stopping check
    early_stopping(val_loss, model)

    # If early stopping condition is met
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break

# After training, save the best model only once using the given file name format
if early_stopping.best_model_state is not None:
    torch.save(early_stopping.best_model_state, file_name + '.pth')
    print(f"Best model saved as '{file_name + '.pth'}'")
