import argparse
import time
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
from torch.autograd import Variable
from sklearn.metrics import confusion_matrix

from planning_dataset import TSNDataSet
from planning_models import VideoModel
from utils.utils import plot_confusion_matrix

from colorama import init
from colorama import Fore, Back, Style
from tqdm import tqdm
from time import sleep

from sklearn.metrics import f1_score, recall_score, precision_score
from planning_main import calc_ed


FUTURE_LENGTH = 8
ACTION_NUM_CLASSES=27
init(autoreset=True)

# options
parser = argparse.ArgumentParser(description="Standard video-level testing")
parser.add_argument('class_file', type=str, default="classInd.txt")
parser.add_argument(
    'modality',
    type=str,
    choices=['RGB', 'Flow', 'RGBDiff', 'RGBDiff2', 'RGBDiffplus'])
parser.add_argument('test_list', type=str)
parser.add_argument('weights', type=str)

# ========================= Model Configs ==========================
parser.add_argument('--arch', type=str, default="resnet101")
parser.add_argument('--test_segments', type=int, default=5)

parser.add_argument(
    '--add_fc',
    default=1,
    type=int,
    metavar='M',
    help=
    'number of additional fc layers (excluding the last fc layer) (e.g. 0, 1, 2, ...)'
)
parser.add_argument('--fc_dim',
                    type=int,
                    default=512,
                    help='dimension of added fc')
parser.add_argument('--baseline_type',
                    type=str,
                    default='frame',
                    choices=['frame', 'video', 'tsn'])
parser.add_argument(
    '--frame_aggregation',
    type=str,
    default='avgpool',
    choices=['avgpool', 'rnn', 'temconv', 'trn-m', 'none'],
    help='aggregation of frame features (none if baseline_type is not video)')
parser.add_argument('--dropout_i', type=float, default=0)
parser.add_argument('--dropout_v', type=float, default=0)

#------ RNN ------
parser.add_argument('--n_rnn',
                    default=1,
                    type=int,
                    metavar='M',
                    help='number of RNN layers (e.g. 0, 1, 2, ...)')
parser.add_argument('--rnn_cell',
                    type=str,
                    default='LSTM',
                    choices=['LSTM', 'GRU'])
parser.add_argument('--n_directions',
                    type=int,
                    default=1,
                    choices=[1, 2],
                    help='(bi-) direction RNN')
parser.add_argument('--n_ts',
                    type=int,
                    default=5,
                    help='number of temporal segments')

# ========================= DA Configs ==========================
parser.add_argument('--share_params',
                    type=str,
                    default='Y',
                    choices=['Y', 'N'])
parser.add_argument('--use_bn',
                    type=str,
                    default='none',
                    choices=['none', 'AdaBN', 'AutoDIAL'])
parser.add_argument('--use_attn_frame',
                    type=str,
                    default='none',
                    choices=['none', 'TransAttn', 'general', 'DotProduct'],
                    help='attention-mechanism for frames only')
parser.add_argument('--use_attn',
                    type=str,
                    default='none',
                    choices=['none', 'TransAttn', 'general', 'DotProduct'],
                    help='attention-mechanism')
parser.add_argument('--n_attn',
                    type=int,
                    default=1,
                    help='number of discriminators for transferable attention')

# ========================= Monitor Configs ==========================
parser.add_argument('--top',
                    default=[1, 3, 5],
                    nargs='+',
                    type=int,
                    help='show top-N categories')
parser.add_argument('--verbose', default=False, action="store_true")

# ========================= Runtime Configs ==========================
parser.add_argument('--save_confusion', type=str, default=None)
parser.add_argument('--save_scores', type=str, default=None)
parser.add_argument('--save_attention', type=str, default=None)
parser.add_argument('--max_num',
                    type=int,
                    default=-1,
                    help='number of videos to test')
parser.add_argument('-j',
                    '--workers',
                    default=4,
                    type=int,
                    metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--bS',
                    default=2,
                    help='batch size',
                    type=int,
                    required=False)
parser.add_argument('--gpus', nargs='+', type=int, default=None)
parser.add_argument('--flow_prefix', type=str, default='')
parser.add_argument('--feat_path',type=str)

args = parser.parse_args()

# class_names = [line.strip().split(' ', 1)[1] for line in open(args.class_file)]
# num_class = len(class_names)
num_class = FUTURE_LENGTH*ACTION_NUM_CLASSES

