from models.resnet import resnet50, CIFAR_ResNet18
from torchvision.models import efficientnet_v2_s, convnext_tiny
from vit_pytorch import SimpleViT
from mlp_mixer_pytorch import MLPMixer
from datasets import load_dataset
import torchvision.utils as vutils
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import argparse
import numpy as np
import os
import random
import math
import time
torch.set_printoptions(profile="full")

def accuracy(soft_targets, labels):
    ans = (torch.max(soft_targets, 1).indices == labels).sum()#.cpu().numpy()
    return ans

def accuracy5(outs, labels):
    ans = sum([1 if labels[i] in torch.topk(outs, 5).indices[i] else 0 for i in range(len(outs))])
    return ans

def _init_fn(worker_id):
    np.random.seed(args.seed+worker_id)

parser = argparse.ArgumentParser()
parser.add_argument("--bs", default=32, type=int, help="batch size")
parser.add_argument("--subset_num", default=1, type=int)
parser.add_argument("--parallel", action="store_true")
parser.add_argument("--gpu", default=0, type=int)
parser.add_argument("--lr", default=3e-4, type=float, help="learning rate")
parser.add_argument("--wd", default=0.0, type=float)
parser.add_argument("--epochs", default=200, type=int)
parser.add_argument("--seed", default=0, type=int, help="different for every model")
parser.add_argument("--warmup", default=10, type=int)
parser.add_argument("--name", default="c10_ss1_r18", type=str)
parser.add_argument("--dataset", default="cifar10", type=str, help="imagenet | cifar10 | Caltech256")
parser.add_argument("--dataroot", default="./data", type=str, help="data dir")
parser.add_argument("--model", default="CIFAR_ResNet18", type=str, help="CIFAR_ResNet18 | resnet50 | eff2s")
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

train_loader, test_loader, dim, hw, mean_image = load_dataset(args.dataset, args.dataroot, bs=args.bs, worker_init_fn=_init_fn, drop_last=False, subset_num=args.subset_num, seed=args.seed)
device = torch.device("cuda:%d"%args.gpu)

remove = 0
in_channels = 3

if args.dataset == "Caltech256":
    patch_size, Dim, depth, heads, mlp_dim = 32, 1024, 6, 16, 2048
    m_patch_size, m_Dim, m_depth = 16, 512, 12
else:
    patch_size, Dim, depth, heads, mlp_dim = 4, 512, 6, 8, 512
    m_patch_size, m_Dim, m_depth = 4, 512, 6


if args.parallel:
    gpus = [i for i in range(8)]
    gpus.remove(args.gpu)
    gpus = [args.gpu] + gpus
    device = torch.device("cuda:%d"%args.gpu)
if args.model == "CIFAR_ResNet18":
    M = CIFAR_ResNet18(pretrained=False, num_classes=dim, remove=remove, in_channels=in_channels)
elif args.model == "resnet50":
    M = resnet50(pretrained=False, num_classes=dim, remove=remove, in_channels=in_channels)
elif args.model == "eff2s":
    M = efficientnet_v2_s(num_classes=dim, pretrained=False)
elif args.model == "convnext":
    M = convnext_tiny(num_classes=dim, pretrained=False)
elif args.model == "vit":
    M = SimpleViT(image_size=hw, patch_size=patch_size, num_classes=dim, dim=Dim, depth=depth, heads=heads, mlp_dim=mlp_dim) 
elif args.model == "mlpm":
    M = MLPMixer(image_size=hw, channels=in_channels, patch_size=m_patch_size, dim=m_Dim, depth=m_depth, num_classes=dim)
else:
    raise NotImplementedError
if args.parallel:
    model = nn.DataParallel(M, device_ids=gpus).to(device)
else:
    model = M.to(device)

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,last_epoch=-1)

criterion = nn.CrossEntropyLoss().to(device)
softmax = nn.Softmax(-1)
best_acc = 0.0
os.makedirs("./results/%s/"%args.name, exist_ok=True)

print(args)
for epoch in range(args.epochs):
    model = model.train()
    corr, corr5, total, train_loss = 0, 0, 0, 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        logits = model(images).to(device)
        preds = softmax(logits).to(device)
        cost = criterion(logits, labels).to(device)
        train_loss += cost.item()
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()
        corr += accuracy(preds, labels)
        corr5 += accuracy5(preds, labels)
        total += len(images)
    tr_acc, tr_acc5 = corr/total*100, corr5/total*100
    print("Epoch [%d]: Train loss: %.3f"%(epoch, train_loss)) 
    print("Epoch [%d]: Train acc: %.3f | acc5: %.3f"%(epoch, tr_acc, tr_acc5))
    if epoch >= args.warmup:
        scheduler.step()

    corr, corr5, total = 0, 0, 0
    model = model.eval()
    with torch.no_grad():
        for i, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)
            logits = model(images).to(device)
            preds = softmax(logits).to(device)
            corr += accuracy(preds, labels)
            corr5 += accuracy5(preds, labels)
            total += len(images)
    te_acc, te_acc5 = corr/total*100, corr5/total*100
    print("Epoch [%d]: Test acc: %.3f | acc5: %.3f"%(epoch, te_acc, te_acc5))
    acc = tr_acc
    if acc > best_acc:
        print("NEW BEST!\n")
        state = {
            "train_acc": tr_acc,
            "train_acc5": tr_acc5,
            "test_acc": te_acc,
            "test_acc5": te_acc5,
            "epoch": epoch
        }
        if args.parallel:
               state2 = {"model": model.module.state_dict()}
        else:
            state2 = {"model": model.state_dict()}
        state.update(state2)
        torch.save(state, "./results/%s/best.pt"%(args.name))
        best_acc = tr_acc
    else:
        print("\n")
