import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm
import argparse
import os
import datetime


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class CustomResNet(nn.Module):
    def __init__(self, num_classes, input_channels=3, base_channels=64):
        super(CustomResNet, self).__init__()

        self.conv1 = nn.Conv2d(
            input_channels,
            base_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(base_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(base_channels, base_channels, 2)
        self.layer2 = self._make_layer(
            base_channels, base_channels * 2, 2, stride=2
        )
        self.layer3 = self._make_layer(
            base_channels * 2, base_channels * 4, 2, stride=2
        )
        self.layer4 = self._make_layer(
            base_channels * 4, base_channels * 8, 2, stride=2
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(base_channels * 8, num_classes)

        self.dropout = nn.Dropout(0.5)

    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(out_channels),
            )

        layers = []
        layers.append(
            ResidualBlock(in_channels, out_channels, stride, downsample)
        )
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        return x


class EfficientCNN(nn.Module):
    """효율적인 CNN 모델 (MNIST, Fashion-MNIST용)"""

    def __init__(self, num_classes, input_channels=1):
        super(EfficientCNN, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


class SVHN_CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SVHN_CNN, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.3),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.3),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.3),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def get_transforms(dataset_name, train=True):
    if dataset_name.lower() == "cifar10":
        if train:
            return transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.ColorJitter(
                        brightness=0.2, contrast=0.2, saturation=0.2
                    ),
                    transforms.ToTensor(),
                ]
            )
        else:
            return transforms.Compose(
                [
                    transforms.ToTensor(),
                ]
            )

    elif dataset_name.lower() == "mnist":
        if train:
            return transforms.Compose(
                [
                    transforms.Resize((32, 32)),
                    transforms.RandomRotation(10),
                    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                    transforms.ToTensor(),
                ]
            )
        else:
            return transforms.Compose(
                [
                    transforms.Resize((32, 32)),
                    transforms.ToTensor(),
                ]
            )

    elif dataset_name.lower() == "fashionmnist":
        if train:
            return transforms.Compose(
                [
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomRotation(10),
                    transforms.ToTensor(),
                ]
            )
        else:
            return transforms.Compose(
                [
                    transforms.ToTensor(),
                ]
            )

    elif dataset_name.lower() == "svhn":
        if train:
            return transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.ColorJitter(
                        brightness=0.3, contrast=0.3, saturation=0.3
                    ),
                    transforms.ToTensor(),
                ]
            )
        else:
            return transforms.Compose(
                [
                    transforms.ToTensor(),
                ]
            )


def get_dataset(dataset_name, root="./data", download=True):
    train_transform = get_transforms(dataset_name, train=True)
    test_transform = get_transforms(dataset_name, train=False)

    if dataset_name.lower() == "cifar10":
        train_dataset = datasets.CIFAR10(
            root=root, train=True, download=download, transform=train_transform
        )
        test_dataset = datasets.CIFAR10(
            root=root, train=False, download=download, transform=test_transform
        )
        num_classes = 10
        input_channels = 3

    elif dataset_name.lower() == "mnist":
        train_dataset = datasets.MNIST(
            root=root, train=True, download=download, transform=train_transform
        )
        test_dataset = datasets.MNIST(
            root=root, train=False, download=download, transform=test_transform
        )
        num_classes = 10
        input_channels = 1

    elif dataset_name.lower() == "fashionmnist":
        train_dataset = datasets.FashionMNIST(
            root=root, train=True, download=download, transform=train_transform
        )
        test_dataset = datasets.FashionMNIST(
            root=root, train=False, download=download, transform=test_transform
        )
        num_classes = 10
        input_channels = 1

    elif dataset_name.lower() == "svhn":
        train_dataset = datasets.SVHN(
            root=root,
            split="train",
            download=download,
            transform=train_transform,
        )
        test_dataset = datasets.SVHN(
            root=root,
            split="test",
            download=download,
            transform=test_transform,
        )
        num_classes = 10
        input_channels = 3

    else:
        raise ValueError(f"Not supported dataset: {dataset_name}")

    return train_dataset, test_dataset, num_classes, input_channels


def get_model(model_name, num_classes, input_channels):
    if model_name.lower() == "resnet":
        return CustomResNet(num_classes, input_channels)
    elif model_name.lower() == "efficientcnn":
        return EfficientCNN(num_classes, input_channels)
    elif model_name.lower() == "svhncnn":
        return SVHN_CNN(num_classes)
    else:
        raise ValueError(f"Not supported model: {model_name}")


def train_epoch(
    model, train_loader, criterion, optimizer, device, scheduler=None
):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(train_loader, desc="Training")

    for batch_idx, (data, target) in enumerate(progress_bar):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        progress_bar.set_postfix(
            {"Loss": f"{loss.item():.4f}", "Acc": f"{100.*correct/total:.2f}%"}
        )

    if scheduler:
        scheduler.step()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100.0 * correct / total

    return epoch_loss, epoch_acc


def validate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc="Validation")

        for data, target in progress_bar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()

            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

            progress_bar.set_postfix({"Acc": f"{100.*correct/total:.2f}%"})

    test_loss /= len(test_loader)
    test_acc = 100.0 * correct / total

    return test_loss, test_acc


