import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import LambdaLR
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import argparse
import os
import math
from timm.data import Mixup

def parse_args():
    parser = argparse.ArgumentParser("Post Evaluation Script")
    parser.add_argument('--adamw-lr', type=float, default=0.001, help='Learning rate for AdamW optimizer')
    parser.add_argument('--eta', type=int, default=2)
    parser.add_argument('--epoch', type=int,default=300)
    parser.add_argument('--train-dir', type=str, default='' ,help='The path of the dataset')
    parser.add_argument('--test-dir', default='',type=str, help='The path to the dataset directory')
    parser.add_argument('--bssl', action='store_true', help='Use BSSL')
    args = parser.parse_args()
    return args

args = parse_args()
train_dir = args.train_dir
test_dir = args.test_dir

num_classes = 1000
batch_size = 16
num_epochs = args.epoch
pretrain = False

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(root=test_dir, transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=10, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=10, pin_memory=True)

# 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from timm.data import Mixup

# 创建 CutMix (Mixup 会被禁用)
mixup_fn = Mixup(
    mixup_alpha=0.0,  
    cutmix_alpha=1.0,  
    label_smoothing=0,
    num_classes=num_classes
)


def train(model, loader, criterion, optimizer, teacher):
    model.train()
    running_loss = 0.0
    teacher.train()
    for batch_idx, (inputs, labels) in enumerate(loader):
        if len(inputs) % 2 != 0: 
            inputs = inputs[:-1]
            labels = labels[:-1]

        inputs, labels = mixup_fn(inputs, labels)
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = F.log_softmax(model(inputs)/20,dim=1)
        soft_labels = F.softmax(teacher(inputs)/20, dim=1)
        
        loss = criterion(outputs, soft_labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    return epoch_loss


def validate(model, loader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    loss_function = nn.CrossEntropyLoss()
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


def get_parameters(model):
    group_no_weight_decay = []
    group_weight_decay = []
    for pname, p in model.named_parameters():
        if pname.find('weight') >= 0 and len(p.size()) > 1:
            group_weight_decay.append(p)
        else:
            group_no_weight_decay.append(p)
    assert len(list(model.parameters())) == len(
        group_weight_decay) + len(group_no_weight_decay)
    groups = [dict(params=group_weight_decay), dict(
        params=group_no_weight_decay, weight_decay=0.)]
    return groups

# change model here if needed
model = models.resnet18(pretrained=False)
model = model.to(device)
model.train()

# BSSL settings
teacher = models.resnet18(pretrained=True)
teacher = teacher.to(device)
if args.bssl:
    teacher.train()
    print('using train mode for teacher')
else:
    teacher.eval()
    print('using eval mode for teacher')

if torch.cuda.device_count() > 1:
    print(f"🚀 Using {torch.cuda.device_count()} GPUs for training.")
    model = nn.DataParallel(model)

criterion = nn.KLDivLoss(reduction='batchmean')

optimizer = torch.optim.AdamW(get_parameters(model),
                                    lr=args.adamw_lr,
                                    weight_decay=1e-4)

scheduler = LambdaLR(optimizer,
                        lambda step: 0.5 * (1. + math.cos(math.pi * step / args.epoch / args.eta)) if step <= args.epoch else 0, last_epoch=-1)

for epoch in range(num_epochs):
    print(f"🚀 Starting Epoch {epoch+1}/{num_epochs}")
    train_loss = train(model, train_loader, criterion, optimizer, teacher)
    scheduler.step()

    print(f"✅ Epoch [{epoch+1}/{num_epochs}] Completed")
    print(f"📊 Train Loss: {train_loss:.4f}")
    
    if (epoch == num_epochs - 1) or ((epoch+1) % 10 == 0):
        val_loss, val_acc = validate(model, val_loader)
        print(f"📊 Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc * 100:.2f}%")
    print("-" * 50)
    