import argparse
import math
import random
import sys
import os
import json
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler


import arguments
import models.data_utils.data_utils as data_utils
import models
import models.model_utils as model_utils

def train(args):
	print('Training:')
	if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
		rank = int(os.environ["RANK"])
		world_size = int(os.environ['WORLD_SIZE'])
		gpu = int(os.environ['LOCAL_RANK'])
	elif 'SLURM_PROCID' in os.environ:
		rank = int(os.environ['SLURM_PROCID'])
		gpu = rank % torch.cuda.device_count()
	torch.cuda.set_device(gpu)
	torch.distributed.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank)
	torch.manual_seed(args.seed)
	torch.cuda.manual_seed_all(args.seed)
	np.random.seed(args.seed)
	random.seed(args.seed)

	prog_vocab, prog_vocab_list = data_utils.load_vocab(args)
	args.prog_vocab_size = len(prog_vocab)

	train_data = data_utils.get_dataset(args.train_dataset, args.prog_format, prog_vocab, prog_vocab_list, args)
	val_data = data_utils.get_dataset(args.val_dataset, args.prog_format, prog_vocab, prog_vocab_list, args)
	val_random_data = torch.load('../datasets/Karel/val.random.thdump')

	if args.max_prog_len is None:
		max_prog_len = data_utils.cal_data_stat(train_data)
		print('max training prog len: ', max_prog_len)
		args.max_prog_len = max_prog_len
		max_prog_len = data_utils.cal_data_stat(val_data)
		print('max val prog len: ', max_prog_len)
		args.max_prog_len = max(args.max_prog_len, max_prog_len)
	model_supervisor = models.create_model(args, prog_vocab, prog_vocab_list)
	logger = model_utils.Logger(args)

	local_rank = torch.distributed.get_rank()
	device = torch.device("cuda", local_rank)
	train_data_size = len(train_data)
	sampler = DistributedSampler(train_data)
	model_supervisor.model.to(device)
	model_supervisor.model = torch.nn.parallel.DistributedDataParallel(model_supervisor.model, device_ids=[local_rank], output_device=local_rank)
	loader = DataLoader(dataset=train_data, batch_size=args.batch_size // torch.distributed.get_world_size(), sampler=sampler)
	for epoch in range(args.num_epochs):
		sampler.set_epoch(epoch)
		for batch_idx, data in enumerate(loader):
			batch_data = data_utils.get_data(data, args.prog_format, prog_vocab)
			train_loss = model_supervisor.train(batch_data)
			if local_rank == 0:
				print(epoch, batch_idx, 'train loss: ', train_loss)
			if model_supervisor.global_step % args.eval_every_n == 0:
				train_loss, train_exact_acc, train_generalization_acc, train_semantic_acc = model_supervisor.eval(train_data, args.eval_max_size)
				val_loss, val_exact_acc, val_generalization_acc, val_semantic_acc, val_functional_acc = model_supervisor.eval(val_data, random_data=val_random_data)
				val_summary = {'train_loss': train_loss, 'train_exact_acc': train_exact_acc, 'train_generalization_acc': train_generalization_acc,\
				'train_semantic_acc': train_semantic_acc, 'val_loss': val_loss, 'val_exact_acc': val_exact_acc, 'val_generalization_acc': val_generalization_acc,\
				'val_semantic_acc': val_semantic_acc, 'val_functional_acc': val_functional_acc}
				if local_rank == 0 and (val_summary['val_exact_acc'] >= logger.best_val_exact_acc or val_summary['val_loss'] <= logger.best_val_loss):
				 	model_supervisor.save_model()
				val_summary['model_type'] = args.model_type
				val_summary['global_step'] = model_supervisor.global_step
				logger.write_summary(val_summary)
			if args.lr_decay_steps is not None and model_supervisor.global_step % args.lr_decay_steps == 0:
				model_supervisor.model.module.lr_decay(args.lr_decay_rate)
		

def evaluate(args):
	print('Evaluation:')
	prog_vocab, prog_vocab_list = data_utils.load_vocab(args)
	args.prog_vocab_size = len(prog_vocab)
	model_supervisor = models.create_model(args, prog_vocab, prog_vocab_list)

	val_data = data_utils.get_dataset(args.val_dataset, args.prog_format, prog_vocab, prog_vocab_list, args)
	val_random_data = torch.load('../datasets/Karel/val.random.thdump')
	val_loss, val_exact_acc, val_generalization_acc, val_semantic_acc, val_functional_acc, val_semantic_functional_acc = model_supervisor.eval(val_data, random_data=val_random_data)
	print('Val loss: ', val_loss)
	print('Val exact acc: ', val_exact_acc)
	print('Val generalization acc: ', val_generalization_acc)
	print('Val semantic acc: ', val_semantic_acc)
	print('Val functional acc: ', val_functional_acc)
	import math
	print('Val semantic functional: ', math.log2(val_semantic_functional_acc / (val_semantic_acc * val_functional_acc)))

	test_data = data_utils.get_dataset(args.test_dataset, args.prog_format, prog_vocab, prog_vocab_list, args)
	test_random_data = torch.load('../datasets/Karel/test.random.thdump')
	test_loss, test_exact_acc, test_generalization_acc, test_semantic_acc, test_functional_acc, test_semantic_functional_acc = model_supervisor.eval(test_data, random_data=test_random_data)
	print('Test loss: ', test_loss)
	print('Test exact acc: ', test_exact_acc)
	print('Test generalization acc: ', test_generalization_acc)
	print('Test semantic acc: ', test_semantic_acc)
	print('Test functional acc: ', test_functional_acc)
	print('Test semantic functional: ', math.log2(test_semantic_functional_acc / (test_semantic_acc * test_functional_acc)))



if __name__ == "__main__":
	parser = arguments.get_arg_parser('Karel')
	args = parser.parse_args()
	args.cuda = not args.cpu and torch.cuda.is_available()

	if args.eval:
		evaluate(args)
	else:
		train(args)
