import os
import sys
import json
import pickle
import argparse
import importlib
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
import time
import yaml
import glob

from torch.utils.data import DataLoader
from datetime import datetime
from tqdm import tqdm
from copy import deepcopy
from munch import Munch

sys.path.append(os.path.join(os.getcwd())) # HACK add the root folder
from config.config import CONF

from lib.joint.dataset_2clip import ScannetReferenceDataset
from lib.joint.eval_ground import get_eval_clip
from lib.joint.eval_ground import get_eval_m3d as get_eval_m3d

from models.jointnet.jointnet_2clip import JointNet
from data.scannet.model_util_scannet import ScannetDatasetConfig, SunToScannetDatasetConfig


print('Import Done', flush=True)
if CONF.dataset == "ScanRefer":
    SCANREFER_TRAIN = json.load(open(os.path.join(CONF.PATH.DATA, "scanrefer", "ScanRefer_train_new.json")))
    SCANREFER_VAL = json.load(open(os.path.join(CONF.PATH.DATA, "scanrefer", "ScanRefer_val_new.json")))
elif CONF.dataset == "nr3d":
    SCANREFER_TRAIN = json.load(open(os.path.join(CONF.PATH.DATA, "nr3d", "nr3d_train_sorted.json")))
    SCANREFER_VAL = json.load(open(os.path.join(CONF.PATH.DATA, "nr3d", "nr3d_val_sorted.json")))
elif CONF.dataset == "sr3d":
    SCANREFER_TRAIN = json.load(open(os.path.join(CONF.PATH.DATA, "sr3d", "sr3d_train_sorted.json")))
    SCANREFER_VAL = json.load(open(os.path.join(CONF.PATH.DATA, "sr3d", "sr3d_val_sorted.json")))
elif CONF.dataset == "multi3drefer":
    SCANREFER_TRAIN = json.load(open(os.path.join(CONF.PATH.DATA, "multi3drefer", "multi3drefer_train_sorted.json")))
    SCANREFER_VAL = json.load(open(os.path.join(CONF.PATH.DATA, "multi3drefer", "multi3drefer_val_sorted.json")))
# SCANREFER_VAL = json.load(open(os.path.join(CONF.PATH.DATA, "ScanRefer_filtered_test.json")))
DC = ScannetDatasetConfig() if CONF.pretrain_data == "scannet" else SunToScannetDatasetConfig()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_dataloader(args, scanrefer, scanrefer_new, all_scene_list, cfg):
    dataset = ScannetReferenceDataset(
        scanrefer=scanrefer,
        scanrefer_new=scanrefer_new,
        scanrefer_all_scene=all_scene_list, 
        data_root=cfg.data_root,
        prefix=cfg.prefix,
        suffix=cfg.suffix,
        voxel_cfg=cfg.voxel_cfg,
        training=cfg.training,
        with_label=cfg.with_label if "with_label" in cfg else False,
        repeat=cfg.repeat if "repeat" in cfg else 0,
        name=args.dataset,
        lang_num_max=args.lang_num_max,
    )
    print("evaluate on {} samples".format(len(dataset)))

    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=dataset.collate_fn)

    return dataset, dataloader

def get_model(args, cfg):
    # load model
    # input_channels = int(args.use_multiview) * 128 + int(args.use_normal) * 3 + int(args.use_color) * 3 + int(not args.no_height)
    input_channels = int(not args.no_height) + int(args.use_color) * 3
    model = JointNet(
        num_class=DC.num_class,
        class_name=DC.class2type,
        input_feature_dim=input_channels,
        width=args.width,
        hidden_size=args.hidden_size,
        num_proposal=args.num_proposals,
        num_target=args.num_target,
        no_caption=args.no_caption,
        use_topdown=args.use_topdown,
        num_locals=args.num_locals,
        query_mode=args.query_mode,
        use_lang_classifier=(not args.no_lang_cls),
        use_bidir=args.use_bidir,
        no_reference=args.no_reference,
        args=args,
        cfg=cfg
    ).cuda()


    if args.pretrain_model == "softgroup":
        print("loading pretrained SoftGroup...")
        pretrained_votenet_weights = torch.load(CONF.PATH.SOFTGROUP_PRETRAIN)
        # print(pretrained_votenet_weights["model_state_dict"].keys())
        model.softgroup.load_state_dict(pretrained_votenet_weights['net'], strict=False)
    
    model_name = "model.pth"
    # if args.fordebug:
    #     path = os.path.join(CONF.PATH.BASE, 'outputs/res', args.folder, model_name)
    # else:
        # path = os.path.join(CONF.PATH.BASE, 'outputs/res', args.folder, model_name)
    path = os.path.join(CONF.PATH.BASE, 'pretrained', args.folder, model_name)
    print("loading pretrained text_classifier...")
    model.lang_cls.load_state_dict(torch.load(path), strict=False)
    model.eval()

    return model

