import json
import scallopy
import torch
import torch._tensor

from torch.optim import Optimizer

import time

import sys

from tqdm import tqdm 
sys.path.append("../")

from models.sg_model import SceneGraphModel
from models.learning import (
    get_batched_scene_graph_predictions,
)
from typing import Dict
from utils.dataloader import VQARTaskIndexDataset
from torch.utils.data import DataLoader
import pickle
import os
from typing import List
from models.sg_model import test_SceneGraphModel
from models.idx2word import Idx2Word
from utils.utilities import AverageMeter, ProgressMeter
from torch.nn import BCELoss


import random
import numpy

from utils.utilities import adjust_learning_rate
from pruning.vqar_pruning_algorithm import structural_pruning, purification
from arguments import parser

# goes over the samples, passes them to the classifier, training via scallop
loss_func = BCELoss()

def get_recall(labels, logits, topk=5):
    # Calculate the recall
    _, pred = logits.topk(min(topk, len(logits)), 0, True, True)
    pred = pred.t()
    correct = torch.sum(labels.gather(0, pred.t()))
    correct_label = torch.clamp(torch.sum(labels, dim = 0), 0, topk)
    accuracy = torch.mean(correct / correct_label).item()

    return accuracy

def get_preds_from_preimage(preimage):
    all_preds = set()
    for proof in preimage:
        for pred in proof:
            all_preds.add(pred)
    return all_preds

def get_tps_from_preimage(og_tps, preimage):
    preimage_preds = get_preds_from_preimage(preimage)
    new_tps = {}
    num_retained = 0
    num_og = 0

    for rel, samples in og_tps.items():
        new_tps[rel] = [[] for _ in range(len(samples))]
        for idx, tups in enumerate(samples):
            for tag, tup in tups:
                pred = rel + "("
                for c in tup:
                    if isinstance(c, str):
                        pred += c
                    else:
                        pred += str(c)
                    pred += ","
                pred = pred[:-1] + ")"
                if pred in preimage_preds:
                    new_tps[rel][idx].append((tag, tup))
                    num_retained += 1
                num_og += 1
    # print("Retained: ", num_retained, "out of", num_og)
    return new_tps

def loss_auc_grad(all_preds, correct_oids, all_oids, is_train=True):

    all_labels = torch.tensor([1 if obj in correct_oids else 0 for obj in all_oids], dtype=all_preds.dtype)

    loss = loss_func(all_preds, all_labels)
    recall = get_recall(all_labels, all_preds)

    if is_train:
        loss.backward(loss, retain_graph=True)

    return loss, recall

