import torch
import torch.utils.data
import torchvision.transforms as transforms
from datasets.Pascal3DPlus import Pascal3DPlus3D, ToTensor, Normalize, Pascal3DPlus
from lib.NCEAverage_new import NearestMemoryManager, mask_remove_near
from models.KeypointRepresentationNet import NetE2E
from datetime import datetime
import os
import argparse
import torch.nn.functional as F
import numpy as np
torch.autograd.set_detect_anomaly(True)
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.patches as patches
import matplotlib
from get_n_list import get_n_list

net_stride = {'vgg_pool4': 16, 'vgg_pool5': 32, 'resnet50': 32, 'resnext50': 32, 'resnet50_pre':16, 'resnet50_prepre':8, 'resunet':2, 'resunetpre':8, 'hg': 4, 'resnetupsample':16}

# os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"

matplotlib.use('agg')

##########################################################################
global args
parser = argparse.ArgumentParser(description='3D Representation Net Training')

parser.add_argument('--local_size', default = 1 , type = int, help = '')
# parser.add_argument('--d_feature', default = 512 , type = int , help = '')
parser.add_argument('--d_feature', default = 1024 , type = int , help = '')
parser.add_argument('--n_points', default = -1 , type = int, help = '')
parser.add_argument('--batch_size', default = 16, type = int , help = '')
parser.add_argument('--workers', default = 0, type = int, help = '')
parser.add_argument('--total_epochs', default = 400 , type = int, help = '')
parser.add_argument('--distance_thr', default = 8, type = int, help = '')
parser.add_argument('--T', default = 0.07 , type = float , help = '')
parser.add_argument('--weight_noise', default = 5e-3, type = float, help = '')
parser.add_argument('--update_lr_epoch_n', default = 5, type = int, help = '')
parser.add_argument('--update_lr_', default = 0.5 , type = float, help = '')
parser.add_argument('--lr', default = 0.01*0.01, type = float, help = '')
parser.add_argument('--momentum', default = 0.9 , type = float , help = '')
parser.add_argument('--weight_decay', default = 1e-4 , type = float , help = '')
parser.add_argument('--type_', default = 'car' , type = str, help = '')
parser.add_argument('--num_noise', default = 0 , type = int, help = '')
parser.add_argument('--max_group', default = 16 , type = int, help = '')
parser.add_argument('--adj_momentum', default = 0.9 , type = float , help = '')
parser.add_argument('--init_memory_bank', default = False, type = bool, help = '')
parser.add_argument('--mesh_path', default = '../PASCAL3D/PASCAL3D+_release1.1/CAD_%s/%s/', type = str, help = '')
# parser.add_argument('--save_dir', default='../3DrepresentationData/trained_resunetpre_second_3D_weighted_distcrop/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/trained_resunetpre_second_3D_weighted_distcrop_%s_offmask/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/trained_resunetpre_second_3D_weighted_distcrop_%s/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/trained_resunetpre_second_3D_weighted_distcrop_%s_offmask/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/trained_resnetupsample_second_3D_weighted_distcrop_%s_offmask/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/trained_vgg_pool4_second_3D_weighted_distcrop_%s_offmask/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/Unsupervised_V5_vggp4_trained/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/Unsupervised_V5_resnet50_pre_trained_20percent/', type=str, help='')
parser.add_argument('--save_dir', default='../3DrepresentationData/trained_resnet50_pre_second_3D_weighted_distcrop_buildn_7/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/trained_resunetpre_second_3D_%s/', type=str, help='')
# parser.add_argument('--save_dir', default='../3DrepresentationData/trained_resunetpre_second_3D_weighted_distcrop_%s_offmask/', type=str, help='')
parser.add_argument('--root_path', default = '../PASCAL3D/PASCAL3D_distcrop/', type = str, help = '') # '../PASCAL3D/PASCAL3D_train/'
# parser.add_argument('--root_path', default = '../PASCAL3D/PASCAL3D/', type = str, help = '')
parser.add_argument('--load_mb', default = True)
parser.add_argument('--num_mesh', default = -1 , type = int, help = '')
parser.add_argument('--data_pendix', default = '' , type = str, help = '')

# parser.add_argument('--ckpt', default = '3D1024_selected_points1saved_model_car_799.pth', type = str)
# parser.add_argument('--ckpt', default = '3D512_points1saved_model_car_799.pth', type = str)
parser.add_argument('--ckpt', default = '3D512_points1saved_model_%s_799.pth', type = str)
# parser.add_argument('--ckpt', default = 'saved_model_120_160_ungt.pth', type = str)
# parser.add_argument('--ckpt', default = 'saved_model_120_100_ungt_%s.pth', type = str)
# parser.add_argument('--ckpt', default = 'saved_model_120_60_ungt_%s.pth', type = str)