def get_scannet_scene_list(split):
    scene_list = sorted([line.rstrip() for line in open(os.path.join(CONF.PATH.SCANNET_META, "scannetv2_{}.txt".format(split)))])

    return scene_list

def get_scanrefer(args, split):
    if not args.no_detection:
        scene_list = get_scannet_scene_list(split)
        scanrefer = []
        for scene_id in scene_list:
            data = deepcopy(SCANREFER_TRAIN[0])
            data["scene_id"] = scene_id
            scanrefer.append(data)
    else:
        scanrefer = SCANREFER_TRAIN if split == 'train' else SCANREFER_VAL
        scene_list = sorted(list(set([data["scene_id"] for data in scanrefer])))
        if args.num_scenes != -1:
            scene_list = scene_list[:args.num_scenes]

        scanrefer = [data for data in scanrefer if data["scene_id"] in scene_list]

        new_scanrefer_val = scanrefer
        scanrefer_new = []
        scanrefer_val_new_scene = []
        scene_id = ""
        for data in scanrefer:
            # if data["scene_id"] not in scanrefer_val_new:
            # scanrefer_val_new[data["scene_id"]] = []
            # scanrefer_val_new[data["scene_id"]].append(data)
            if scene_id != data["scene_id"]:
                scene_id = data["scene_id"]
                if len(scanrefer_val_new_scene) > 0:
                    scanrefer_new.append(scanrefer_val_new_scene)
                scanrefer_val_new_scene = []
            if len(scanrefer_val_new_scene) >= args.lang_num_max:
                scanrefer_new.append(scanrefer_val_new_scene)
                scanrefer_val_new_scene = []
            scanrefer_val_new_scene.append(data)
        if len(scanrefer_val_new_scene) > 0:
            scanrefer_new.append(scanrefer_val_new_scene)

        new_scanrefer_eval_val2 = []
        scanrefer_eval_val_new2 = []
        for scene_id in scene_list:
            data = deepcopy(SCANREFER_VAL[0])
            data["scene_id"] = scene_id
            new_scanrefer_eval_val2.append(data)
            scanrefer_eval_val_new_scene2 = []
            for i in range(args.lang_num_max):
                scanrefer_eval_val_new_scene2.append(data)
            scanrefer_eval_val_new2.append(scanrefer_eval_val_new_scene2)

    return scanrefer, scene_list, scanrefer_new

def decode_stimulus_string(s):
    """
    Split into scene_id, instance_label, # objects, target object id,
    distractors object id.

    :param s: the stimulus string
    """
    if len(s.split('-', maxsplit=4)) == 4:
        scene_id, instance_label, n_objects, target_id = \
            s.split('-', maxsplit=4)
        distractors_ids = ""
    else:
        scene_id, instance_label, n_objects, target_id, distractors_ids = \
            s.split('-', maxsplit=4)

    instance_label = instance_label.replace('_', ' ')
    n_objects = int(n_objects)
    target_id = int(target_id)
    distractors_ids = [int(i) for i in distractors_ids.split('-') if i != '']
    assert len(distractors_ids) == n_objects - 1

    return scene_id, instance_label, n_objects, target_id, distractors_ids


