import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

import argparse
import os
import copy
import time
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from utils import ImageFolderIPC, get_dataset, get_network
from torch.optim.lr_scheduler import CosineAnnealingLR


parser = argparse.ArgumentParser(description="PyTorch Post-Training")

parser.add_argument('--dataset', type=str, default='Tiny', help='dataset')
parser.add_argument('--model', type=str, default='ResNet18', help='model')
parser.add_argument('--val_model', type=str, default='ResNet18', help='model')
parser.add_argument('--cuda', type=int, default=0, help='GPU id')
parser.add_argument('--batch_size', type=int, default=64)

parser.add_argument('--filter', type=str, default='low', help='low, high or poly')
parser.add_argument('--feat', type=str, default='avg', help='hw or avg')

parser.add_argument("--output_dir", default="./save", type=str)
parser.add_argument("--syn_data_path", default="/syn_data", type=str)
parser.add_argument("--syn_folder", type=str)
parser.add_argument("--teacher_path", default="./ckpt", type=str)
parser.add_argument("--ipc", default=50, type=int)
parser.add_argument("--epochs", default=100, type=int)

args = parser.parse_args()

args.device = 'cuda:{}'.format(args.cuda)
args.output_dir = os.path.join(args.output_dir, args.dataset)
#args.syn_data_path = os.path.join(args.syn_data_path, args.dataset, 'syn_img_{}_{}'.format(args.filter, args.feat))
args.syn_data_path = os.path.join(args.syn_data_path, args.dataset, args.syn_folder)
args.teacher_path = os.path.join(args.teacher_path, args.dataset)

print(args.syn_data_path)

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

channel, im_size, num_classes, dst_train, dst_test = get_dataset(args.dataset)

if args.dataset == 'Tiny':
    args.epochs = 100
    args.batch_size = 64

    transform_train = transforms.Compose(
        [
            transforms.RandomResizedCrop(im_size[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
                                std=[0.2302, 0.2265, 0.2262])
        ]
    )
elif args.dataset == 'ImageNet':
    args.epochs = 300
    args.batch_size = 100

    transform_train = transforms.Compose(
        [
            transforms.RandomResizedCrop(im_size[0]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ]
    )
elif args.dataset == 'CIFAR-100':
    args.epochs = 1000
    args.batch_size = 64

    transform_train = transforms.Compose(
        [
            # transforms.RandomResizedCrop(im_size[0]),
            transforms.RandomCrop(im_size[0])
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                std=[0.2675, 0.2565, 0.2761])
        ]
    )
elif args.dataset == 'CIFAR-10':
    args.epochs = 1000
    args.batch_size = 64

    transform_train = transforms.Compose(
        [
            # transforms.RandomResizedCrop(im_size[0]),
            transforms.RandomCrop(im_size[0])
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                std=[0.2023, 0.1994, 0.2010])
        ]
    )


print(len(dst_test))
print("=> Using IPC setting of ", args.ipc)
trainset = ImageFolderIPC(root=args.syn_data_path, transform=transform_train, ipc=args.ipc)
print(len(trainset))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)
testloader = torch.utils.data.DataLoader(dst_test, batch_size=256, shuffle=False, num_workers=16, pin_memory=True)

# Model
print("==> Building model..")

if args.dataset == 'ImageNet':

    model_teacher = torchvision.models.__dict__[args.model](pretrained=True)
    for p in model_teacher.parameters():
        p.requires_grad = False

    print(args.val_model)
    net = torchvision.models.get_model(args.val_model, weights=None, num_classes=num_classes)

    model_teacher = model_teacher.to(args.device)
    model_teacher.eval()

    net = net.to(args.device)
    net.train()

elif args.dataset == 'Tiny':

    model_teacher = torchvision.models.get_model('resnet18', weights=None, num_classes=num_classes)
    model_teacher.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model_teacher.maxpool = nn.Identity()

    ckpt_path = os.path.join(args.teacher_path, 'ResNet18.pth')
    model_teacher.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
    for p in model_teacher.parameters():
        p.requires_grad = False

    print(args.val_model)
    net = torchvision.models.get_model(args.val_model, weights=None, num_classes=num_classes)
    net.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    net.maxpool = nn.Identity()

    model_teacher = model_teacher.to(args.device)
    model_teacher.eval()

    net = net.to(args.device)
    net.train()


