from __future__ import print_function
import sys
import argparse
import os
import time
import numpy as np
import copy
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

import warnings
warnings.filterwarnings("ignore", category=UserWarning) 

# Set the random seed manually for reproducibility.
seed=100
torch.manual_seed(seed)
if torch.cuda.is_available():
  torch.cuda.manual_seed(seed)
else:
  print("WARNING: CUDA not available")

import opts     
parser = argparse.ArgumentParser(description='pytorch action')
opts.train_opts(parser)
args = parser.parse_args()
print(args)

torch.cuda.set_device(args.gpu)

outputclass=60
indim=50*3
if args.geo_aug:
  indim=50*3+897*2
batch_size = args.batch_size
seq_len=args.seq_len
gradientclip_value=args.gradclipvalue
if args.U_bound==0:
  U_bound=np.power(10,(np.log10(args.MAG)/args.seq_len))   
else:
  U_bound=args.U_bound

if args.model=='plainIndRNN':
  import Indrnn_plainnet as Indrnn_network
  model = Indrnn_network.stackedIndRNN_encoder(indim, outputclass)  
elif args.model=='residualIndRNN':
  import Indrnn_residualnet_preact as Indrnn_network
  model = Indrnn_network.ResidualNet(indim, outputclass)  
elif args.model=='denseIndRNN':
  import Indrnn_densenet as Indrnn_network
  if args.time_diff:
    import Indrnn_densenet_FA as Indrnn_network
  from ast import literal_eval
  block_config = literal_eval(args.block_config)
  model = Indrnn_network.DenseNet(indim, outputclass, growth_rate=args.growth_rate, block_config=block_config,
                                        num_init_features=args.growth_rate * args.num_first)
elif args.model=='adaptiveIndRNN':
  import Indrnn_adaptive as Indrnn_network
  from ast import literal_eval
  block_config = literal_eval(args.block_config)
  model = Indrnn_network.DenseNet(indim, outputclass, growth_rate=args.growth_rate, block_config=block_config,
                                        num_init_features=args.growth_rate * args.num_first, 
                                        hard=args.hard, type=args.type, threshold=args.threshold)
elif args.model=='fullAdaptiveIndRNN':
  import Indrnn_adaptive_full as Indrnn_network
  from ast import literal_eval
  block_config = literal_eval(args.block_config)
  model = Indrnn_network.DenseNet(indim, outputclass, growth_rate=args.growth_rate, block_config=block_config,
                                        num_init_features=args.growth_rate * args.num_first, 
                                        hard=args.hard, type=args.type, threshold=args.threshold)
else:
  print('set the model type: plainIndRNN, residualIndRNN, denseIndRNN, adaptiveIndRNN')
  assert 2==3                                        
model.cuda()
criterion = nn.CrossEntropyLoss()
###
params = list(model.parameters()) + list(criterion.parameters())
total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size())
print('Args:', args)
print('Model total parameters:', total_params)

if args.test_CV:
  train_datasets='train_CV_ntus'
  test_dataset='test_CV_ntus'
else:
  train_datasets='train_ntus'
  test_dataset='test_ntus'
geo_aug=args.geo_aug 
data_randtime_aug=args.data_randtime_aug 

from data_reader import DataHandler
dh_test= DataHandler(batch_size,seq_len,train_or_eval='test')
num_test_batches=int(np.ceil(dh_test.GetDatasetSize()/(batch_size+0.0)))

# Pass data
inputs,targets,index=dh_test.GetBatch()
inputs=inputs.transpose(1,0,2,3)
inputs=Variable(torch.from_numpy(inputs).cuda())
seq_len, batch_size, joints_no,_=inputs.size()             
inputs=inputs.view(seq_len,batch_size,3*joints_no)   

output=model(inputs)

model.load_state_dict(torch.load(args.weights, map_location="cuda:{}".format(args.gpu)))

def set_bn_train(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
      m.train()       

selection_weights = []
true_labels = []
def test(dh,num_batches,use_bn_trainstat=False):
  model.eval()
  if use_bn_trainstat:
    model.apply(set_bn_train)

  eval_selections = 0
  tacc=0
  count=0  
  start_time = time.time()
  total_testdata=dh_test.GetDatasetSize()  
  total_ave_acc=np.zeros((total_testdata,outputclass))
  testlabels=np.zeros((total_testdata))
  
  while(1):  
    inputs,targets,index=dh.GetBatch()
    inputs=inputs.transpose(1,0,2,3)
    testlabels[index]=targets
    inputs=Variable(torch.from_numpy(inputs).cuda())
    targets=Variable(torch.from_numpy(np.int64(targets)).cuda())
    seq_len, batch_size, joints_no,_=inputs.size()             
    inputs=inputs.view(seq_len,batch_size,3*joints_no)   
        
    output=model(inputs)
    pred = output.data.max(1)[1] # get the index of the max log-probability
    accuracy = pred.eq(targets.data).cpu().sum().numpy()    
    total_ave_acc[index]+=output.data.cpu().numpy()

    sel_weights = model.get_selection_weights()
    selection_weights.append(sel_weights)
    true_labels.append(targets.cpu().numpy())
    
    tacc+=accuracy
    count+=1

    eval_selections += model.regularizer().detach().cpu().numpy()
    # if count % 1000 == 0:
    #   np.save(args.weights + "_selections", np.array(selection_weights))
    #   np.save(args.weights + "_labels", np.array(true_labels))

    # if count==5000:
    #   break
    if count==num_batches*args.test_no:
      break

  top = np.argmax(total_ave_acc, axis=-1)
  eval_acc=np.mean(np.equal(top, testlabels))    
  elapsed = time.time() - start_time
  print ("test accuracy: ", tacc/(count*targets.data.size(0)+0.0), eval_acc)
  print("test selections: ", eval_selections/count)
  
  #print ('test time per batch: ', elapsed/(count+0.0))
  return tacc/(count*targets.data.size(0)+0.0)#, eval_acc/(total_testdata+0.0)

test_acc=test(dh_test,num_test_batches,True)     

# 'base of the spine'
# 'middle of the spine'
# 'neck'
# 'head
# 'left shoulder'
# 'left elbow'
# 'left wrist'
# 'left hand'
# 'right shoulder'
# 'right elbow'
# 'right wrist'
# 'right hand'
# 'left hip'
# 'left knee'
# 'left ankle'
# 'left foot'
# 'right hip'
# 'right knee'
# 'right ankle'
# 'right foot'
# 'spine'
# 'tip of the left hand'
# 'left thumb'
# 'tip of the right hand'
# 'right thumb'