def eval_ref(args):
    print("evaluate localization...")
    # split = "train"
    split = "val"
    save_iou = 'y'
    # init training dataset
    print("preparing data...")
    scanrefer, scene_list, scanrefer_new = get_scanrefer(args, split)
    if args.config is not None:
        cfg_txt = open(args.config, 'r').read()
        cfg = Munch.fromDict(yaml.safe_load(cfg_txt))
    # dataloader
    if split == 'train':
        dataset, dataloader = get_dataloader(args, scanrefer, scanrefer_new, scene_list, cfg.data.train_eval)
    else:
        dataset, dataloader = get_dataloader(args, scanrefer, scanrefer_new, scene_list, cfg.data.test)

    # model
    model = get_model(args, cfg)
    print("\nparam stats:")
    print("model:", count_parameters(model))
    print("backbone:", count_parameters(model.softgroup))
    # print("recon:", count_parameters(model.lang))
    # print("match:", count_parameters(model.match))

    # random seeds
    seeds = [args.seed] + [2 * i for i in range(args.repeat - 1)]
    mean_forward_time = 0
    
    stamp = ""
    stamp += "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    root = os.path.join(CONF.PATH.OUTPUT, stamp)
    os.makedirs(root, exist_ok=True)

    # evaluate
    print("save on ", split, " ... ")
    if args.nodetect:
        print("Do Not Detect!")
    path_2_prop = os.path.join(CONF.PATH.DATA, 'proposals', split)
    exit_proposal = sorted(glob.glob(path_2_prop + '/scene*'))

    for seed in seeds:
        # reproducibility
        torch.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        print("generating the scores for seed {}...".format(seed))
        predictions = {}
        proposals = []
        ious, topk_ious, max_ious, nt_labels = [], [], [], []
        scan_ids, object_ids, ann_ids = [], [], []
        meta_datas, view_dependents= [], []
        debug_count = 0
        for data in tqdm(dataloader):
            for key in data:
                if hasattr(data[key], 'cuda'):
                    data[key] = data[key].cuda()
                elif key == 'gt_masks' or key == 'prop':
                    for i in range(len(data[key])):
                        data[key][i] = data[key][i].cuda()              
            # feed
            with torch.no_grad():
                data = model(data, is_eval=args.is_eval)
                scan_ids.extend(data['scan_ids'])
                ann_ids.extend(data['ann_ids'])
                object_ids.extend(data['object_id_list'])
                if data["meta_datas"][0] is not None:
                    meta_datas.extend(data['meta_datas'])
                    if data["view_dependents"][0] is not None:
                        view_dependents.extend(data['view_dependents'])
                if args.dataset == 'multi3drefer':
                    data = get_eval_m3d(data_dict=data)
                    ious.extend(data["ref_iou"])
                    nt_labels.extend(data["nt_labels"])
                else:
                    data = get_eval_clip(data_dict=data, k=3)
                    ious.extend(data["ref_iou"])
                    topk_ious.extend(data["topk_iou"])
                    max_ious.extend(data["max_iou"])

                batch_size, len_num_max = data['ann_id_list'].shape[:2]
                lang_num = data["lang_num"]
                weights = data["coarse_weights"]
                for i in range(batch_size):
                    scene_id = data["scene_ids"][i]
                    ann_id_list = data["ann_id_list"][i].cpu().numpy()

                    if scene_id not in predictions:
                        predictions[scene_id] = {}

                    for j, ann_id in enumerate(ann_id_list):
                        if j < lang_num[i]:
                            ann_id = int(ann_id)
                            if ann_id not in predictions[scene_id]:
                                predictions[scene_id][ann_id] = {}
                            predictions[scene_id][ann_id]["weights"] = weights[i][j].detach().cpu().numpy().tolist()
                        else:
                            break
            torch.cuda.empty_cache()


    ious = np.array(ious)
    output = {}
    if args.dataset == 'multi3drefer':
        eval_dict = {"zt_w_d": 0, "zt_wo_d": 1, "st_w_d": 2, "st_wo_d": 3, "mt": 4}
        eval_type_mask = np.empty(len(scan_ids))
        iou_out = []
        for idx, scan_id in enumerate(scan_ids):
            tmp = {}
            tmp['scene_id']= scan_id
            tmp['ann_id']= ann_ids[idx]
            tmp['object_id'] = object_ids[idx]
            tmp['eval_type'] = meta_datas[idx]['eval_type']
            eval_type = meta_datas[idx]['eval_type']
            eval_type_mask[idx] = eval_dict[eval_type]
            if nt_labels[idx]:
                ious[idx] = 0.0
            
            if eval_type in ("zt_wo_d", "zt_w_d"):
                if nt_labels[idx]:
                    ious[idx] = 1.0
                else:
                    ious[idx] = 0.0 
            
            tmp['iou']= float(ious[idx])
            iou_out.append(tmp)

        if save_iou == 'y':
            iou_path = os.path.join(CONF.PATH.OUTPUT, root, 'ious_' + split + '.json')
            with open(iou_path, "w") as json_file:
                json.dump(iou_out, json_file, indent=2) 
        
        for sub_group in ("zt_w_d", "zt_wo_d", "st_w_d", "st_wo_d", "mt"):
            selected_indices = eval_type_mask == eval_dict[sub_group]
            selected = ious[selected_indices]
            print(sub_group + ":\tmiou:" + str(round(selected.mean(), 4)) + ' \tAcc_25: '+ str(round((selected > 0.25).mean(), 4))+' \tAcc_50: '+str(round((selected > 0.5).mean(), 4)))
            output[sub_group]={}
            output[sub_group]["miou"] = selected.mean()
            output[sub_group]["acc_25"] = (selected > 0.25).mean()
            output[sub_group]["acc_50"] = (selected > 0.5).mean()

    else:
        # if split == 'val':
        if split == 'val' or split == 'train':
            if save_iou == 'y':
                iou_out = {}
                for idx, scan_id in enumerate(scan_ids):
                    sample = scan_id + '_' + str(ann_ids[idx]) + '_' + str(object_ids[idx])
                    iou_out[sample] = {}
                    iou_out[sample]['iou']=float(ious[idx])
                    
                iou_path = os.path.join(CONF.PATH.OUTPUT, root, 'ious_' + split + '.json')
                with open(iou_path, "w") as json_file:
                    json.dump(iou_out, json_file, indent=2) 

            # scanrefer
            if args.dataset == 'ScanRefer':
                with open(os.path.join(CONF.PATH.DATA, "lookup_new.json"),'r') as load_f:
                    unique_multi_lookup = json.load(load_f)
                unique, multi = [], []
                for idx, scan_id in enumerate(scan_ids):
                    if unique_multi_lookup[scan_id][str(object_ids[idx][0])][str(ann_ids[idx])] == 0:
                        unique.append(ious[idx])
                    else:
                        multi.append(ious[idx])
                unique = np.array(unique)
                multi = np.array(multi)
                output["unique"] = {}
                output["multi"] = {}
                output["overall"] = {}
                for u in [0.25, 0.5]:
                    print(f'Acc@{u}: \tunique: '+str(round((unique>u).mean(), 4))+' \tmulti: '+str(round((multi>u).mean(), 4))+' \tall: '+str(round((ious>u).mean(), 4)))
                    output["unique"][str(u)] = (unique>u).mean()
                    output["multi"][str(u)] = (multi>u).mean()
                    output["overall"][str(u)] = (ious>u).mean()
                print('mIoU:\t \tunique: '+str(round(unique.mean(), 4))+' \tmulti: '+str(round(multi.mean(), 4))+' \tall: '+str(round(ious.mean(), 4)))
                output["unique"]["miou"] = unique.mean()
                output["multi"]["miou"] = multi.mean()
                output["overall"]["miou"] = ious.mean()
            
            # referit3d
            elif args.dataset == 'nr3d' or args.dataset == 'sr3d':
                hardness = [decode_stimulus_string(meta_data)[2] for meta_data in meta_datas]
                ious_vd, ious_vind, ious_easy, ious_hard = [], [], [], []
                for idx, scan_id in enumerate(scan_ids):
                    piou = ious[idx]      
                    if len(meta_datas)>0:
                        if hardness[idx] > 2:
                            ious_hard.append(piou.item())
                        else:
                            ious_easy.append(piou.item())
                        
                        if view_dependents[idx]:
                            ious_vd.append(piou.item())
                        else:
                            ious_vind.append(piou.item())   
                
                ious_hard = np.array(ious_hard)
                ious_easy = np.array(ious_easy)
                ious_vd = np.array(ious_vd)
                ious_vind = np.array(ious_vind)
                output["view_dep"] = {}
                output["view_indep"] = {}
                output["esay"] = {}
                output["hard"] = {}
                output["overall"] = {}
                for u in [0.25, 0.5]:
                    print(f'Acc@{u}: \tview_dep: '+str(round((ious_vd>u).mean(), 4))+' \tview_indep: '+str(round((ious_vind>u).mean(), 4)))
                    print(f'Acc@{u}: \tesay: '+str(round((ious_easy>u).mean(), 4))+' \thard: '+str(round((ious_hard>u).mean(), 4)))
                    print(f'Acc@{u}: \toverall: '+str(round((ious>u).mean(), 4)))
                    output["view_dep"][str(u)] = (ious_vd>u).mean()
                    output["view_indep"][str(u)] = (ious_vind>u).mean()
                    output["esay"][str(u)] = (ious_easy>u).mean()
                    output["hard"][str(u)] = (ious_hard>u).mean()
                    output["overall"][str(u)] = (ious>u).mean()
                print('mIoU:\t \tview_dep: '+str(round(ious_vd.mean(), 4))+' \tview_indep: '+str(round(ious_vind.mean(), 4)))
                print('mIoU:\t \tview_dep: '+str(round(ious_easy.mean(), 4))+' \tview_indep: '+str(round(ious_hard.mean(), 4)))
                print('mIoU:\t \toverall: '+str(round(ious.mean(), 4)))
                output["view_dep"]['miou'] = ious_vd.mean()
                output["view_indep"]['miou'] = ious_vind.mean()
                output["esay"]['miou'] = ious_easy.mean()
                output["hard"]['miou'] = ious_hard.mean()
                output["overall"]['miou'] = ious.mean()
        
        topk_ious = np.array(topk_ious)
        topk_acc25 = (topk_ious > 0.25).sum().astype(float) / topk_ious.size
        topk_acc50 = (topk_ious > 0.5).sum().astype(float) / topk_ious.size
        max_ious = np.array(max_ious)
        max_acc25 = (max_ious > 0.25).sum().astype(float) /  max_ious.size
        max_acc50 = (max_ious > 0.5).sum().astype(float) /  max_ious.size
        print("topk_acc25: ", topk_acc25)
        print("topk_acc50: ", topk_acc50)
        output["topk_acc25"] = topk_acc25
        output["topk_acc50"] = topk_acc50
        output["max_acc25"] = max_acc25
        output["max_acc50"] = max_acc50


    miou = ious.mean()
    acc25 = (ious > 0.25).sum().astype(float) / ious.size
    acc50 = (ious > 0.5).sum().astype(float) / ious.size
    print("miou: ", miou)
    print("ref_acc25: ", acc25)
    print("ref_acc50: ", acc50)
    output["miou"] = miou
    output["ref_acc25"] = acc25
    output["ref_acc50"] = acc50

    output_path = os.path.join(CONF.PATH.OUTPUT, root, split + '_test_result.json')
    with open(output_path, "w") as json_file:
        json.dump(output, json_file, indent=2)   
    # save output
    # save_weight = input('If you want to save the results? (y/n)')
    save_weight = 'y'
    # save_weight = 'n'
    if save_weight == 'y':
        weight_path = os.path.join(CONF.PATH.OUTPUT, root, 'weight_' + split + '.json')
        with open(weight_path, "w") as json_file:
            json.dump(predictions, json_file, indent=2)
    else:
        print('Not saving results.')
        exit()

    print("Done!")


if __name__ == "__main__":
    eval_ref(CONF)

