from comet_ml import Experiment

import torch
import torch.nn as nn
import json
from tqdm import tqdm
from utils import set_global_seed
from cifar10_training import get_data, get_optimizer
import models

OPTIMIZER_NAME = 'AdamW'
MODEL_NAME = 'SimpleCNN'
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
SEED = 42
NUM_EPOCHS = 50
BATCH_SIZE = 128
USE_AUGMENTATIONS = False

API_KEY = ''
COMET_WORKSPACE = ""
COMET_PROJECT = ""


try:
    with open(f'tuning/cifar10/simplecnn/{OPTIMIZER_NAME}.json', 'r') as f:
        optimizer_params = json.load(f)
    optimizer_params.pop('val_score', None)
    optimizer_params.pop('test_score', None)
    print(f"Loaded parameters: {optimizer_params}")
except FileNotFoundError:
    print(f"Warning: No tuned parameters found for {OPTIMIZER_NAME}, using defaults")
    optimizer_params = {
        'lr': 0.01,
        'momentum': 0.9,
        'weight_decay': 0.001,
        'tmin': 2.0,
        'tmax': 20.0,
        'warmup_iters': 0.8,  # Fraction of total iterations
    }

optimizer_params['batch_size'] = BATCH_SIZE

set_global_seed(SEED)
print(f"Using device: {DEVICE}")

experiment = Experiment(project_name=COMET_PROJECT, 
                        workspace=COMET_WORKSPACE, 
                        auto_param_logging=False,
                        api_key=API_KEY)
experiment.set_name(f'{OPTIMIZER_NAME}_{MODEL_NAME}')


model = models.model_map[MODEL_NAME]().to(DEVICE)
print(f"Model: {MODEL_NAME}")


train_loader, eval_loader, test_loader = get_data(
    batch_size=BATCH_SIZE, 
    seed=SEED, 
    use_augmentations=USE_AUGMENTATIONS
)
print(f"Train samples: {len(train_loader.dataset)}")
print(f"Val samples: {len(eval_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")


n_iters = NUM_EPOCHS * len(train_loader)
optimizer, clipping = get_optimizer(
    OPTIMIZER_NAME, 
    model, 
    search_space=None, 
    trial=None, 
    optimizer_params=optimizer_params,
    n_iters=n_iters
)
print(f"Optimizer: {OPTIMIZER_NAME}")
if clipping:
    print(f"Gradient clipping: {clipping}")

criterion = nn.CrossEntropyLoss()

experiment.log_other('optimizer', OPTIMIZER_NAME)
experiment.log_other('model', MODEL_NAME)
experiment.log_other('seed', SEED)
experiment.log_other('num_epochs', NUM_EPOCHS)
experiment.log_other('batch_size', BATCH_SIZE)
experiment.log_other('use_augmentations', USE_AUGMENTATIONS)
experiment.log_other('clipping', clipping)


for key, value in optimizer_params.items():
    experiment.log_other(key, value)


def evaluate(model, loader, device):
    """Evaluate model accuracy"""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def collect_per_sample_losses(model, loader, device, criterion):
    """Collect per-sample losses for histogram"""
    model.eval()
    all_losses = []
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            # Per-sample losses (without reduction)
            losses = nn.CrossEntropyLoss(reduction='none')(outputs, labels)
            all_losses.extend(losses.cpu().numpy().tolist())
    
    return all_losses

print("\n" + "="*50)
print("Starting training...")
print("="*50 + "\n")

train_losses = []
val_accuracies = []
test_accuracies = []
lipschitz_constants = []
grad_norms = []

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    # Training loop
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        
        if clipping is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type='inf')
        
        optimizer.step()
        running_loss += loss.item()
        
        # Update progress bar
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Calculate metrics
    avg_loss = running_loss / len(train_loader)
    val_acc = evaluate(model, eval_loader, DEVICE)
    test_acc = evaluate(model, test_loader, DEVICE)
    
    
    val_sample_losses = collect_per_sample_losses(model, eval_loader, DEVICE, criterion)
    test_sample_losses = collect_per_sample_losses(model, test_loader, DEVICE, criterion)
    
    experiment.log_metrics({
        'train_loss': avg_loss,
        'val_acc': val_acc,
        'test_acc': test_acc,
    }, epoch=epoch+1)
    

    experiment.log_histogram_3d(
        val_sample_losses,
        name="val_loss_distribution",
        step=epoch+1
    )
    experiment.log_histogram_3d(
        test_sample_losses,
        name="test_loss_distribution",
        step=epoch+1
    )
    
    # Save metrics
    train_losses.append(avg_loss)
    val_accuracies.append(val_acc)
    test_accuracies.append(test_acc)
    
    # Print epoch summary
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
          f"Loss: {avg_loss:.4f} | "
          f"Val Acc: {val_acc:.2f}% | "
          f"Test Acc: {test_acc:.2f}% | ")

print("\n" + "="*50)
print("Training completed!")
print("="*50)
print(f"Best Val Accuracy: {max(val_accuracies):.2f}% at epoch {val_accuracies.index(max(val_accuracies))+1}")
print(f"Best Test Accuracy: {max(test_accuracies):.2f}% at epoch {test_accuracies.index(max(test_accuracies))+1}")
print(f"Final Val Accuracy: {val_accuracies[-1]:.2f}%")
print(f"Final Test Accuracy: {test_accuracies[-1]:.2f}%")

experiment.log_metrics({
    'best_val_acc': max(val_accuracies),
    'best_test_acc': max(test_accuracies),
    'final_val_acc': val_accuracies[-1],
    'final_test_acc': test_accuracies[-1],
})

results = {
    'optimizer': OPTIMIZER_NAME,
    'model': MODEL_NAME,
    'train_losses': train_losses,
    'val_accuracies': val_accuracies,
    'test_accuracies': test_accuracies,
    'best_val_acc': max(val_accuracies),
    'best_test_acc': max(test_accuracies),
    'optimizer_params': optimizer_params
}

experiment.end()

with open(f'training_log_{OPTIMIZER_NAME}.json', 'w') as f:
    json.dump(results, f, indent=2)
print(f"\nResults saved to training_log_{OPTIMIZER_NAME}.json")

