import torch
from torch.autograd import Variable
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torchvision import models, utils, datasets, transforms
import numpy as np
import sys
import os
from PIL import Image
from typing import Any
import glob
import argparse


class AvgrageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)

        res.append(correct_k.mul_(100.0)/batch_size)
    return res


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img

def _data_transforms_svhm(args):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform


def _data_transforms_cifar10(args):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform


def _data_transforms_imagenet():
    train_transform = transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform


def _data_transforms_cifar100(args):
    mean = [0.5071, 0.4865, 0.4409]  # CIFAR-100 数据集的均值
    std = [0.2673, 0.2564, 0.2761]  # CIFAR-100 数据集的标准差

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),  # 随机裁剪
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.ToTensor(),  # 转换为Tensor
        transforms.Normalize(mean, std),  # 归一化
    ])

    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))  # 可选的 Cutout 增强

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    return train_transform, valid_transform

def _data_transforms_imagenet1K():
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    valid_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform


def count_parameters_in_MB(model):
    return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6


def save(model, epoch, optimizer, valid_acc, model_dir=""):
    # 获取当前时间戳，格式为 YYYYMMDD_HHMMSS
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    # 格式化文件名，包含时间戳和准确率
    model_filename = f"EV_model_acc{valid_acc:.4f}_{timestamp}.pt"
    model_path = os.path.join(model_dir, model_filename)

    # 保存模型和优化器状态及其他信息
    torch.save({
        'epoch': epoch,  # 训练的 epoch
        'model': model.state_dict(),  # 模型的权重
        'optimizer': optimizer.state_dict(),  # 优化器的状态
        'valid_acc': valid_acc,  # 验证集准确率
    }, model_path)
    print(f"Model saved to {model_path}")

def load(model, model_path):
    model.load_state_dict(torch.load(model_path))


def drop_path(x, drop_prob):
    if drop_prob > 0.:
        keep_prob = 1.-drop_prob
        mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
        x.div_(keep_prob)
        x.mul_(mask)
    return x


def madry_generate(model, x_natural, y, optimizer, step_size=0.003, epsilon=0.031, perturb_steps=10, distance='l_inf'):
    criterion_ce = torch.nn.CrossEntropyLoss(reduction='none')
    model.eval()

    # 冻结模型参数梯度
    original_requires_grad = {}
    for param in model.parameters():
        original_requires_grad[param] = param.requires_grad
        param.requires_grad = False

    x_adv = x_natural.detach() + 0.001 * torch.randn_like(x_natural).detach()

    if distance == 'l_inf':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                out = model(x_adv)
                if isinstance(out, tuple):
                    logits, _ = out
                else:
                    logits = out
                loss_ce = criterion_ce(logits, y).mean()
            grad = torch.autograd.grad(loss_ce, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * grad.sign()
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    # 恢复模型参数梯度
    for param in model.parameters():
        param.requires_grad = original_requires_grad[param]
    model.train()

    return x_adv.detach()

class TrainTinyImageNet(Dataset):
    def __init__(self, root, id, transform=None) -> None:
        super().__init__()
        # 使用 os.path.join 拼接路径
        self.filenames = glob.glob(os.path.join(root, 'train', '*', '*', '*.JPEG'))
        self.transform = transform
        self.id_dict = id

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx: int):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == 'L':
            image = image.convert('RGB')
        # 使用 os.path.normpath 和 os.sep 来分割路径
        parts = os.path.normpath(img_path).split(os.sep)
        # 这里假设类别目录在倒数第三个位置
        label = self.id_dict[parts[-3]]
        if self.transform:
            image = self.transform(image)
        return image, label


class ValTinyImageNet(Dataset):
    def __init__(self, root, id, transform=None):
        self.filenames = glob.glob(os.path.join(root, 'val', 'images', '*.JPEG'))
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        annotations_path = os.path.join(root, 'val', 'val_annotations.txt')
        with open(annotations_path, 'r') as f:
            for line in f:
                a = line.split('\t')
                img, cls_id = a[0], a[1]
                self.cls_dic[img] = self.id_dict[cls_id]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx: int):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == 'L':
            image = image.convert('RGB')
        # 获取文件名使用 os.path.normpath 和 os.sep
        parts = os.path.normpath(img_path).split(os.sep)
        # 文件名位于最后一个位置
        filename = parts[-1]
        label = self.cls_dic[filename]
        if self.transform:
            image = self.transform(image)
        return image, label


def load_tinyimagenet(args):
    batch_size = args.batch_size
    nw = args.workers
    root = args.data_dir
    id_dic = {}
    wnids_path = os.path.join(root, 'wnids.txt')
    with open(wnids_path, 'r') as f:
        for i, line in enumerate(f):
            id_dic[line.strip()] = i
    num_classes = len(id_dic)
    train_transforms = [
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ]
    # 根据 args.cutout 决定是否加入 Cutout
    if getattr(args, 'cutout', False):
        train_transforms.append(Cutout(args.cutout_length))
    print("Applied train transforms:", train_transforms)
    valid_transforms = [
        transforms.ToTensor()
    ]

    data_transform = {
        "train": transforms.Compose(train_transforms),
        "val": transforms.Compose(valid_transforms),
    }

    train_dataset = TrainTinyImageNet(root, id=id_dic, transform=data_transform["train"])
    val_dataset = ValTinyImageNet(root, id=id_dic, transform=data_transform["val"])

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=nw
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=nw
    )

    print("TinyImageNet Loading SUCCESS" +
          "\nlen of train dataset: " + str(len(train_dataset)) +
          "\nlen of val dataset: " + str(len(val_dataset)))

    return train_loader, val_loader, num_classes

