import os, time, pickle, argparse

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from torchvision.models import resnet18
from torch.utils.data.sampler import SubsetRandomSampler

from models import MNISTSingleCNN
from train_mnist import train_epochs



if torch.cuda.is_available():
	torch.backends.cudnn.deterministic = True

def main(args): 
	ROOT_DIR = "/Users/MyDinh/gradprop/lib/Fair_LTR_enhanced/"
	RANDOM_SEED = args.seed
	LEARNING_RATE = 0.001
	BATCH_SIZE = 64
	NUM_EPOCHS = 10

	# Architecture
	NUM_FEATURES = 28*28

	# Other
	DEVICE = args.device
	GRAYSCALE = True

	for i in range(10): 
	  args.target_num = i
	  print("*"*100)
	  print('Training for number {}'.format(args.target_num))

	  train_data = datasets.MNIST('./pytorch/data/', train=True, download=True,
	                            transform=transforms.Compose([
	                                transforms.ToTensor(),
	                                transforms.Normalize((0.1307,), (0.3081,))
	                            ]))
	  test_data = datasets.MNIST('./pytorch/data/', train=False, download=True,
	                              transform=transforms.Compose([
	                                  transforms.ToTensor(),
	                                  transforms.Normalize((0.1307,), (0.3081,))
	                              ]))


	  train_data.targets = torch.where(train_data.targets==args.target_num, torch.tensor(1), torch.tensor(0))

	  train_subset_indices = torch.load(os.path.join(ROOT_DIR,"data/mnist/train_{}_sample_idx.pt".format(args.target_num)))
	  train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=False,
	                                              sampler=SubsetRandomSampler(train_subset_indices))



	  test_data.targets = torch.where(test_data.targets==args.target_num, torch.tensor(1), torch.tensor(0))

	  test_subset_indices = torch.load(os.path.join(ROOT_DIR,"data/mnist/test_{}_sample_idx.pt".format(args.target_num)))
	  test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False,
	                                              sampler=SubsetRandomSampler(test_subset_indices))

	  device = torch.device(DEVICE)
	  torch.manual_seed(RANDOM_SEED)



	  model = MNISTSingleCNN()
	  model.to(DEVICE)

	  optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
	  criterion = nn.NLLLoss()


	  valid_acc_lst, valid_loss_lst = train_epochs(model, optimizer, criterion, train_loader, DEVICE, NUM_EPOCHS)
	  results_dump = {}
	  results_dump["seed"] = RANDOM_SEED
	  results_dump["target_num"] = args.target_num
	  results_dump["valid_acc_lst"] = valid_acc_lst
	  results_dump["valid_loss_lst"] = valid_loss_lst

	  results_dump["method"] = args.method
	  meta_filename = os.path.join(ROOT_DIR, "results/mnist","mnist_"+args.method+"_" +str(args.target_num) + "_"+str(args.seed)+'.p') 

	  pickle.dump(results_dump, open(meta_filename, 'wb'))





if __name__ == "__main__":

	parser = argparse.ArgumentParser(description='PyTorch Mnist Model')

	# Model parameters.
	parser.add_argument('--seed', type=int, default=1111,
	                help='random seed')
	parser.add_argument('--device', type=str, default="cpu",
	                help='use CUDA')

	parser.add_argument('--target_num', type=int,default=0,
	                help='number to train')
	parser.add_argument('--method',choices=['single', 'multi'],default='single',
	                help='method to train')
	args = parser.parse_args()

	main(args)

