import json
import torch
import torch._tensor

import torch.nn.functional as F
from torch.optim import Optimizer

import time

import sys 

from models.sg_model import SceneGraphModel
from models.learning import (
    get_batched_scene_graph_predictions,
)
from typing import Dict
from utils.dataloader import VQARTaskIndexDataset
from torch.utils.data import DataLoader
import pickle
import os
from typing import List
from models.sg_model import test_SceneGraphModel
from models.idx2word import Idx2Word
from utils.utilities import AverageMeter, ProgressMeter
from ilp.formulation import ilp_pywrap
from benchmark.utilities import maintain_topk_properties

import random
import numpy

from utils.utilities import adjust_learning_rate
from arguments import parser

# ILP-related parameters
parser.add_argument(
    "--epsilon_ilp",
    default=0.99,
    type=float,
    help="high-confidence selection threshold",
)

parser.add_argument(
    "--continuous-relaxation",
    default=True,
    type=bool,
    help="returning continuous relaxations of the linear program",
)

topk_names = 50 
topk_attrs = 50 
topk_relas = 50 

# To estimate the gold label distribution of the training data,
# we filter the scene graph so it includes only the objects within top-k from each scene_graph
def gold_distribution(
    train_loader: DataLoader,
    train_samples: List,
    train_scene_graphs_and_features: Dict,
):
    total_name_objs = 0
    total_relas = 0
    name_freqs = [0]*topk_names
    rela_freqs = [0]*topk_relas
    
    for i, sample_ids in enumerate(train_loader):
        # Get the predictions of the deep model
        samples_in_batch = [train_samples[sample_id] for sample_id in sample_ids]
        scene_graphs_and_features = [train_scene_graphs_and_features[image_id] for image_id,_,_,_,_ in samples_in_batch]
        for item in scene_graphs_and_features:
            scene = item["scene_graph"]
            # Each scene graph should be filtered out so that we maintain only the desired classes.
            object2type, _, object2rels = maintain_topk_properties(scene, topk_names, topk_attrs, topk_relas)
            for _, name_id in object2type.items():
                name_freqs[name_id] = name_freqs[name_id] + 1
                total_name_objs = total_name_objs + 1

            for obj in object2rels:
                for (_, relation) in object2rels[obj].items():
                    rela_freqs[relation] = rela_freqs[relation] + 1
                    total_relas = total_relas + 1
    return {
        "name": torch.as_tensor(
            [float(n) / total_name_objs for n in name_freqs]
        ),
        "rela": torch.as_tensor(
            [float(r) / total_relas for r in rela_freqs]
        ),
    }

def train_epoch(
    epoch: int,
    sg_model,
    optimizers: Optimizer,
    train_loader: DataLoader,
    train_samples: List,
    train_scene_graphs_and_features: Dict,
    idx2word: Idx2Word,
    args,
) -> float:
    sg_model.train()

    batch_time = AverageMeter("Time", ":1.2f")
    loss_cls_log = AverageMeter("Loss@Cls", ":2.2f")
    sdd_create_log = AverageMeter("SDD Creation", ":1.2f")
    sdd_evaluate_log = AverageMeter("SDD Evaluation", ":1.2f")
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, sdd_create_log, sdd_evaluate_log, loss_cls_log],
        prefix="Epoch: [{}]".format(epoch),
    )

    # Structure of the form: 
    # "name" -> [freq_name_i,...],
    # "relation" -> [freq_rela_i,...]
    # } }
    distribution = gold_distribution(train_loader, train_samples, train_scene_graphs_and_features)

    end = time.time()
    for i, sample_ids in enumerate(train_loader):
        # Get the predictions of the deep model
        samples_in_batch = [train_samples[sample_id] for sample_id in sample_ids]
        scene_graphs_and_features = [train_scene_graphs_and_features[image_id] for image_id,_,_,_,_ in samples_in_batch]
        queries = [query for _,query,_,_,_ in samples_in_batch]
        proofs_n = [lineage[1] for _,_,lineage,_,_ in samples_in_batch]
     
        # Each sample includes the following fields
        # - image_id: the corresponding image
        # - query: the Datalog query 
        # - lineage: a triple of the form (answer, proofs, ground truth), where 
        # -- answer is the answer (bounding box) to the query
        # -- proofs is a list of abductive proofs, one proof per entry in the list 
        # -- ground truth is a vector of length equal to len(proofs) and it is 1 if the corresponding abductive proof is true and false otherwise.  
        # - object_ids: the list of all object ids in the proofs. 
        # - object_id_pairs: the list of object combinations occurring in the proofs. 

        # Structure of the form: 
        # [ {
        # "name" -> {o_i -> unnormalized predictions}, 
        # "relation" -> {(o_i,o_j) -> unnormalized predictions} 
        # } 
        # ]        
        batch_predictions, batch_logits = (
            get_batched_scene_graph_predictions(sg_model, samples_in_batch, scene_graphs_and_features, queries, idx2word)
        )

        # Get the pseudolabels via the ILP formulation
        # Structure of the form: 
        # {image_id -> {
        # "name" -> {o_i -> unnormalized predictions},
        # "relation" -> {(o_i,o_j) -> unnormalized predictions} 
        # } }
        pseudo_labels = ilp_pywrap(
            proofs_n,
            batch_predictions,
            idx2word,
            distribution,
            args.epsilon_ilp,
            args.continuous_relaxation,
        )

        loss = 0
        for s in range(len(samples_in_batch)):
            for t in ["name", "rela"]: 
                for box in batch_logits[s][t]:
                    # Predictions are unnormalized
                    # Double check that the network is called with softmax=False in line 162 of learning.py
                    l = F.cross_entropy(batch_logits[s][t][box], pseudo_labels[s][t][box])
                    loss = loss + l

        loss = loss / len(samples_in_batch)
        print(loss)
        # Zero the gradients
        optimizers.zero_grad()
        # Perform back propagation
        loss.backward()
        # Perform the Adam optimizer training step
        optimizers.step()
        
        batch_time.update(time.time() - end)
        end = time.time()
        if i % args.print_freq == 0:
            progress.display(i)


