import torch
from torch.autograd import Variable
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torchvision import models, utils, datasets, transforms
import numpy as np
import sys
import os
from PIL import Image
from typing import Any
import glob
import argparse


class AvgrageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)

        res.append(correct_k.mul_(100.0)/batch_size)
    return res


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_cifar10(args):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform


def _data_transforms_imagenet():
    train_transform = transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
    ])

    return train_transform, valid_transform


def _data_transforms_cifar100(args):
    mean = [0.5071, 0.4865, 0.4409]
    std = [0.2673, 0.2564, 0.2761]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    return train_transform, valid_transform, test_transform


def count_parameters_in_MB(model):
    return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6


def save(model, epoch, optimizer, valid_acc, model_dir=""):

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

    model_filename = f"EV_model_acc{valid_acc:.4f}_{timestamp}.pt"
    model_path = os.path.join(model_dir, model_filename)


    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'valid_acc': valid_acc,
    }, model_path)
    print(f"Model saved to {model_path}")
    return model_path

def load(model, model_path):
    model.load_state_dict(torch.load(model_path))


def drop_path(x, drop_prob):
    if drop_prob > 0.:
        keep_prob = 1.-drop_prob
        mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
        x.div_(keep_prob)
        x.mul_(mask)
    return x


def madry_generate(model, x_natural, y, optimizer, step_size=0.003, epsilon=0.031, perturb_steps=10, distance='l_inf'):
    criterion_ce = torch.nn.CrossEntropyLoss(reduction='none')
    model.eval()

    original_requires_grad = {}
    for param in model.parameters():
        original_requires_grad[param] = param.requires_grad
        param.requires_grad = False

    x_adv = x_natural.detach() + 0.001 * torch.randn_like(x_natural).detach()

    if distance == 'l_inf':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                out = model(x_adv)
                if isinstance(out, tuple):
                    logits, _ = out
                else:
                    logits = out
                loss_ce = criterion_ce(logits, y).mean()
            grad = torch.autograd.grad(loss_ce, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * grad.sign()
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)


    for param in model.parameters():
        param.requires_grad = original_requires_grad[param]
    model.train()

    return x_adv.detach()



class TrainTinyImageNet(Dataset):
    def __init__(self, root, id, transform=None) -> None:
        super().__init__()

        self.filenames = glob.glob(os.path.join(root, 'train', '*', '*', '*.JPEG'))
        self.transform = transform
        self.id_dict = id

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx: int):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == 'L':
            image = image.convert('RGB')

        parts = os.path.normpath(img_path).split(os.sep)

        label = self.id_dict[parts[-3]]
        if self.transform:
            image = self.transform(image)
        return image, label


class ValTinyImageNet(Dataset):
    def __init__(self, root, id, transform=None):
        self.filenames = glob.glob(os.path.join(root, 'val', 'images', '*.JPEG'))
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        annotations_path = os.path.join(root, 'val', 'val_annotations.txt')
        with open(annotations_path, 'r') as f:
            for line in f:
                a = line.split('\t')
                img, cls_id = a[0], a[1]
                self.cls_dic[img] = self.id_dict[cls_id]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx: int):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == 'L':
            image = image.convert('RGB')

        parts = os.path.normpath(img_path).split(os.sep)

        filename = parts[-1]
        label = self.cls_dic[filename]
        if self.transform:
            image = self.transform(image)
        return image, label





def load_tinyimagenet(args):
    batch_size = args.batch_size
    nw = args.workers
    root = args.data_dir
    id_dic = {}
    wnids_path = os.path.join(root, 'wnids.txt')
    with open(wnids_path, 'r') as f:
        for i, line in enumerate(f):
            id_dic[line.strip()] = i
    num_classes = len(id_dic)
    train_transforms = [
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
    # 根据 args.cutout 决定是否加入 Cutout
    if getattr(args, 'cutout', False):
        train_transforms.append(Cutout(args.cutout_length))
    print("Applied train transforms:", train_transforms)
    valid_transforms = [
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]

    data_transform = {
        "train": transforms.Compose(train_transforms),
        "val": transforms.Compose(valid_transforms),
    }

    train_dataset = TrainTinyImageNet(root, id=id_dic, transform=data_transform["train"])
    val_dataset = ValTinyImageNet(root, id=id_dic, transform=data_transform["val"])

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=nw
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=nw
    )

    print("TinyImageNet Loading SUCCESS" +
          "\nlen of train dataset: " + str(len(train_dataset)) +
          "\nlen of val dataset: " + str(len(val_dataset)))

    return train_loader, val_loader, num_classes