def train_epoch(
    epoch: int,
    sg_model,
    scallop_ctx,
    optimizers: Optimizer,
    train_loader_weak: DataLoader,
    train_samples_weak: List,
    train_samples_supervised: List,
    train_scene_graphs_and_features: Dict,
    idx2word: Idx2Word,
    args,
) -> float:


    batch_time = AverageMeter("Time", ":1.4f")
    loss_cls_log = AverageMeter("Loss@Cls", ":2.4f")
    auc_log = AverageMeter("Recall", ":2.4f")
    sdd_create_log = AverageMeter("SDD Creation", ":1.4f")
    sdd_evaluate_log = AverageMeter("SDD Evaluation", ":1.4f")
    progress = ProgressMeter(
        len(train_loader_weak),
        [batch_time, sdd_create_log, sdd_evaluate_log, loss_cls_log, auc_log],
        prefix="Epoch: [{}]".format(epoch),
    )

    # if args.algo1_pruning and epoch > args.warmup:
    #     # TODO: implement algo1 pruning
    #     # step 1: loop over the train loader and get the train samples from the sample ids returned by the train loader
    #     # step 2: get the input to the SG model needed for evaluation
    #     # step 3: prune out the proofs that are too far away from the max prediction
    #     # repeat if nothing is pruned out

    #     train_samples = purification(train_loader, train_samples, train_scene_graphs_and_features, get_batched_scene_graph_predictions, sg_model, idx2word, args)
    
    sg_model.train()

    end = time.time()
    pbar = tqdm(enumerate(train_loader_weak), total=len(train_loader_weak), desc="Training")
    for i, sample_ids in pbar:
        new_gt = 0
        og_gt = 0
        pruned = 0
        total = 0
        optimizers.zero_grad()
        samples_in_batch = [train_samples_weak[sample_id] for sample_id in sample_ids]
        # Append to the batch, the supervised samples 
        samples_in_batch += train_samples_supervised.values()
        scene_graphs_and_features = [train_scene_graphs_and_features[image_id] for image_id,_,_,_,_ in samples_in_batch]
        queries = [query for _,query,_,_,_ in samples_in_batch]
        original_preimage = [lineage[1] for _,_,lineage,_,_ in samples_in_batch]

        object_ids = [oid for _, _, _, oid, _ in samples_in_batch]
        ground_truth = [[l[0]] for _, _, l, _, _ in samples_in_batch]

        if args.structured_pruning and (args.structure_k > 0 or args.percent > 0): 
            samples_in_batch, ngt, ogt, p, t = structural_pruning(samples_in_batch, scene_graphs_and_features, idx2word, args)
            
            new_gt += ngt
            og_gt += ogt
            pruned += p
            total += t

        # Each sample includes the following fields
        # - image_id: the corresponding image
        # - query: the Datalog query 
        # - lineage: a triple of the form (answer, proofs, ground truth), where 
        # -- answer is the answer (bounding box) to the query
        # -- proofs is a list of abductive proofs, one proof per entry in the list 
        # -- ground truth is a vector of length equal to len(proofs) and it is 1 if the corresponding abductive proof is true and false otherwise.  
        # - object_ids: the list of all object ids in the proofs. 
        # - object_id_pairs: the list of object combinations occurring in the proofs. 
        # Structure of the form: 
        # [ {
        # "name" -> {o_i -> unnormalized predictions}, 
        # "relation" -> {(o_i,o_j) -> unnormalized predictions} 
        # } 
        # ]
        batch_predictions, scl_queries = (
            get_batched_scene_graph_predictions(sg_model, samples_in_batch, scene_graphs_and_features, queries, idx2word, args.gpu)
        )
        # ]     
        
        # Get the predictions of the deep model   
        # batch_predictions, batch_logits = (
        #     get_batched_scene_graph_predictions(sg_model, samples_in_batch, scene_graphs_and_features, queries, idx2word)
        # )

        preimage_proofs = [lineage[1] for _, _, lineage, _, _ in samples_in_batch]

        pruned_preimage_numbers = [ len(original_preimage[i]) - len(preimage_proofs[i]) for i in range(len(original_preimage)) ]

        batch_time.update(time.time() - end)
        end = time.time()
        # if i % args.print_freq == 0:
        #     progress.display(i)

        batched_querie_results = []
        # run the model through scallop forward pass
        for tps, p, query, oid in zip(batch_predictions, preimage_proofs, scl_queries, object_ids):
            # print(p)
            # print(tps)
            new_tps = get_tps_from_preimage(tps, p)
            # print(len(sample_ids))
            # print("Pruned preimage numbers: ", pruned_preimage_numbers)
            # print(new_tps)
            # exit()
            current_ctx = scallop_ctx.clone()
            current_ctx.add_rule(query)
            reason = current_ctx.forward_function(output_mappings={"Q": oid}, retain_graph=True)
            query_result = reason(**new_tps)
            batched_querie_results.append(query_result[0])

        losses = []
        for target_probs, correct_oids, all_oids in zip(batched_querie_results, ground_truth, object_ids):
            # print("Target probs: ", target_probs)
            # print("Correct oids: ", correct_oids)
            # print("All oids: ", all_oids)
            # exit()
            loss, auc = loss_auc_grad(target_probs, correct_oids, all_oids, is_train=True)
            # print(loss, auc)
            if auc >= 0:
                auc_log.update(auc)
            loss_cls_log.update(loss.item())
            losses.append(loss)
        
        loss = torch.mean(torch.stack(losses))
        loss.backward()
        optimizers.step()

        desc = f"Epoch: [{epoch}] Loss: {loss_cls_log.avg:.4f} AUC: {auc_log.avg:.4f}"
        if args.structured_pruning and (args.structure_k > 0 or args.percent > 0):
            desc += f" Retained GT: {new_gt}/{og_gt} ({100 * new_gt / og_gt:.2f}%) Pruned: {pruned}/{total} ({100 * pruned / total:.2f}%)"
        pbar.set_description(desc)
        # pbar.set_description(f'[{epoch}] Loss: {loss_cls_log.avg:.4f} AUC: {auc_log.avg:.4f}')