# parser.add_argument('--ckpt', default = '3D512_points1saved_unsupervised_bank_%s_799.pth', type = str)
parser.add_argument('--inp_res', default = 256, type = int)
parser.add_argument('--out_res', default = 256, type = int)

parser.add_argument('--backbone', default = 'resnet50_pre', type = str)
parser.add_argument('--stacks', default = 8, type = int)
parser.add_argument('--blocks', default = 1, type = int)
parser.add_argument('--mesh_d', default = 'buildn', type = str)
parser.add_argument('--objectnet', default = False, type = bool)
parser.add_argument('--eval_kp_score', default = False, type = bool)
parser.add_argument('--save_features_path', default = 'saved_features_7percent', type = str)


args = parser.parse_args()

mesh_d = args.mesh_d
# mesh_d = 'build'

thr = 0.1

# Generate to unseen pose
unseen_setting = False
if unseen_setting:
    azum_sel = 'TFFTTFFT'
    use_azum_data = 'TFFTTFFT'

    args.save_dir = args.save_dir.strip('/') + '_azum_' + azum_sel + '/'
else:
    azum_sel = ''
    use_azum_data = ''

args.mesh_path = args.mesh_path % (mesh_d, args.type_)
if '%s' in args.save_dir:
    args.save_dir = args.save_dir % mesh_d

if '%s' in args.ckpt:
    args.ckpt = args.ckpt % args.type_

if not args.objectnet:
    if len(args.data_pendix) == 0:
        if len(azum_sel) > 0:
            save_features = args.save_features_path + '/' + args.type_ + '/' + args.backbone + '_' + args.ckpt.split('.')[0] + '_%s_azum_%s_using_%s.npz' % (mesh_d, azum_sel, use_azum_data)
        else:
            save_features = args.save_features_path + '/' + args.type_ + '/' + args.backbone + '_' + args.ckpt.split('.')[0] + '_%s.npz' % mesh_d
    else:
        save_features = args.save_features_path + '/' + args.type_ + '_occ/' + args.data_pendix + '_resunetpre_' + args.ckpt.split('.')[0] + '_%s.npz' % mesh_d
        args.root_path = '../PASCAL3D/PASCAL3D_OCC_distcrop/'
else:
    save_features = args.save_features_path + '_objectnet/' + args.type_ + '/resunetpre_' + args.ckpt.split('.')[0] + '_%s.npz' % mesh_d

if args.type_ not in save_features.split(args.backbone)[1]:
    save_features = save_features.split('_%s.npz' % mesh_d)[0] + '_%s_%s.npz' % (args.type_, mesh_d)
print(save_features)

##########################################################################
num_kp_dict = {'aeroplane': 8, 'bicycle': 11, 'boat': 7, 'bottle': 7, 'bus': 12, 'car': 12, 'chair': 10, 'diningtable': 12, 'motorbike': 10, 'sofa': 10, 'train': 17, 'tvmonitor': 8}
num_mesh_dict = {'aeroplane': 8, 'bicycle': 6, 'boat': 6, 'bottle': 8, 'bus': 6, 'car': 10, 'chair': 10, 'diningtable': 6, 'motorbike': 5, 'sofa': 6, 'train': 4, 'tvmonitor': 4}
        
args.local_size = [args.local_size, args.local_size]

# SingleCuboid: 1, MultiCuboid: number of subtypes
n_list = get_n_list(args.mesh_path)
subtypes = ['mesh%02d' % (i + 1) for i in range(len(n_list))]

# net = NetE2E(net_type='resnet50', local_size=args.local_size,
#              output_dimension=args.d_feature, reduce_function=None, n_noise_points=args.num_noise, pretrain = True)
net = NetE2E(net_type=args.backbone, local_size=args.local_size,
             output_dimension=-1, reduce_function=None, n_noise_points=args.num_noise, pretrain = True)
# net = NetE2E(net_type=args.backbone, local_size=args.local_size,
#              output_dimension=-1, reduce_function=None, n_noise_points=args.num_noise, pretrain = True)

net = torch.nn.DataParallel(net).cuda()
net.eval()

transforms = transforms.Compose([
    ToTensor(),
    Normalize(),
])

criterion = torch.nn.CrossEntropyLoss(reduction='mean').cuda()

args.ckpt = os.path.join(args.save_dir, args.ckpt)

checkpoint = torch.load(args.ckpt)

if 'state' in checkpoint:
    net.load_state_dict(checkpoint['state'])
else:
    net.load_state_dict(checkpoint['net'])


get = {}


def nanmean(v, *args, inplace=False, **kwargs):
    if not inplace:
        v = v.clone()
    is_nan = torch.isnan(v)
    v[is_nan] = 0
    return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)


