import os

import json
import random
from argparse import ArgumentParser
from tqdm import tqdm
import torch
import scallopy

from dataset_test_prog import *
from utils import *

def save_scl_file(datapoint_id, object_tps, scl_dir=""):
    scl_file_content = construct_scl_facts(object_tps)
    scl_file_name = datapoint_id + '.scl'
    scl_path = os.path.join(scl_dir, scl_file_name)
    
    with open(scl_path, 'w') as scl_file:
        scl_file.write(scl_file_content)
        
class Trainer():
    def __init__(self, train_loader, test_loader, device,
                 caption2scl, common_scl_path,
                 provenance="difftopkproofs", k=3, save_scl_dir=None,
                 model_dir=None, model_name=None, learning_rate=None,
                 latent_dim=64, model_layer=2, load_model=False, save_model=True,
                 save_video=False, video_save_dir=None, violation_weight=0.1,
                 with_violation=True, use_contrast=False, world_size=4, parallel=False, 
                 use_neg_sampling=True, train_num_top_pairs=100, test_num_top_pairs=300, report_dir=None):

        self.save_video = save_video
        self.video_save_dir = video_save_dir

        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.caption2scl = caption2scl
        self.template2action = {}
        self.common_scl = open(common_scl_path).read()
        self.save_model = save_model
        self.with_violation = with_violation
        self.use_contrast = use_contrast
        self.use_neg_sampling = use_neg_sampling
        self.train_num_top_pairs = train_num_top_pairs
        self.test_num_top_pairs = test_num_top_pairs

        self.scallop_ctx = scallopy.ScallopContext(provenance=provenance, k=k)
        self.scallop_ctx.import_file(common_scl_path)
        self.scallop_ctx.set_non_probabilistic(non_prob_gpt_prog_str_preds)
        self.epoch_ct = 0
        self.report_dir = report_dir
        
        # self.reason = self.scallop_ctx.forward_function("constraint")
        if with_violation:
            self.reason = self.scallop_ctx.forward_function(output_mappings={
                "temporal_diff": None,
                "violation": (),
            }, retain_graph=True)
        else:
            self.reason = self.scallop_ctx.forward_function(output_mappings={
                "aligned_t1": None,
                "aligned_t2": None, 
                "aligned_t3": None,
            # }, dispatch="single")
            }, retain_graph=True)
            self.reason.to(self.device)

        if not save_scl_dir is None:
            self.save_scl_dir = save_scl_dir
                
    def baseline_eval_batch(self, batch):
        
        batched_ids, batched_captions, \
            batched_obj_pairs, batched_object_ids, \
            batched_gt_labels, batched_gpt_specs = batch
            
        if len(batched_object_ids) == 0:
            return {}
                
        batch_size = len(batched_ids)
        
        batched_start_id = {}
        batched_end_id = {}
        for vid, fid, _ in batched_object_ids:
            if not vid in batched_start_id:
                batched_start_id[vid] = fid
            if fid < batched_start_id[vid]:
                batched_start_id[vid] = fid
                
            if not vid in batched_end_id:
                batched_end_id[vid] = fid
            if fid > batched_end_id[vid]:
                batched_end_id[vid] = fid
        
        cleaned_batched_object_pairs = []
        batched_unary_kws = [[]] * batch_size
        batched_binary_kws = []
        batched_consts = []
        batched_obj_labels = {}
        
        for spec in batched_gpt_specs:
            batched_unary_kws.append(list(set(spec['unary_kws'])))
            batched_binary_kws.append(list(set(spec['binary_kws'])))
            batched_consts.append(list(set(spec['consts'])))
        
        consts = [e for c in batched_consts for e in c]

        const_lookup = {}
        for k, v in enumerate(consts):
            const_lookup[v] = -k - 1
            const_lookup[v.upper()] = -k - 1
            const_lookup[v.lower()] = -k - 1
            
        for (vid, fid, label), (vid, fid, oid) in zip(batched_gt_labels, batched_object_ids):
            
            if not vid in batched_obj_labels:
                batched_obj_labels[vid] = {}
            
            if oid in batched_obj_labels[vid]:
                assert label == batched_obj_labels[vid][oid]
            batched_obj_labels[vid][oid] = label
                
        batched_cate_pred_scl = []
        batched_unary_pred_scl = []
        for vid, obj_labels in batched_obj_labels.items():
            new_cate_pred_scl = []
            new_unary_pred_scl = []
            
            for oid, cate_name in obj_labels.items():
                if cate_name in const_lookup:
                    new_cate_pred_scl.append((torch.tensor(1.0), (oid, const_lookup[cate_name])))
                else:
                    for const in batched_consts[vid]:
                        new_cate_pred_scl.append((torch.tensor(1.0 / len(batched_consts[vid])), (oid, const_lookup[const])))
                
                for unary_kw in batched_unary_kws[vid]:
                    new_unary_pred_scl.append(torch.tensor(1.0 / len(batched_unary_kws[vid])),  (oid, unary_kw))
                
            batched_cate_pred_scl.append(new_cate_pred_scl)
            batched_unary_pred_scl.append(new_unary_pred_scl)
            
        # Process binary predicates
        batched_binary_pred_scl = []

        for vid, binary_kws in enumerate(batched_binary_kws):
            binary_pred_scl = []
            for binary_kw in binary_kws:
                for (pair_vid, fid, (sub, obj)) in batched_obj_pairs:
                    if vid == pair_vid:
                        binary_pred_scl.append((torch.tensor(1.0 / len(binary_kws)) , (binary_kw, fid, sub, obj)))

            batched_binary_pred_scl.append(binary_pred_scl)
        
        # batched_object_tps = get_object_tps(batched_object_names, batched_object_ids, const_lookup, batch_size)
        batched_scl_tps = construct_batched_scl_tps(batched_consts, batched_object_ids)
        batched_scl_input_facts = []
        for vid, (data_id, caption, scl_tp, cate_pred_tp, unary_pred_tp, binary_pred_tp, gpt_spec) \
            in enumerate(zip(batched_ids, batched_captions, batched_scl_tps, batched_cate_pred_scl, batched_unary_pred_scl, batched_binary_pred_scl, batched_gpt_specs)):

            # Give an ID to all required placeholders and object names
            scl_input_facts = {}
            neg_scl_input_facts = {}

            scl_input_facts.update(scl_tp)
            scl_input_facts['name'] = (cate_pred_tp)
            scl_input_facts['sg_unary_atom'] = (unary_pred_tp)
            scl_input_facts['sg_binary_atom'] = (binary_pred_tp)
            scl_input_facts['variable'] = [tuple([i + 1]) for i in range(len(gpt_spec['args']))]
            scl_input_facts['spec'] = [gpt_spec['prog']]
            
            # scl_input_facts['name'] = [(vid, const_lookup[p]) for vid, p in zip(var2vid.values(), gpt_spec['args'])]
            # time_stamp_ids, time_stamp_fact_ls = self.caption2scl[caption]['time_stamp_ids'], self.caption2scl[caption]['time_stamp_facts']
            
            batched_scl_input_facts.append(scl_input_facts)
            
        formatted_batched_scl_input_facts = process_batched_facts(batched_scl_input_facts)
    
        output = self.reason(**formatted_batched_scl_input_facts)
        # save_scl_file(batched_ids[0], formatted_batched_scl_input_facts)
        
        return output
                
    def test_scallop_progs(self):

        total_results_unary = []
        total_results_binary = []
        no_answer_ids = []
        no_answer = 0
        total_questions = 0
        debug_path = ""
        
        print("Testing train split")
        with torch.no_grad():
            iterator = tqdm(self.train_loader)
            for ct, dp_list in enumerate(iterator):
                # if not dp_list[0][0] == "P14_06":
                #     continue
                # if ct > 5:
                #     break
                total_questions += 1
                result = self.baseline_eval_batch(dp_list)
                
                label_issue = 'aligned_t1' not in result
                if label_issue: 
                    continue
                
                has_no_answer = (len(result['aligned_t1'][0]) == 1 and result['aligned_t1'][0][0][0] == -1)

                if has_no_answer:
                    print('Warning: No anwer')
                    no_answer += 1
                    no_answer_ids.append(dp_list[0][0])
                    json.dump(no_answer_ids, open(debug_path, 'w'))
        
        print(f"We miss answers for {no_answer}/{total_questions}.")
        
        print("Testing test split")
        with torch.no_grad():
            iterator = tqdm(self.test_loader)
            for ct, dp_list in enumerate(iterator):
                # if not dp_list[0][0] == "P14_06":
                #     continue
                # if ct > 5:
                #     break
                total_questions += 1
                result = self.baseline_eval_batch(dp_list)
                if 'aligned_t1' not in result or len(result['aligned_t1'][0]) == 1 and result['aligned_t1'][0][0][0] == -1:
                    print('Warning: No anwer')
                    no_answer += 1
        
        print(f"We miss answers for {no_answer}/{total_questions}.")
        
                    
        return
      
