from typing import List, Tuple
import os, sys
import math
import torch
from torch.functional import Tensor


def noise_magnitute(kwargs):
    # For apgd and afw noise is already generated by the apgd package
    return kwargs["inverse_config"].noise


def comp_logical_tensors(
    tensor1: torch.BoolTensor, tensor2: torch.BoolTensor
) -> torch.BoolTensor:
    """
    Composes first, bigger logical tensor with the second, smaller one
    comp_logical_tensors([True, False, True], [True, False]) -> [True, False, False]
    Parameters
    ----------
    tensor1
    tensor2

    Returns
    -------

    """
    ten_ = tensor1.clone()
    gen2 = iter(tensor2)
    for i, el in enumerate(ten_):
        ten_[i] = ~el or ~next(gen2)
    return ~ten_


# Disable
def blockPrint():
    sys.stdout = open(os.devnull, "w")


# Restore
def enablePrint():
    sys.stdout = sys.__stdout__


def n_restarts(kwargs):
    # ToDo: add pgd + prior
    assert (
        kwargs["RATIO_config"].activate
        + kwargs["RATIO_config"].apgd.activate
        + kwargs["inverse_config"].activate
        + kwargs["RATIO_config"].frank_wolfe.activate
        + kwargs["use_generative_model"]
        == 1
    ), "Exactly one method has to be chosen!"

    if kwargs["inverse_config"].activate:
        return kwargs["inverse_config"].n_restarts
    elif (
        kwargs["RATIO_config"].activate
        or kwargs["RATIO_config"].apgd.activate
        or kwargs["RATIO_config"].frank_wolfe.activate
        or kwargs["use_generative_model"]
    ):
        return kwargs["RATIO_config"].apgd.n_restarts


def get_thresholds_from_folder(
    thresholds_folder: str, start: int, end: int, class_labels: List[str]
) -> List[float]:
    _, _, filenames = next(os.walk(thresholds_folder))
    thresholds = []
    idx = []
    for i, filename in enumerate(filenames):
        if "_last.txt" in filename:
            index = int(filename.split("_from")[0])
            if start <= index <= end:
                # end_conf = float(filename.split('_l_0')[0].split('end=')[1])
                target_class_name = filename.split(":")[2]
                index_target = class_labels.index(target_class_name)
                end_conf = float(
                    open(os.path.join(thresholds_folder, filename), "r")
                    .read()
                    .split("\n")[index_target]
                )
                end_conf = math.floor(end_conf * 1e6) / 1e6
                print(end_conf, filename)
                thresholds.append(end_conf)
                idx.append(index)
    thresholds = [x for _, x in sorted(zip(idx, thresholds), key=lambda pair: pair[0])]
    return thresholds


def get_images_from_folder(
    thresholds_folder: str, start: int, end: int, class_labels: List[str], device: str
) -> Tuple[Tensor, Tensor]:
    _, _, filenames = next(os.walk(thresholds_folder))
    images_batch = []
    labels = []
    idx = []
    for i, filename in enumerate(filenames):
        if "_last.pt" in filename:
            index = int(filename.split("_from")[0])
            if start <= index <= end:
                # end_conf = float(filename.split('_l_0')[0].split('end=')[1])
                target_class_name = filename.split(":")[2]
                index_target = class_labels.index(target_class_name)
                labels.append(index_target)
                images_batch.append(
                    torch.load(os.path.join(thresholds_folder, filename))
                )
                idx.append(index)
    labels = [
        torch.tensor([x]) for _, x in sorted(zip(idx, labels), key=lambda pair: pair[0])
    ]
    images_batch = [
        x for _, x in sorted(zip(idx, images_batch), key=lambda pair: pair[0])
    ]
    labels = torch.cat(labels, dim=0).to(device)
    images_batch = torch.cat(images_batch, dim=0).to(device)
    return images_batch, labels
