import numpy as np
import torch
import torch
from typing import Dict, List, Tuple

import sys 
sys.path.append("../")
from preprocessing.utilities import getArguments, getPredicate
import math


TIME_OUT = 120


def get_prediction(
    atom: str,
    argument_to_prediction_index: Dict[str, List[Tuple]],
    predictions: Dict[str, torch.tensor],
) -> torch.tensor:
    # The tuples returned by TGs are of the form (predicate, subject, object)
    predicate = getPredicate(atom)
    arguments = getArguments(atom)

    if predicate == "name" or predicate == "attr":
        arguments = (
            int(arguments[1]),
            arguments[0],
        )
    elif predicate == "rela":
        arguments = (
            arguments[0],
            int(arguments[1]),
            int(arguments[2]),
        )

    try:
        index = argument_to_prediction_index[predicate][arguments]
        weight = predictions[predicate][index]
        return weight
    except ValueError:
        return RuntimeError(
            "The index of atom {} cannot be found in the neural predictions".format(
                atom
            )
        )

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=":f"):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print("\t".join(entries), end='\r')

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = "{:" + str(num_digits) + "d}"
        return "[" + fmt + "/" + fmt.format(num_batches) + "]"


def adjust_learning_rate(args, optimizer, epoch):
    lr = args.lr
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate**3)
        lr = (
            eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / args.epochs)) / 2
        )
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate**steps)

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
