import argparse
import pickle
import numpy as np
import torch
import torch.nn as nn
import os
import random
import pdb
import time
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
from tinyModels import ResNet
from datasets.folder import ImageFolder, ImageFolderPrune

TRAIN_MEAN = [0.4802, 0.4481, 0.3975]
TRAIN_STD = [0.2302, 0.2265, 0.2262]


def seed_torch(seed=0):
    print(f"Using seed {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=0, help="training batch size")
    parser.add_argument("--data_path", type=str, default="./data", help="path to save/load data")
    parser.add_argument("--epochs", type=int, default=0, help="train epochs")
    parser.add_argument("--logdir", default="./log", help="tensorboard log dir")
    parser.add_argument("--model_dir", default="./tiny-imagenet/model", help="dir to save model")
    parser.add_argument("--lr", default=0, help="learning rate", type=float)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--prune", type=bool, default=False, help="whether to train on a pruned dataset")
    parser.add_argument("--pruneId_path", default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--save", default=False, type=bool)
    args = parser.parse_args()
    return args

@torch.no_grad()
def eval(model, loader, criterion, device):
    model.eval()
    losses, correct = 0, 0
    
    for _, img, target in loader:
        img, target = img.to(device), target.to(device)
        output = model(img)
        loss = criterion(output, target)
        losses += loss.item()
        preds = output.argmax(1)
        correct += torch.eq(preds, target).float().sum().cpu().numpy()

    acc = correct / len(loader.dataset)
    losses /= len(loader)
    return losses, acc

def main(args):
    
    batch_size = args.batch_size
    device = args.device
    os.makedirs(args.model_dir, exist_ok=True)
    
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(55),
        transforms.Resize(64),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(TRAIN_MEAN, TRAIN_STD),
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize(int(64/0.875)),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize(TRAIN_MEAN, TRAIN_STD),
    ])
    
    if args.prune:
        with open(args.pruneId_path, "rb") as f:
            drop_id = pickle.load(f)
        print("Using prune id: ", args.pruneId_path)
        
        data_train = ImageFolderPrune(root="train", 
                                  drop_ids=drop_id, transform=train_transform)
        data_test = ImageFolder("test", transform=test_transform)
    else:
        # train on normal dataset
        data_train = ImageFolder("train", transform=train_transform)
        data_test = ImageFolder("test", transform=test_transform)
        
    trainloader = DataLoader(data_train, batch_size=args.batch_size, pin_memory=True, num_workers=1, shuffle=True)
    print("trainset size:", len(trainloader.dataset))
    testloader = DataLoader(data_test, batch_size=32, pin_memory=True, num_workers=4)
    
    model = ResNet(pretrained=False)
    model = model.to(args.device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
    writer = SummaryWriter(args.logdir)
    
    iter = 0
    for epoch in range(args.epochs):
        model.train()
        losses, accs = [], []
        for i, (_, img, target) in enumerate(trainloader):
            
            img, target = img.to(device), target.to(device)
            output = model(img)
            loss = criterion(output, target)
            pred = output.argmax(1)
            
            acc = torch.eq(pred, target).float().mean().cpu().numpy()
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            
            losses.append(loss.item())
            accs.append(acc.item())
            # write train loss every 100 iterations
            if iter % 500 == 0:
                writer.add_scalar("train/loss", loss.item(), iter)
                writer.add_scalar("train/acc", acc, iter)
            iter += 1
            
        if args.save:
            save_path = os.path.join(args.model_dir, f"epoch{epoch}.pth")
            torch.save(model.state_dict(), save_path)
        
        train_loss, train_acc = np.mean(losses), np.mean(accs)
        eval_loss, eval_acc = eval(model, testloader, criterion, device)
        print(f"Epoch: {epoch}, train loss: {train_loss}, train accs: {train_acc}, val loss: {eval_loss}, val acc: {eval_acc}")
        scheduler.step()
        
    print("Ended at ", time.asctime( time.localtime(time.time()) ))
    
if __name__ == "__main__":
    args = parse_args()
    seed_torch(args.seed)
    main(args)