if __name__ == "__main__":
    dataset = "open_pvsg"
    # cache_file_name = f"gpt_specs_scl.json"
    cache_file_name = f"gpt_specs_prog_str.json"
    data_file_name = 'pvsg.json'

    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f"../../../data/{dataset}"))
    data_nl_dir = os.path.join(data_dir, 'nl2spec') 
    cache_path = os.path.join(data_nl_dir, cache_file_name)
    scl_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), '../scl'))
    
    assert (os.path.exists(data_dir))
    if not os.path.exists(data_nl_dir):
        os.mkdir(data_nl_dir)
    assert os.path.exists(scl_dir)
    
    video_save_dir = os.path.join(data_dir, 'pred_video')
    model_dir = os.path.join(data_dir, 'model_orig')

    parser = ArgumentParser(dataset)
    parser.add_argument("--phase", type=str, default='test')
    parser.add_argument("--n-epochs", type=int, default=50)
    parser.add_argument("--load_model", default=False)
    parser.add_argument("--save_model", default=False)
    parser.add_argument("--video_save_dir", type=str, default=video_save_dir)
    parser.add_argument("--model_type", type=str, default="contrast")
    parser.add_argument("--violation-weight", type=float, default=0.05)
    parser.add_argument("--with-violation",  action='store_true')
    parser.add_argument("--use-contrast",  action='store_true')
    parser.add_argument("--parallel",  action='store_true')
    parser.add_argument("--train-num-top-pairs", type=int, default=100)
    parser.add_argument("--test-num-top-pairs", type=int, default=100)
    
    parser.add_argument("--max_video_len", type=int, default=12)
    parser.add_argument("--report-dir", type=str, default="")

    # setup question path
    parser.add_argument("--train_num", type=int, default=5000)
    parser.add_argument("--val_num", type=int, default=1000)
    parser.add_argument("--train-percentage", type=int, default=100)
    parser.add_argument("--test-percentage", type=int, default=100)

    # Training hyperparameters
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--learning-rate", type=float, default=0.000001)
    parser.add_argument("--latent-dim", type=float, default=64)
    parser.add_argument("--model-layer", type=int, default=2)
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="difftopkproofs")
    parser.add_argument("--train-top-k", type=int, default=3)
    parser.add_argument("--test-top-k", type=int, default=3)
    parser.add_argument("--model-name", type=str, default=None)
    parser.add_argument("--model-dir", type=str, default=model_dir)
    parser.add_argument("--data-dir", type=str, default=data_dir)
    parser.add_argument("--use-cuda", action="store_true")
    parser.add_argument("--gpu", type=int, default=0)
    
    args = parser.parse_args()
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # common_scl_path = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../scl/{dataset}_ltl.scl'))
    common_scl_path = os.path.join(scl_dir, f'ltl_disj.scl')
    assert os.path.exists(common_scl_path)

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    # device = "cpu"
    
    data_path = os.path.join(data_dir, data_file_name)
    with open(data_path, 'r') as f:
        anno = json.load(f)

    caption2scl = json.load(open(cache_path, 'r'))
    
    # See video id in anno['split'].
    data = {data_dict['video_id']: data_dict for data_dict in anno['data']}
    train_dataset, valid_dataset, train_loader, test_loader = open_pvsg_loader(
            cache_path=cache_path,
            dataset_dir=data_dir, 
            dataset_name=data_file_name, 
            batch_size=args.batch_size, 
            device=device, 
            training_percentage=args.train_percentage, 
            testing_percentage=args.test_percentage, 
            max_video_len=args.max_video_len)
    
    trainer = Trainer(train_loader=train_loader, test_loader=test_loader, device=device, caption2scl=caption2scl, 
                      save_scl_dir=scl_dir, common_scl_path=common_scl_path,
                      latent_dim=args.latent_dim, model_layer=args.model_layer,
                      model_dir=args.model_dir, model_name=args.model_name,
                      learning_rate=args.learning_rate, load_model=args.load_model,
                      provenance=args.provenance,
                      save_model=args.save_model,
                      video_save_dir= args.video_save_dir,
                      violation_weight=args.violation_weight,
                      with_violation=args.with_violation,
                      use_contrast=args.use_contrast, 
                      parallel=args.parallel,
                      train_num_top_pairs=args.train_num_top_pairs,
                      test_num_top_pairs=args.test_num_top_pairs,
                      report_dir=args.report_dir)
    
    print(args.model_name)
    trainer.test_scallop_progs()
        
    print("end")