else: # CIFAR
    
    if args.model == 'ConvNetW128':
    
        ckpt_path = os.path.join(args.teacher_path, 'ConvNetW128.pth')
        model_teacher = get_network('ConvNetW128', channel=channel, num_classes=num_classes, im_size=im_size, dist=False)
        model_teacher.load_state_dict(torch.load(ckpt_path, weights_only=True))
        model_teacher.eval()
        for p in model_teacher.parameters():
            p.requires_grad = False

        net = get_network('ConvNetW128', channel=channel, num_classes=num_classes, im_size=im_size, dist=False)
    
    elif args.model == 'ResNet18':
        model_teacher = torchvision.models.get_model('resnet18', weights=None, num_classes=num_classes)
        model_teacher.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        model_teacher.maxpool = nn.Identity()

        ckpt_path = os.path.join(args.teacher_path, 'ResNet18.pth')
        model_teacher.load_state_dict(torch.load(ckpt_path, map_location='cpu', weights_only=True))
        for p in model_teacher.parameters():
            p.requires_grad = False

        net = torchvision.models.get_model('resnet18', weights=None, num_classes=num_classes)
        net.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        net.maxpool = nn.Identity()

    model_teacher = model_teacher.to(args.device)
    model_teacher.eval()

    net = net.to(args.device)
    net.train()


if args.dataset == 'Tiny':
    optimizer = torch.optim.SGD(net.parameters(), lr=0.2, momentum=0.9, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
elif args.dataset == 'ImageNet':
    optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
elif args.dataset == 'CIFAR-100':
    optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)
elif args.dataset == 'CIFAR-10':
    optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3, weight_decay=5e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)


criterion = nn.CrossEntropyLoss()
loss_function_kl = nn.KLDivLoss(reduction="batchmean")

if 'CIFAR' in args.dataset:
    args.temperature = 30
else:
    args.temperature = 20


def mixup_data(x, y, alpha=0.8):
    """
    Returns mixed inputs, mixed targets, and mixing coefficients.
    For normal learning
    """
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


# Train
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    t1 = time.time()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        inputs, _, _, _ = mixup_data(inputs, targets)

        optimizer.zero_grad()
        outputs = net(inputs)

        soft_label = model_teacher(inputs).detach()
        outputs_ = F.log_softmax(outputs / args.temperature, dim=1)
        soft_label = F.softmax(soft_label / args.temperature, dim=1)

        loss = args.temperature * args.temperature * loss_function_kl(outputs_, soft_label)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    t2 = time.time()

    print(f"Epoch: [{epoch}], Acc@1 {100.*correct/total:.3f}, Loss {train_loss/(batch_idx+1):.4f}, Time {t2-t1:.4f}")

# Test
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(args.device), targets.to(args.device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print(f"Test: Acc@1 {100.*correct/total:.3f}, Loss {test_loss/(batch_idx+1):.4f}")

    # Save checkpoint.
    acc = 100.0 * correct / total
    # if acc > best_acc:
    # save last checkpoint
    if True:
        state = {
            "state_dict": net.state_dict(),
            "acc": acc,
            "epoch": epoch,
        }
        # if not os.path.isdir('checkpoint'):
        #     os.mkdir('checkpoint')

        path = os.path.join(args.output_dir, "./ckpt.pth")
        torch.save(state, path)
        best_acc = acc


start_time = time.time()
for epoch in range(args.epochs):
    train(epoch)
    # fast test
    if (epoch + 1) % 50 == 0 or epoch == args.epochs - 1:
        test(epoch)
    scheduler.step()
end_time = time.time()

print(f"total time: {end_time - start_time} s")
print(args.syn_data_path)