def main():
    parser = argparse.ArgumentParser(
        description="PyTorch Classifier Training Script"
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        choices=["cifar10", "mnist", "fashionmnist", "svhn"],
    )
    parser.add_argument(
        "--data_root", type=str, default="./data", help="data root directory"
    )
    parser.add_argument(
        "--batch_size", type=int, default=128, help="batch size"
    )

    parser.add_argument(
        "--model",
        type=str,
        default="auto",
        choices=["auto", "resnet", "efficientcnn", "svhncnn"],
    )

    parser.add_argument(
        "--epochs", type=int, default=100, help="number of training epochs"
    )
    parser.add_argument(
        "--lr", type=float, default=0.001, help="initial learning rate"
    )
    parser.add_argument(
        "--weight_decay", type=float, default=1e-4, help="weight decay"
    )
    parser.add_argument(
        "--patience", type=int, default=15, help="early stopping patience"
    )

    parser.add_argument(
        "--experiment_name", type=str, default=None, help="experiment name"
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="./results/classification",
        help="model save directory",
    )

    parser.add_argument(
        "--device", type=str, default="auto", help="device to use"
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=4,
        help="number of data loader workers",
    )

    args = parser.parse_args()

    if args.device == "auto":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(args.device)
    print(f"Using device: {device}")

    os.makedirs(args.save_dir, exist_ok=True)

    if args.experiment_name is None:
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        args.experiment_name = f"{args.dataset}_{args.model}_{timestamp}"

    if args.model == "auto":
        if args.dataset.lower() in ["mnist", "fashionmnist"]:
            model_name = "efficientcnn"
        elif args.dataset.lower() == "svhn":
            model_name = "efficientcnn"
        else:  # CIFAR-10
            model_name = "efficientcnn"
    else:
        model_name = args.model

    print(f"Dataset: {args.dataset.upper()}")
    print(f"Model: {model_name.upper()}")

    train_dataset, test_dataset, num_classes, input_channels = get_dataset(
        args.dataset, args.data_root
    )

    print(f"Train samples: {len(train_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    print(f"Number of classes: {num_classes}")
    print(f"Input channels: {input_channels}")

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True if device.type == "cuda" else False,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True if device.type == "cuda" else False,
    )

    model = get_model(model_name, num_classes, input_channels)
    model = model.to(device)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.weight_decay
    )

    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-6
    )
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    best_val_acc = 0.0
    patience_counter = 0

    print(f"\nStarting training for {args.epochs} epochs...")

    for epoch in range(args.epochs):
        print(f"\nEpoch {epoch+1}/{args.epochs}")
        print("-" * 50)

        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, scheduler
        )

        val_loss, val_acc = validate(model, test_loader, criterion, device)

        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0

            best_model_path = os.path.join(
                args.save_dir, f"{args.experiment_name}_best.pth"
            )
            torch.save(
                {
                    "epoch": epoch + 1,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "best_val_acc": best_val_acc,
                    "train_losses": train_losses,
                    "train_accs": train_accs,
                    "val_losses": val_losses,
                    "val_accs": val_accs,
                    "args": args,
                    "model_name": model_name,
                },
                best_model_path,
            )

            model.eval()
            dummy_input = torch.randn(1, input_channels, 32, 32, device=device)

            try:
                traced_model = torch.jit.trace(model, dummy_input)

                best_jit_path = os.path.join(
                    args.save_dir, f"{args.experiment_name}_best_jit.pt"
                )
                traced_model.save(best_jit_path)

                print(f"JIT model saved: {best_jit_path}")
            except Exception as e:
                print(f"Warning: JIT save failed - {e}")

            model.train()

            print(
                f"★ New best model saved! "
                f"Validation Accuracy: {best_val_acc:.2f}%"
            )
        else:
            patience_counter += 1

        if patience_counter >= args.patience:
            print(
                f"\nEarly stopping: {patience_counter} epochs without "
                f"improvement"
            )
            break

        if device.type == "cuda":
            torch.cuda.empty_cache()

    final_model_path = os.path.join(
        args.save_dir, f"{args.experiment_name}_final.pth"
    )
    torch.save(
        {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "final_val_acc": val_acc,
            "best_val_acc": best_val_acc,
            "train_losses": train_losses,
            "train_accs": train_accs,
            "val_losses": val_losses,
            "val_accs": val_accs,
            "args": args,
            "model_name": model_name,
        },
        final_model_path,
    )

    model.eval()
    dummy_input = torch.randn(1, input_channels, 32, 32, device=device)

    try:
        traced_model = torch.jit.trace(model, dummy_input)

        final_jit_path = os.path.join(
            args.save_dir, f"{args.experiment_name}_final_jit.pt"
        )
        traced_model.save(final_jit_path)

        simple_final_jit_path = os.path.join(
            args.save_dir, f"{args.dataset.lower()}_final_jit.pt"
        )
        traced_model.save(simple_final_jit_path)

        print(f"Final JIT model saved: {final_jit_path}")
        print(f"Final JIT model (simple): {simple_final_jit_path}")
    except Exception as e:
        print(f"Warning: Final JIT save failed - {e}")

    print("\n" + "=" * 60)
    print("TRAINING COMPLETED!")
    print("=" * 60)
    print(f"Dataset: {args.dataset.upper()}")
    print(f"Model: {model_name.upper()}")
    print(f"Total Epochs: {epoch + 1}")
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Final Validation Accuracy: {val_acc:.2f}%")
    print(f"Total Parameters: {total_params:,}")
    print(f"Best Model Saved: {best_model_path}")
    print(f"Final Model Saved: {final_model_path}")
    print("=" * 60)


if __name__ == "__main__":
    main()