#=== Load the network ===#
print(Fore.CYAN + 'preparing the model......')
net = VideoModel(
    num_class,
    args.baseline_type,
    args.frame_aggregation,
    args.modality,
    train_segments=args.test_segments if args.baseline_type == 'video' else 1,
    val_segments=args.test_segments if args.baseline_type == 'video' else 1,
    base_model=args.arch,
    add_fc=args.add_fc,
    fc_dim=args.fc_dim,
    share_params=args.share_params,
    dropout_i=args.dropout_i,
    dropout_v=args.dropout_v,
    use_bn=args.use_bn,
    partial_bn=False,
    n_rnn=args.n_rnn,
    rnn_cell=args.rnn_cell,
    n_directions=args.n_directions,
    n_ts=args.n_ts,
    use_attn=args.use_attn,
    n_attn=args.n_attn,
    use_attn_frame=args.use_attn_frame,
    verbose=args.verbose)

checkpoint = torch.load(args.weights)

validation_score = checkpoint['prec1']
print("model epoch {} prec@1: {}".format(checkpoint['epoch'],
                                         checkpoint['prec1']))

base_dict = {
    '.'.join(k.split('.')[1:]): v
    for k, v in list(checkpoint['state_dict'].items())
}
net.load_state_dict(base_dict)

#=== Data loading ===#
print(Fore.CYAN + 'loading data......')

data_length = 1 if args.modality == "RGB" else 5
num_test = sum(1 for i in open(args.test_list))

data_set = TSNDataSet(
    "",
    args.test_list,
    args.feat_path,
    num_dataload=num_test,
    num_segments=args.test_segments,
    new_length=data_length,
    modality=args.modality,
    image_tmpl="img_{:05d}.t7" if args.modality
    in ['RGB', 'RGBDiff', 'RGBDiff2', 'RGBDiffplus'] else args.flow_prefix +
    "{}_{:05d}.t7",
    test_mode=True,
)
data_loader = torch.utils.data.DataLoader(data_set,
                                          batch_size=args.bS,
                                          shuffle=False,
                                          num_workers=args.workers,
                                          pin_memory=True)

data_gen = tqdm(data_loader)

#--- GPU processing ---#
net = torch.nn.DataParallel(net.cuda())
net.eval()

output = []
attn_values = torch.Tensor()


#############################################################
def eval_video(video_data):
    i, data, label = video_data

    data = data.cuda()
    label = label.cuda(non_blocking=True)  # pytorch 0.4.X

    num_crop = 1

    # e.g.
    # data.shape = [1,sample # x 2048]
    # data.view(-1, length, data.size(2), data.size(3)).shape = [sample #,2048]

    with torch.no_grad():
        _, _, _, _, _, attn, out, _, _, _ = net(data,
                                                data, [0, 0, 0],
                                                0,
                                                is_train=False,
                                                reverse=False)
        # print(out.shape, flush=True)
        # print(label.shape, flush=True)
        return out, label
        # out = nn.Softmax(dim=1)(out).topk(max(args.top))
        # prob = out[0].data.cpu().numpy().copy(
        # )  # rst.shape = [sample #, top class #]
        # pred_labels = out[1].data.cpu().numpy().copy(
        # )  # rst.shape = [sample #, top class #]

        # if args.baseline_type == 'video':
        #     prob_video = prob.reshape(
        #         (data.size(0), num_crop, max(args.top))).mean(axis=1).reshape(
        #             (data.size(0), max(args.top)))
        # else:
        #     prob_video = prob.reshape(
        #         (data.size(0), num_crop, args.test_segments,
        #          max(args.top))).mean(axis=1).reshape(
        #              (data.size(0), args.test_segments, max(args.top)))
        #     prob_video = np.mean(prob_video, axis=1)

        # return i, prob_video, pred_labels, label.cpu().numpy(), attn.cpu()


#############################################################

proc_start_time = time.time()
max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset)

count_correct_topK = [0 for i in range(len(args.top))]
count_total = 0
video_pred = [[] for i in range(max(args.top))]
video_labels = []

