# main.py

import os
import time
import datetime
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from config import get_args
from models import get_model
from dataset import get_cifar10, get_dataloaders
from optimizer import CLAGR, SAM
from engine import train_one_epoch, evaluate

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def main():
    # 1. Get parameters
    args = get_args()
    
    # 2. Set up environment
    setup_seed(args.seed)
    device = torch.device(args.device)
    output_path = os.path.join(args.output_dir, f"{args.model}_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}")
    os.makedirs(output_path, exist_ok=True)
    print(f"Arguments: {vars(args)}")
    print(f"Output path: {output_path}")

    print("Loading dataset...")
    train_data, val_data, n_classes = get_cifar10(args.datadir)
    train_loader, val_loader = get_dataloaders(args, train_data, val_data)
    print(f"Train data: {len(train_data)}, Test data: {len(val_data)}, Classes: {n_classes}")

    print(f"Building model: {args.model}")
    model = get_model(args.model, num_classes=n_classes).to(device)
    
    criterion = nn.CrossEntropyLoss()
    
    # base optimizer
    if args.base_optimizer.lower() == 'sgd':
        base_optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.base_optimizer.lower() == 'adam':
         base_optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError(f"Base optimizer {args.base_optimizer} not supported.")

    if args.optimizer.lower() == 'sam':
        optimizer = SAM(
            model.parameters(),
            base_optimizer=base_optimizer,
            rho=args.rho,
        )
    elif args.optimizer.lower() == 'clagr':
        optimizer = CLAGR(
        model.parameters(),
        base_optimizer=base_optimizer,
        rho=args.rho,
        inner_step=args.inner_step,
        cr_lambda=args.cr_lambda,
        lmomentum=args.lmomentum,
        gamma_interp=args.gamma_interp,
        beta_start=args.beta_start,
        beta_end=args.beta_end,
    )
    print(f"Using Optimizer: {args.optimizer.upper()} with {args.base_optimizer.upper()}")

    if args.lr_scheduler == 'CosineAnnealingLR':
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(base_optimizer, T_max=args.epochs, eta_min=args.eta_min)
    else:
        lr_scheduler = optim.lr_scheduler.ConstantLR(base_optimizer, factor=1.0, total_iters=args.epochs)
    
    # 6. started training
    print(f"Start training for {args.epochs} Epochs on {device}.")
    max_acc = 0.0
    
    for epoch in range(args.epochs):
        start_epoch_time = time.time()
        
        train_stats = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, args.log_freq)
        lr_scheduler.step()
        val_stats = evaluate(model, val_loader, criterion, device)
        
        test_loss, test_acc = val_stats["test_acc5"], val_stats["test_acc1"]
        train_acc, train_loss = train_stats['train_acc1'], train_stats["train_loss"]
        epoch_time = time.time() - start_epoch_time
        
        if test_acc > max_acc:
            max_acc = test_acc
            torch.save(model.state_dict(), os.path.join(output_path, 'best_model.pth'))
            print(f"**** New best model saved with accuracy: {max_acc:.2f}% ****")

        print(
            f"Epoch [{epoch+1}/{args.epochs}] | "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}% (Max: {max_acc:.2f}%) | "
            f"LR: {lr_scheduler.get_last_lr()[0]:.5f} | Time: {epoch_time:.2f}s"
        )
    
    print(f"Training finished. Max Test Accuracy: {max_acc:.2f}%")
    


if __name__ == '__main__':
    main()