import argparse
import os
import random
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import autocast, GradScaler

import torchvision
from torchvision import transforms

try:
    import timm
except Exception:
    timm = None

# ===================== Model =====================
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.1):
        super().__init__()
        hidden_features = hidden_features or in_features
        out_features = out_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=True, attn_drop=dropout, proj_drop=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout=dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class DeiT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=192, depth=12, num_heads=3, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.n_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        self.blocks = nn.Sequential(*[
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Linear(embed_dim, num_classes)
        self.head_dist = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        nn.init.trunc_normal_(self.dist_token, std=0.02)
        nn.init.trunc_normal_(self.head.weight, std=0.02)
        nn.init.trunc_normal_(self.head_dist.weight, std=0.02)
        nn.init.zeros_(self.head.bias)
        nn.init.zeros_(self.head_dist.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        dist_tokens = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)
        x = self.norm(x)

        cls_output = x[:, 0]
        dist_output = x[:, 1]

        cls_logits = self.head(cls_output)
        dist_logits = self.head_dist(dist_output)

        if self.training:
            return cls_logits, dist_logits
        else:
            return (cls_logits + dist_logits) / 2


def get_deit_tiny(num_classes=1000, dropout=0.1):
    return DeiT(num_classes=num_classes, dropout=dropout, embed_dim=192, depth=12, num_heads=12)


# ===================== Utils =====================
class SmoothedValue:
    def __init__(self):
        self.total = 0.0
        self.count = 0

    def update(self, val, n=1):
        self.total += float(val) * n
        self.count += n

    @property
    def avg(self):
        return self.total / max(1, self.count)


def accuracy(output, target, topk=(1, 5)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def setup_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    cudnn.benchmark = True


# ===================== Distillation Loss =====================
class DistillationLoss(nn.Module):
    def __init__(self, base_criterion, teacher_model=None, distill_type="soft", alpha=0.5, tau=1.0):
        super().__init__()
        assert distill_type in {"soft", "hard", "none"}
        self.base_criterion = base_criterion
        self.teacher_model = teacher_model
        self.distill_type = distill_type
        self.alpha = alpha
        self.tau = tau
        if distill_type != "none":
            assert teacher_model is not None
            for p in self.teacher_model.parameters():
                p.requires_grad_(False)
            self.teacher_model.eval()

    def forward(self, inputs, outputs, targets):
        if isinstance(outputs, tuple):
            cls_logits, dist_logits = outputs
        else:
            cls_logits, dist_logits = outputs, None

        base_loss = self.base_criterion(cls_logits, targets)
        if dist_logits is not None:
            base_loss = 0.5 * (base_loss + self.base_criterion(dist_logits, targets))

        if self.distill_type == "none":
            return base_loss

        with torch.no_grad():
            teacher_out = self.teacher_model(inputs)

        if self.distill_type == "soft":
            T = self.tau
            distill_loss = nn.KLDivLoss(reduction="batchmean")(nn.functional.log_softmax(cls_logits / T, dim=1),
                                                               nn.functional.softmax(teacher_out / T, dim=1)) * (T * T)
            return (1 - self.alpha) * base_loss + self.alpha * distill_loss
        else:
            hard_targets = teacher_out.argmax(dim=1)
            distill_loss = nn.functional.cross_entropy(cls_logits, hard_targets)
            return (1 - self.alpha) * base_loss + self.alpha * distill_loss


# ===================== Training / Eval =====================
# (train_one_epoch, evaluate, DDP init, dataloaders, lr schedule, main) 保持和之前一致
# 这里只是模型替换为 get_deit_tiny，并在 DistillationLoss 里适配了 tuple 输出。