all_preds = []
all_labels = []
#=== Testing ===#
print(Fore.CYAN + 'start testing......')
for i, (data, label) in enumerate(data_gen):
    data_size_ori = data.size()  # original shape
    if data_size_ori[0] < args.bS:
        data_dummy = torch.zeros(args.bS - data_size_ori[0], data_size_ori[1],
                                 data_size_ori[2])
        data = torch.cat((data, data_dummy))
        label_dummy = torch.zeros(args.bS - data_size_ori[0], label.shape[1])
        # print(label.shape, label_dummy.shape)
        label = torch.cat((label, label_dummy))

    if i >= max_num:
        break
    preds, labels = eval_video((i, data, label))

    # remove the dummy part
    preds = preds[:data_size_ori[0]]  # rst[1].shape = [sample #, top class #]
    labels = labels[:data_size_ori[0]]

    preds = preds.reshape(-1, FUTURE_LENGTH, num_class//FUTURE_LENGTH)
    
    all_preds.append(preds)
    all_labels.append(labels)

    # attn_values = torch.cat((attn_values, attn))  # save the attention values

    # accumulate
    # for j in range(len(args.top)):
    #     for k in range(args.top[j]):
    #         count_correct_topK[j] += ((preds[:, k] == labels) * 1).sum()
    # count_total += (preds[:, 0].shape)[0]
    # acc_topK = [
    #     float(count_correct_topK[j]) / float(count_total)
    #     for j in range(len(args.top))
    # ]

    # for k in range(max(args.top)):
    #     video_pred[k] += preds[:, k].tolist()  # save the top-K prediction

    # video_labels += labels.tolist()

    # cnt_time = time.time() - proc_start_time
    # line_print = '                                                            '  # leave a large space for the tqdm progess bar
    # for j in range(len(args.top)):
    #     line_print += 'Pred@%d %f, ' % (args.top[j], acc_topK[j])

    # line_print += 'average %f sec/video\r' % (float(cnt_time) /
    #   (i + 1) / args.bS)
    # data_gen.set_description(line_print)

all_labels = torch.cat(all_labels)
all_preds = torch.cat(all_preds)
# print(all_labels)
# print(all_preds)
auedit, res = calc_ed(all_preds, all_labels, k=5, logits=True)
print(auedit)


ed_final = 1 - auedit["action_AUED"]
# print(('Testing Results: recall top5 {recall:.3f} prec top5 {prec:.3f}'.format(
#     recall=recall, prec=prec)))

print("validation score:", validation_score)
print("testing score:", ed_final)

# if args.save_attention:
#     np.savetxt(args.save_attention + '.txt',
#                attn_values.cpu().numpy(),
#                fmt="%s")

# cf = [
#     confusion_matrix(video_labels,
#                      video_pred[k],
#                      labels=list(range(num_class)))
#     for k in range(max(args.top))
# ]

# plot_confusion_matrix(args.save_confusion + '.png',
#                       cf[0],
#                       classes=[i for i in range(num_class)],
#                       normalize=True,
#                       title='Normalized confusion matrix')

# #--- overall accuracy ---#
# cls_cnt = cf[0].sum(axis=1)
# cls_hit = np.array([np.diag(cf[i]) for i in range(max(args.top))])
# cls_acc_topK = [cls_hit[:j].sum(axis=0) / cls_cnt for j in args.top]

# if args.verbose:
#     for i in range(len(cls_acc_topK[0])):
#         line_print = ''
#         for j in range(len(args.top)):
#             line_print += str(cls_acc_topK[j][i]) + ' '
#         print(line_print)

# final_line = ''
# for j in args.top:
#     final_line += Fore.YELLOW + 'Pred@{:d} {:.02f}% '.format(
#         j,
#         np.sum(cls_hit[:j].sum(axis=0)) / np.sum(cls_cnt) * 100)
# print(final_line)

# if args.save_confusion:
#     class_acc_file = open(
#         args.save_confusion + '-top' + str(args.top) + '.txt', 'w')

#     for i in range(len(cls_acc_topK[0])):
#         line_print = ''
#         for j in range(len(args.top)):
#             line_print += str(cls_acc_topK[j][i]) + ' '
#         class_acc_file.write(line_print + '\n')

#     class_acc_file.close()

# if args.save_scores is not None:
#     # reorder before saving
#     name_list = [x.strip().split()[0] for x in open(args.test_list)]

#     order_dict = {e: i for i, e in enumerate(sorted(name_list))}

#     reorder_output = [None] * len(output)
#     reorder_label = [None] * len(output)

#     for i in range(len(output)):
#         idx = order_dict[name_list[i]]
#         reorder_output[idx] = output[i]
#         reorder_label[idx] = video_labels[i]

#     np.savez(args.save_scores, scores=reorder_output, labels=reorder_label)
