import torch
from panda import Panda
from loaddata import PandaSetLoader, pandaset_test_collate_fn
from helper import loadconfig, lbl2comparison, lbl2distortion, lbl2sev
from torch.utils.data import DataLoader
import argparse
from deepspeed_train import collate_accuracy
from functools import partial
import json
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

def load_json(path):
    with open(path, "r") as f:
        data = json.load(f)
    return data

def run_inference(model, 
                  test_dataloader,
                  device, 
                  name_of_exp,
                  batchsize=1):

    # these are used for SRCC/PLCC calculation
    a_score_pred_lst = []
    a_score_gt_lst = []
    t_score_pred_lst = []
    t_score_gt_lst = []

    psg_json = load_json("pandaset/psg/psg_annots/psg.json")
    predicate_classes = psg_json['predicate_classes']

    img_id = 0
    for batch in tqdm(test_dataloader):
        img_id += 1
        # predicted graph
        distortion_graph = {
            "objects": [],
            "attributes": [],
            "relationships": [],
            "art": []
        }

        # unroll the batch
        names = batch['names'][0]
        relations = batch["relations"][0]
        category_ids = batch["category_ids"][0]
        count = 0

        anchor_img, target_img = batch['orig_anchor'], batch['orig_target']
        orig_anchor_box, orig_target_box = batch['orig_anchor_bbox'], batch['orig_target_bbox']
        imgA, imgB = batch["anchor"], batch["target"]
        imgA_bbs, imgB_bbs = batch["anchor_bbox"], batch["target_bbox"]
        severities, distortions, comparisons, scores = batch["severity"], batch["distortion"], batch["comparison"], batch["scores"]
        region_mask_flags = batch["region_mask_flags"]
        
        (imgA, imgB, severities, 
         distortions, comparisons, 
         scores, region_mask_flags) = (imgA.to(device), imgB.to(device),
                                       severities.to(device), distortions.to(device),
                                       comparisons.to(device), scores.to(device), 
                                       region_mask_flags.to(device))
        anchor_masks, target_masks = batch["anchor_seg_masks"], batch["target_seg_masks"]
        anchor_masks, target_masks = anchor_masks.to(device), target_masks.to(device)
        orig_anchor_masks, orig_target_masks = batch["orig_anchor_seg_masks"], batch["orig_target_seg_masks"]
        with torch.no_grad():
            preds, losses, valid_masks = model(imgA, imgB, imgA_bbs, imgB_bbs,
                                            anchor_masks, target_masks,
                                            severities, distortions,
                                            comparisons, scores, 
                                            region_mask_flags)
            
        # compute per-data accuracy
        gts = [comparisons, distortions, severities, scores]
        _, pred_gt_dct = collate_accuracy(preds, gts, valid_masks)
        
        # fetch relationships
        comp_pred = pred_gt_dct["comparison_masked_preds"]
        comp_gts = pred_gt_dct["comparison_masked_gts"]

        # fetch attributes
        # three attributes per node related to distortion
        a_dist_masked_preds = pred_gt_dct["a_dist_masked_preds"]
        a_dist_masked_gts = pred_gt_dct["a_dist_masked_gts"]
        t_dist_masked_preds = pred_gt_dct["t_dist_masked_preds"]
        t_dist_masked_gts = pred_gt_dct["t_dist_masked_gts"]

        a_sev_masked_preds = pred_gt_dct["a_sev_masked_preds"]
        a_sev_masked_gts = pred_gt_dct["a_sev_masked_gts"]
        t_sev_masked_preds = pred_gt_dct["t_sev_masked_preds"]
        t_sev_masked_gts = pred_gt_dct["t_sev_masked_gts"]

        a_score_masked_preds = pred_gt_dct["a_score_masked_preds"]
        a_score_masked_gts = pred_gt_dct["a_score_masked_gts"]
        t_score_masked_preds = pred_gt_dct["t_score_masked_preds"]
        t_score_masked_gts = pred_gt_dct["t_score_masked_gts"]

        # keep this separate for SRCC/PLCC
        a_score_pred_lst.append(a_score_masked_preds.detach().cpu().numpy())
        a_score_gt_lst.append(a_score_masked_gts.detach().cpu().numpy())
        t_score_pred_lst.append(t_score_masked_preds.detach().cpu().numpy())
        t_score_gt_lst.append(t_score_masked_gts.detach().cpu().numpy())
        
        comp_gts = comp_gts.reshape(batchsize, -1)
        b,regions = comp_gts.shape
        for region in range(regions):
            region_relationship_pred = lbl2comparison(int(comp_pred.squeeze(0)[region]))
            region_relationship_gt = lbl2comparison(int(comp_gts.squeeze(0)[region]))

            object_name = names[region]
            object_id = region

            # get the region from image
            region_bounding_box = orig_anchor_box[0][region] # coco format
            distortion_graph["objects"].append({
                "id": str(object_id),
                "name": str(object_name),
                "image": str(1) # 1 is for anchor
            })
            distortion_graph["objects"].append({
                "id": str(object_id+regions), # for target, it is regions+object_id
                "name": str(object_name),
                "image": str(2) # 2 is for target
            })
            # ART (<Anchor, Relation, Target>)
            # the subject and object are same but in different images
            distortion_graph["art"].append({
                "predicate": str(region_relationship_pred), 
                "object": str(object_id), # from the anchor
                "subject": str(object_id+regions) # from the target
            })

            anchor_distortion_pred = lbl2distortion(a_dist_masked_preds.squeeze(0)[region].item())
            anchor_distortion_gt = lbl2distortion(a_dist_masked_gts.squeeze(0)[region].item())
            target_distortion_pred = lbl2distortion(t_dist_masked_preds.squeeze(0)[region].item())
            target_distortion_gt = lbl2distortion(t_dist_masked_gts.squeeze(0)[region].item())
            
            anchor_sev_pred = lbl2sev(a_sev_masked_preds.squeeze(0)[region].item())
            anchor_sev_gt = lbl2sev(a_sev_masked_gts.squeeze(0)[region].item())
            target_sev_pred = lbl2sev(t_sev_masked_preds.squeeze(0)[region].item())
            target_sev_gt = lbl2sev(t_sev_masked_gts.squeeze(0)[region].item())

            a_score_pred = a_score_masked_preds.squeeze(0)[region].item()
            a_score_gt = a_score_masked_gts.squeeze(0)[region].item()
            t_score_pred = t_score_masked_preds.squeeze(0)[region].item()
            t_score_gt = t_score_masked_gts.squeeze(0)[region].item()

            # Distortion Attributes (across images)
            distortion_graph["attributes"].append({
                "attribute": str(anchor_distortion_pred), 
                "object": str(object_id),
                "image": str(1),
            })
            distortion_graph["attributes"].append({
                "attribute": str(target_distortion_pred), 
                "object": str(object_id+regions),
                "image": str(2),
            })
            # severity
            distortion_graph["attributes"].append({
                "attribute": str(anchor_sev_pred),
                "object": str(object_id),
                "image": str(1),
            })
            distortion_graph["attributes"].append({
                "attribute": str(target_sev_pred), 
                "object": str(object_id+regions),
                "image": str(2),
            })
            # scores
            distortion_graph["attributes"].append({
                "attribute": str(round(a_score_pred,4)), 
                "object": str(object_id),
                "image": str(1),
            })
            distortion_graph["attributes"].append({
                "attribute": str(round(t_score_pred,4)), 
                "object": str(object_id+regions),
                "image": str(2),
            })

            # fetch scene information from this
            category_id = category_ids[region]
            region_specific_relations = [x for x in relations[0][0] if count in x[:2]]
            for s_idx, o_idx, rel_id in region_specific_relations:
                # scene relationships
                distortion_graph["relationships"].append({
                        "predicate": predicate_classes[rel_id], 
                        "object": str(o_idx),
                        "subject": str(s_idx),
                        "image": str(1),
                })
                
                distortion_graph["relationships"].append({
                        "predicate": predicate_classes[rel_id], 
                        "object": str(o_idx),
                        "subject": str(s_idx),
                        "image": str(2),
                })
            
        with open(f"inf_graphs/img_id_{img_id}_{name_of_exp}.json", "w") as f:
            json.dump(distortion_graph, f, indent=4)
            
