import argparse
import pickle
import numpy as np
import torch
import torch.nn as nn
import os
import random
import pdb
import time
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import MultiStepLR
from model import ResNet50
from datasets.dataset import CIFAR100, CIFAR100Prune

def seed_torch(seed=42):
    print(f"Using seed {seed}")
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=0, help="training batch size")
    parser.add_argument("--data_path", type=str, default="./data", help="path to save/load data")
    parser.add_argument("--epochs", type=int, default=0, help="train epochs")
    parser.add_argument("--model_dir", default="./", help="dir to save model")
    parser.add_argument("--lr", default=0, help="learning rate", type=float)
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--resume", type=bool, default=False)
    parser.add_argument("--prune", type=bool, default=False, help="whether to train on a pruned dataset")
    parser.add_argument("--pruneId_path", default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--save", default=False, type=bool)
    args = parser.parse_args()
    return args

@torch.no_grad()
def eval(model, loader, criterion, device):
    model.eval()
    losses, correct = 0, 0
    
    for _, img, target in loader:
        img, target = img.to(device), target.to(device)
        output = model(img)
        loss = criterion(output, target)
        losses += loss.item()
        preds = output.argmax(1)
        correct += torch.eq(preds, target).float().sum().cpu().numpy()

    acc = correct / len(loader.dataset)
    losses /= len(loader)
    return losses, acc
    
def main(args):
    
    batch_size = args.batch_size
    device = args.device
    os.makedirs(args.model_dir, exist_ok=True)
    
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 
                                (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 
                             (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
    ])
    
    if args.prune:
        with open(args.pruneId_path, "rb") as f:
            drop_id = pickle.load(f)
        print("Using prune id: ", args.pruneId_path)
        
        data_train = CIFAR100Prune(root=args.data_path, train=True,
                                  drop_id=drop_id, transform=train_transform,
                                  download=True)
        data_test = CIFAR100(root=args.data_path, train=False, transform=test_transform, 
                                                download=True)
    else:
        data_train = CIFAR100(root=args.data_path, train=True, transform=train_transform, 
                                                download=True)
        data_test = CIFAR100(root=args.data_path, train=False, transform=test_transform, 
                                                download=True)
        
    trainloader = DataLoader(data_train, shuffle=True, batch_size=batch_size, num_workers=10, pin_memory=True)
    testloader = DataLoader(data_test, batch_size=128, num_workers=4, pin_memory=True)
    
    print("Starting at ",time.asctime( time.localtime(time.time()) ))
    print("trainset size:", len(trainloader.dataset))
    model = ResNet50(100)
    model = model.to(args.device)
    
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
    writer = SummaryWriter(args.logdir)
    
    iter = 0
    for epoch in range(args.epochs):
        model.train()
        losses, accs = [], []
        for i, (_, img, target) in enumerate(trainloader):
            
            img, target = img.to(device), target.to(device)
            output = model(img)
            loss = criterion(output, target)
            pred = output.argmax(1)
            # pdb.set_trace()
            acc = torch.eq(pred, target).float().mean().cpu().numpy()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
            accs.append(acc.item())
            # write train loss every 100 iterations
            if iter % 500 == 0:
                writer.add_scalar("train/loss", loss.item(), iter)
                writer.add_scalar("train/acc", acc, iter)
                
            iter += 1
            
        if args.save:
            save_path = os.path.join(args.model_dir, f"epoch{epoch}.pth")
            torch.save(model.state_dict(), save_path)
        
        train_loss, train_acc = np.mean(losses), np.mean(accs)
        eval_loss, eval_acc = eval(model, testloader, criterion, device)
        print(f"Epoch: {epoch}, train loss: {train_loss}, train accs: {train_acc}, val loss: {eval_loss}, val acc: {eval_acc}")
        scheduler.step()
        # pdb.set_trace()
    print("Ended at ", time.asctime( time.localtime(time.time()) ))
    
if __name__ == "__main__":
    args = parse_args()
    seed_torch(args.seed)
    main(args)