import torch
from torch.autograd import Variable
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torchvision import models, utils, datasets, transforms
import numpy as np
import sys
import os
from PIL import Image
from typing import Any
import glob
import argparse


class AvgrageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)

        res.append(correct_k.mul_(100.0)/batch_size)
    return res


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_cifar10(args):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform


def _data_transforms_imagenet():
    train_transform = transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform


def _data_transforms_cifar100(args):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    print("Applied train transforms:", train_transform)
    return train_transform, valid_transform


def _data_transforms_cifar100else(args):
    mean = [0.5071, 0.4865, 0.4409]
    std = [0.2673, 0.2564, 0.2761]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    return train_transform, valid_transform




def count_parameters_in_MB(model):
    return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6


def save(model, epoch, optimizer, valid_acc, model_dir=""):

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')


    model_filename = f"EV_model_acc{valid_acc:.4f}_{timestamp}.pt"
    model_path = os.path.join(model_dir, model_filename)


    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'valid_acc': valid_acc,
    }, model_path)
    print(f"Model saved to {model_path}")
    return model_path

def load(model, model_path):
    model.load_state_dict(torch.load(model_path))


def drop_path(x, drop_prob):
    if drop_prob > 0.:
        keep_prob = 1.-drop_prob
        mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
        x.div_(keep_prob)
        x.mul_(mask)
    return x


def madry_generate(model, x_natural, y, optimizer, step_size=0.003, epsilon=0.031, perturb_steps=10, distance='l_inf'):
    criterion_ce = torch.nn.CrossEntropyLoss(reduction='none')
    model.eval()

    # 冻结模型参数梯度
    original_requires_grad = {}
    for param in model.parameters():
        original_requires_grad[param] = param.requires_grad
        param.requires_grad = False

    x_adv = x_natural.detach() + 0.001 * torch.randn_like(x_natural).detach()

    if distance == 'l_inf':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                out = model(x_adv)
                if isinstance(out, tuple):
                    logits, _ = out
                else:
                    logits = out
                loss_ce = criterion_ce(logits, y).mean()
            grad = torch.autograd.grad(loss_ce, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * grad.sign()
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)

    # 恢复模型参数梯度
    for param in model.parameters():
        param.requires_grad = original_requires_grad[param]
    model.train()

    return x_adv.detach()



class TrainTinyImageNet(Dataset):
    def __init__(self, root, id, transform=None) -> None:
        super().__init__()
        # 使用 os.path.join 拼接路径
        self.filenames = glob.glob(os.path.join(root, 'train', '*', '*', '*.JPEG'))
        self.transform = transform
        self.id_dict = id

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx: int):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == 'L':
            image = image.convert('RGB')
        # 使用 os.path.normpath 和 os.sep 来分割路径
        parts = os.path.normpath(img_path).split(os.sep)
        # 这里假设类别目录在倒数第三个位置
        label = self.id_dict[parts[-3]]
        if self.transform:
            image = self.transform(image)
        return image, label


class ValTinyImageNet(Dataset):
    def __init__(self, root, id, transform=None):
        self.filenames = glob.glob(os.path.join(root, 'val', 'images', '*.JPEG'))
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        annotations_path = os.path.join(root, 'val', 'val_annotations.txt')
        with open(annotations_path, 'r') as f:
            for line in f:
                a = line.split('\t')
                img, cls_id = a[0], a[1]
                self.cls_dic[img] = self.id_dict[cls_id]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx: int):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == 'L':
            image = image.convert('RGB')
        # 获取文件名使用 os.path.normpath 和 os.sep
        parts = os.path.normpath(img_path).split(os.sep)
        # 文件名位于最后一个位置
        filename = parts[-1]
        label = self.cls_dic[filename]
        if self.transform:
            image = self.transform(image)
        return image, label





def load_tinyimagenet(args):
    batch_size = args.batch_size
    nw = args.workers
    root = args.data_dir
    id_dic = {}
    wnids_path = os.path.join(root, 'wnids.txt')
    with open(wnids_path, 'r') as f:
        for i, line in enumerate(f):
            id_dic[line.strip()] = i
    num_classes = len(id_dic)
    train_transforms = [
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
    # 根据 args.cutout 决定是否加入 Cutout
    if getattr(args, 'cutout', False):
        train_transforms.append(Cutout(args.cutout_length))
    print("Applied train transforms:", train_transforms)
    valid_transforms = [
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]

    data_transform = {
        "train": transforms.Compose(train_transforms),
        "val": transforms.Compose(valid_transforms),
    }

    train_dataset = TrainTinyImageNet(root, id=id_dic, transform=data_transform["train"])
    val_dataset = ValTinyImageNet(root, id=id_dic, transform=data_transform["val"])

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=nw
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=nw
    )

    print("TinyImageNet Loading SUCCESS" +
          "\nlen of train dataset: " + str(len(train_dataset)) +
          "\nlen of val dataset: " + str(len(val_dataset)))

    return train_loader, val_loader, num_classes