for i, (n, subtype) in enumerate(zip(n_list, subtypes)):
    memory_bank = NearestMemoryManager(inputSize=args.d_feature, outputSize=n + args.num_noise * args.max_group,
                                       K=1, num_noise=args.num_noise, num_pos=n, momentum=args.adj_momentum)
    memory_bank = memory_bank.cuda()

    print('subtype: ', subtype, end='\t')
    if isinstance(checkpoint['memory'], list):
        this_mem = checkpoint['memory'][i]
    else:
        this_mem = checkpoint['memory']
    with torch.no_grad():
        print('number points:', n, end='\t')
        memory_bank.memory.copy_(this_mem[0:memory_bank.memory.shape[0]])

    if save_features is not None:
        get['memory_%s' % subtype] = this_mem[0:memory_bank.memory.shape[0]].detach().cpu().numpy()
        get['clutter_%s' % subtype] = this_mem[memory_bank.memory.shape[0]::].detach().cpu().numpy()
        get['names_%s' % subtype] = []

    if len(azum_sel) > 0:
        list_path = 'lists3D_%s_azum_%s' % (mesh_d, use_azum_data)
    else:
        list_path = 'lists3D_%s' % mesh_d
    anno_path = 'annotations3D_%s' % mesh_d

    Pascal3D_dataset = Pascal3DPlus(transforms=transforms, rootpath=args.root_path, imgclass=args.type_,
                                      subtypes=[subtype], mesh_path=args.mesh_path, anno_path=anno_path, 
                                      list_path=list_path, weighted=True, data_pendix=args.data_pendix)
    Pascal3D_dataloader = torch.utils.data.DataLoader(Pascal3D_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)

    all_visible = torch.zeros((n, ), dtype=torch.long)
    all_correct = torch.zeros((n, ), dtype=torch.long)

    print('number images:', len(Pascal3D_dataset))

    with torch.no_grad():
        final_res = []
        for j, sample in enumerate(Pascal3D_dataloader):
            img, keypoint, iskpvisible, this_name, box_obj = sample['img'], sample['kp'], sample['iskpvisible'], sample['this_name'], sample['box_obj']
            img = img.cuda()

            feature_map = net.module.forward_test(img)

            if save_features is not None:
                for i, n in enumerate(this_name):
                    get[n] = feature_map[i].detach().cpu().numpy()
                    get['names_%s' % subtype].append(n)

            # Only for evaluate dense keypoint detection ability of the backbone. Not necessary for NeMo
            if args.eval_kp_score:
                keypoint = keypoint.cuda()

                iskpvisible = iskpvisible > 0
                iskpvisible = iskpvisible.cuda()
                obj_mask = sample['obj_mask']
                obj_mask = obj_mask.cuda()
                hmap = F.conv2d(feature_map, memory_bank.memory.unsqueeze(2).unsqueeze(3))

                stride_ = net_stride[args.backbone]
                obj_mask = F.max_pool2d(obj_mask.unsqueeze(dim=1), kernel_size=stride_, stride=stride_, padding=(stride_ - 1) // 2)

                hmap = hmap * obj_mask     

                # [n, k, h,w]
                x_ = hmap.size(3)
                hmap = hmap.view(*hmap.shape[0:2], -1)
                
                _, max_ = torch.max(hmap, dim=2)
                max_idx = torch.zeros((*hmap.shape[0:2], 2), dtype=torch.long).to(hmap.device)
                max_idx[:, :, 0] = max_ // x_
                max_idx[:, :, 1] = max_ % x_
                
                max_idx = max_idx * stride_ + stride_ // 2
                
                # [n, k]
                distance = torch.sum((max_idx - keypoint) ** 2, dim=2) ** 0.5
                
                # [n, k]
                correct_keypoints = (distance <= thr * torch.max(box_obj[0], box_obj[1]).view(-1, 1).cuda()).type(torch.long).to(iskpvisible.device)
                
                # [k]
                correct_keypoints = torch.sum(iskpvisible * correct_keypoints, dim=0).cpu()
                
                # [k]
                visible_keypoints = torch.sum(iskpvisible, dim=0).cpu()
                
                all_visible += visible_keypoints.type(torch.long)
                all_correct += correct_keypoints.type(torch.long)

        if args.eval_kp_score:
            print('acc:', all_correct.type(torch.float32) / all_visible.type(torch.float32))
            print('avg:', nanmean(all_correct.type(torch.float32) / all_visible.type(torch.float32)))


def find_2nd(string, substring):
   return string.find(substring, string.find(substring) + 1)


if save_features is not None:
    os.makedirs(save_features[0:find_2nd(save_features, '/')], exist_ok=True)
    np.savez(save_features, **get)
