import sys
sys.path.append('../')
import argparse
import torch
from torch import nn
import numpy as np
import random
import os
from arguments import args, parse_args
from trainer_resnet import Trainer
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def main():
	args.dataset = 'PACS'
	args.source = ['art_painting', 'cartoon','sketch', 'photo']
	args.source.remove(args.target)
	print("Source domain: {}".format(args.source))
	print("Target domain: {}".format(args.target))

	args.n_classes = 7
	args.batch_size = 64
	args.learning_rate = 0.01 
	args.mask_learning_rate = 1.0 
	args.drop_rate = 0.15
	args.ber_sample = 1
	args.freeze_bn = False
	args.evaluation_iter = 20
	args.sigmoid_head = True
	args.epochs = 100
	if args.network == 'ResNet18':
		args.feature_dim = 512 
	else:
		args.network == 'ResNet50' 
		args.feature_dim = 2048 
	
	

	args.tau = [-1.0, 0.5, 0.6, 0.7, 0.0] #-1.0 is euqal to Ensemble
	device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
	
	val_results = []
	test_results = []
	test_used_features = []
	val_used_features = []


	for run in range(args.runs):
		trainer = Trainer(args, device)
		results, utilized_features = trainer.do_training()
		val_results.append(results['val'])
		test_results.append(results['test'])
		test_used_features.append(utilized_features['val'])
		val_used_features.append(utilized_features['test'])
		np.savez('iclr_results/{}_{}_{}.npz'.format(args.target, args.network, args.drop_rate), 
			V=np.array(val_results), 
			T=np.array(test_results),
			X=np.array(val_used_features), 
			Y=np.array(test_used_features))
		print(args)

if __name__ == "__main__":
	parse_args(args)
	torch.backends.cudnn.benchmark = True
	main()