def main():
    args = parser.parse_args()
    print(args)

    if not os.path.exists(args.exp_dir):
        os.makedirs(args.exp_dir)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        numpy.random.seed(args.seed)
    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # Create the model
    meta_info = json.load(open(args.meta_f, "r"))
    idx2word = Idx2Word(meta_info)

    sg_model = SceneGraphModel(
        feat_dim=args.feat_dim,
        meta_info=meta_info,
        model_dir=None,
    )

    # set optimizer
    optimizers = torch.optim.SGD(
        sg_model.parameters(),
        args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # Load the training dataset    
    train_file = args.train_file
    train_scene_graphs_and_features_file = args.train_features_file
    with open(train_file, "rb") as train_samples_file:
        train_samples = pickle.load(train_samples_file)
    with open(train_scene_graphs_and_features_file, "rb") as file:
        train_scene_graphs_and_features = pickle.load(file)
    train_scene_graphs_and_features = {feature["image_id"]: feature for feature in train_scene_graphs_and_features}
    train_dataset = VQARTaskIndexDataset(train_samples)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        num_workers=args.workers,
        shuffle=False,
        sampler=torch.utils.data.RandomSampler(
            train_dataset,
        ),
    )

    # Load the testing dataset
    test_file = args.test_file
    test_scene_graphs_and_features_file = args.test_features_file
    with open(test_file, "rb") as test_samples_file:
        test_samples = pickle.load(test_samples_file)
    with open(test_scene_graphs_and_features_file, "rb") as file:
        test_scene_graphs_and_features = pickle.load(file)
    test_scene_graphs_and_features = {feature["image_id"]: feature for feature in test_scene_graphs_and_features}
    test_dataset = VQARTaskIndexDataset(test_samples)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.workers,
        sampler=torch.utils.data.RandomSampler(
            test_dataset,
        ),
    )

    lr = optimizers.param_groups[0]["lr"]

    # Test the accuracy of the untrained model
    best_name_recall, best_attr_recall, best_rela_recall, best_recall = (
        test_SceneGraphModel(sg_model, test_loader, test_samples, test_scene_graphs_and_features, idx2word)
    )
    with open(os.path.join(args.exp_dir, "result.log"), "a+") as f:
        f.write(
            "Epoch {}: Name Acc {}, Attr Acc {}, Rela Acc {}, Acc {}, Best Acc {}. (lr {})\n".format(
                -1,
                best_name_recall,
                best_attr_recall,
                best_rela_recall,
                best_recall,
                best_recall,
                lr,
            )
        )

    # Start training
    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizers, epoch)

        train_epoch(
            epoch,
            sg_model,
            optimizers,
            train_loader,
            train_samples,
            train_scene_graphs_and_features,
            idx2word,
            args,
        )

        lr = optimizers.param_groups[0]["lr"]

        epoch_name_recall, epoch_attr_recall, epoch_rela_recall, epoch_recall = (
            test_SceneGraphModel(
                sg_model, test_loader, test_samples, test_scene_graphs_and_features, idx2word
            )
        )

        if epoch_recall > best_recall:
            best_recall = epoch_recall

        with open(os.path.join(args.exp_dir, "result.log"), "a+") as f:
            f.write(
                "Epoch {}: Name Acc {}, Attr Acc {}, Rela Acc {}, Acc {}, Best Acc {}. (lr {})\n".format(
                    epoch + 1,
                    epoch_name_recall,
                    epoch_attr_recall,
                    epoch_rela_recall,
                    epoch_recall,
                    best_recall,
                    lr,
                )
            )

        sg_model.save(args.exp_dir)


if __name__ == "__main__":
    main()
