import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import torchvision
from torchvision import transforms
from torchvision.models import resnet18, resnet50
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from prepare_datasets import get_transformations

# def train_resnet(trainloader, testloader, epochs=25,learning_rate=0.001, if_resnet18 = True, num_classes=10, pretrained="No"):
#     scaler = GradScaler() 
#     num_epochs = epochs
#     learning_rate = learning_rate

#     # Pretrained Handling
#     if if_resnet18:
#         if pretrained != "No":
#             model = resnet18(weights="IMAGENET1K_V1")
#         else:
#             model = resnet18(weights=None, num_classes=num_classes)
#     else:
#         model = resnet50(weights=None, num_classes=num_classes)

#     if pretrained == "LastLayer":
#         for param in model.parameters():
#             param.requires_grad = False
#     if pretrained!= "No":
#         model.fc = nn.Linear(model.fc.in_features, num_classes)
#         if pretrained == "LastLayer":
#             for param in model.fc.parameters():
#                 param.requires_grad = True
#     # MODEL PARAMETER DIFFS FIXED

#     model = nn.DataParallel(model)
#     model.to('cuda')
#     criterion = nn.CrossEntropyLoss()
#     optimizer = optim.Adam(model.parameters(), lr=learning_rate)
#     # optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
#     scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
#     print(f"Using {torch.cuda.device_count()} GPUs")
#     for epoch in range(num_epochs):
#         model.train()
#         running_loss = 0.0
#         for inputs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
#             inputs, labels = inputs.to('cuda'), labels.to('cuda')
    
#             optimizer.zero_grad()
    
#             with autocast(): 
#                 outputs = model(inputs)
#                 loss = criterion(outputs, labels)
    
#             scaler.scale(loss).backward() 
#             scaler.step(optimizer)
#             scaler.update()
    
#             running_loss += loss.item()
    
#         scheduler.step()
#         print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader):.4f}")
    
#     model.eval()
#     correct, total = 0, 0
#     with torch.no_grad():
#         for inputs, labels in testloader:
#             inputs, labels = inputs.to('cuda'), labels.to('cuda')
#             outputs = model(inputs)
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
    
#     acc = 100 * correct / total
    
#     return acc

def train_resnet(trainloader, testloader, validation_loader=None, epochs=25, learning_rate=0.001,
                 if_resnet18=True, num_classes=10, pretrained="No"):
    scaler = GradScaler()
    num_epochs = epochs

    # --- Model Setup ---
    if if_resnet18 and pretrained == "No":
        model = resnet18(weights="IMAGENET1K_V1" if pretrained != "No" else None, num_classes=num_classes)
    elif if_resnet18 and pretrained != "No":
        model = resnet18(weights="IMAGENET1K_V1" if pretrained != "No" else None)
    else:
        model = resnet50(weights=None, num_classes=num_classes)

    if pretrained == "LastLayer":
        for param in model.parameters():
            param.requires_grad = False

    if pretrained != "No":
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        if pretrained == "LastLayer":
            for param in model.fc.parameters():
                param.requires_grad = True

    model = nn.DataParallel(model).to('cuda')
    criterion = nn.CrossEntropyLoss()
    if learning_rate < 0.05:
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    else:
        optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    print(f"Using {torch.cuda.device_count()} GPUs")

    # --- Early stopping setup ---
    use_early_stopping = validation_loader is not None
    best_acc = 0.0
    patience = 10
    epochs_without_improvement = 0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to('cuda'), labels.to('cuda')

            optimizer.zero_grad()
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()

        scheduler.step()
        print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader):.4f}")

        # --- Validation for early stopping ---
        if use_early_stopping:
            model.eval()
            correct, total = 0, 0
            with torch.no_grad():
                for inputs, labels in validation_loader:
                    inputs, labels = inputs.to('cuda'), labels.to('cuda')
                    outputs = model(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            val_acc = 100 * correct / total
            print(f"Validation Accuracy: {val_acc:.2f}%")

            if val_acc > best_acc:
                best_acc = val_acc
                epochs_without_improvement = 0
                best_model_state = model.state_dict()
            else:
                epochs_without_improvement += 1
                if epochs_without_improvement >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break
        torch.cuda.empty_cache()

    # --- Restore best model if early stopping used ---
    if use_early_stopping and best_model_state is not None:
        model.load_state_dict(best_model_state)

    # --- Final test evaluation ---
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to('cuda'), labels.to('cuda')
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    acc = 100 * correct / total

    return acc