def main():
    parser = argparse.ArgumentParser(description="DistortionGraph!")
    parser.add_argument('--configpath', type=str, help='Config Path.')
    args = parser.parse_args()
    
    # read config and loggers
    config = loadconfig(args.configpath)
    test_dgbench = PandaSetLoader(config["general"]["datapath"],
                                  config["general"]["stats"],
                                  config["general"]["resize_shape"],
                                  mode="test")
    h = w = config['general']['resize_shape']
    test_dataloader = DataLoader(test_dgbench,
                                 batch_size=1,
                                 shuffle=True,
                                 collate_fn=partial(pandaset_test_collate_fn, h=h, w=w))
    print(f"Total Images to Process: {len(test_dataloader)}")
    
    # load the model
    device_no = config["general"]["device"]
    device = torch.device("cuda:{}".format(device_no) if torch.cuda.is_available() else "cpu")
    model = Panda(config, device) 
    ckpt_path = config['inference'].get('ckpt', None)
    if ckpt_path is not None:
        checkpoint = torch.load(ckpt_path)
        model.load_state_dict(checkpoint['model_state_dict'], strict=True)
        print(f"Model Loaded!")
        model = model.to(device)
        model.eval() # put in eval mode
    else:
        raise ValueError(f"No ckpt path defined.")
    name_of_exp = ckpt_path.split('/')[-2]
    run_inference(model, 
                  test_dataloader,
                  device, 
                  name_of_exp)

if __name__ == "__main__":
    main()