def main():
    args = parser.parse_args()
    print(args)

    if not os.path.exists(args.exp_dir):
        os.makedirs(args.exp_dir)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        numpy.random.seed(args.seed)
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # Create the model
    meta_info = json.load(open(args.meta_f, "r"))
    idx2word = Idx2Word(meta_info)

    sg_model = SceneGraphModel(
        feat_dim=args.feat_dim,
        meta_info=meta_info,
        model_dir=args.load_model_dir,
    )

    sg_model = sg_model.cuda(args.gpu)

    scallop_ctx = scallopy.ScallopContext(provenance="difftopkproofs", train_k=5, test_k=5)
    scl_file_name = os.path.abspath(os.path.join(os.path.abspath(__file__), "../scl/" + 'vqar.scl'))
    scallop_ctx.import_file(scl_file_name)

    # set optimizer
    optimizers = torch.optim.SGD(
        sg_model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # Load the weakly-supervised training dataset    
    train_file_weak = args.train_file_weak
    train_scene_graphs_and_features_file = args.train_features_file
    with open(train_file_weak, "rb") as train_samples_file:
        train_samples_weak = pickle.load(train_samples_file)
    with open(train_scene_graphs_and_features_file, "rb") as file:
        train_scene_graphs_and_features = pickle.load(file)
    train_scene_graphs_and_features = {feature["image_id"]: feature for feature in train_scene_graphs_and_features}
    train_dataset_weak = VQARTaskIndexDataset(train_samples_weak)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset_weak,
        batch_size=args.batch_size,
        num_workers=args.workers,
        shuffle=False,
        sampler=torch.utils.data.RandomSampler(
            train_dataset_weak,
        ),
    )

    # Load the supervised training dataset   
    # This is a small number of samples that will be blended with the weakly-supervised ones.  
    train_file_supervised = args.train_file_supervised
    with open(train_file_supervised, "rb") as train_samples_file:
        train_samples_supervised = pickle.load(train_samples_file)

    # Load the testing dataset
    test_file = args.test_file
    test_scene_graphs_and_features_file = args.test_features_file
    with open(test_file, "rb") as test_samples_file:
        test_samples = pickle.load(test_samples_file)
    with open(test_scene_graphs_and_features_file, "rb") as file:
        test_scene_graphs_and_features = pickle.load(file)
    test_scene_graphs_and_features = {feature["image_id"]: feature for feature in test_scene_graphs_and_features}
    test_dataset = VQARTaskIndexDataset(test_samples)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        sampler=torch.utils.data.RandomSampler(
            test_dataset,
        ),
    )

    lr = optimizers.param_groups[0]["lr"]
    best_recall = 0.0

    # Test the accuracy of the untrained model
    #best_name_recall, best_attr_recall, best_rela_recall, best_recall = (
    #    test_SceneGraphModel(sg_model, test_loader, test_samples, test_scene_graphs_and_features, idx2word)
    #)
    #with open(os.path.join(args.exp_dir, "result.log"), "a+") as f:
    #    f.write(
    #        "Epoch {}: Name Acc {}, Attr Acc {}, Rela Acc {}, Acc {}, Best Acc {}. (lr {})\n".format(
    #            -1,
    #            best_name_recall,
    #            best_attr_recall,
    #            best_rela_recall,
    #            best_recall,
    #            best_recall,
    #            lr,
    #        )
    #    )

    # Start training
    epoch = 0
    best_epoch = 0
    while True:
    # for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizers, epoch)

        train_epoch(
            epoch,
            sg_model,
            scallop_ctx,
            optimizers,
            train_loader,
            train_samples_weak,
            train_samples_supervised,
            train_scene_graphs_and_features,
            idx2word,
            args,
        )

        lr = optimizers.param_groups[0]["lr"]
        print("Train epoch {}: lr {}".format(epoch, lr) + "\n")

        epoch_name_recall, epoch_attr_recall, epoch_rela_recall, epoch_recall = (
            test_SceneGraphModel(
                sg_model, test_loader, test_samples, test_scene_graphs_and_features, idx2word, args.gpu
            )
        )

        if epoch_recall > best_recall:
            best_recall = epoch_recall
            best_epoch = epoch

        with open(os.path.join(args.exp_dir, "result.log"), "a+") as f:
            f.write(
                "Epoch {}: Name Acc {}, Attr Acc {}, Rela Acc {}, Acc {}, Best Acc {}. (lr {})\n".format(
                    epoch + 1,
                    epoch_name_recall,
                    epoch_attr_recall,
                    epoch_rela_recall,
                    epoch_recall,
                    best_recall,
                    lr,
                )
            )

        sg_model.save(args.exp_dir)

        print("Best epoch: ", best_epoch)
        if args.early_stopping > 0:
            if epoch - best_epoch >= args.early_stopping:
                print("Early stopping")
                break
        else:
            if epoch >= args.epochs:
                print("Training finished")
                break
        epoch += 1


if __name__ == "__main__":
    main()
