from typing import List, Tuple
import os, sys
import math
import torch
from torch.functional import Tensor
from itertools import groupby

def noise_magnitute(kwargs):
    # For apgd and afw noise is already generated by the apgd package
    return kwargs['inverse_config'].noise


def comp_logical_tensors(tensor1: torch.BoolTensor, tensor2: torch.BoolTensor) -> torch.BoolTensor:
    """
    Composes first, bigger logical tensor with the second, smaller one
    comp_logical_tensors([True, False, True], [True, False]) -> [True, False, False]
    Parameters
    ----------
    tensor1
    tensor2

    Returns
    -------

    """
    ten_ = tensor1.clone()
    gen2 = iter(tensor2)
    for i, el in enumerate(ten_):
        ten_[i] = ~el or ~next(gen2)
    return ~ten_

# Disable
def blockPrint():
    sys.stdout = open(os.devnull, 'w')

# Restore
def enablePrint():
    sys.stdout = sys.__stdout__

def n_restarts(kwargs):
    # ToDo: add pgd + prior
    assert kwargs['RATIO_config'].activate + kwargs['RATIO_config'].apgd.activate + kwargs['inverse_config'].activate \
           + kwargs['RATIO_config'].frank_wolfe.activate + kwargs['use_generative_model'] == 1, 'Exactly one method has to be chosen!'

    if kwargs['inverse_config'].activate:
        return kwargs['inverse_config'].n_restarts
    elif kwargs['RATIO_config'].activate or kwargs['RATIO_config'].apgd.activate or kwargs['RATIO_config'].frank_wolfe.activate or kwargs['use_generative_model']:
        return kwargs['RATIO_config'].apgd.n_restarts



def get_thresholds_from_folder(thresholds_folder: str, start: int, end: int, class_labels: List[str]) -> List[float]:
    _, _, filenames = next(os.walk(thresholds_folder))
    thresholds = []
    idx = []
    for i, filename in enumerate(filenames):
        if '_last.txt' in filename:
            index = int(filename.split('_from')[0])
            if start <= index <= end:
                #end_conf = float(filename.split('_l_0')[0].split('end=')[1])
                target_class_name = filename.split(':')[2]
                index_target = class_labels.index(target_class_name)
                end_conf = float(open(os.path.join(thresholds_folder, filename), 'r').read().split('\n')[index_target])
                end_conf = math.floor(end_conf * 1e6)/1e6
                print(end_conf, filename)
                thresholds.append(end_conf)
                idx.append(index)
    thresholds = [x for _, x in sorted(zip(idx, thresholds), key=lambda pair: pair[0])]
    return thresholds


def get_images_from_folder(thresholds_folder: str, start: int, end: int, class_labels: List[str], device: str) -> Tuple[Tensor, Tensor]:
    _, _, filenames = next(os.walk(thresholds_folder))
    images_batch = []
    labels = []
    idx = []
    for i, filename in enumerate(filenames):
        if '_last.pt' in filename:
            index = int(filename.split('_from')[0])
            if start <= index <= end:
                #end_conf = float(filename.split('_l_0')[0].split('end=')[1])
                target_class_name = filename.split(':')[2]
                index_target = class_labels.index(target_class_name)
                labels.append(index_target)
                images_batch.append(torch.load(os.path.join(thresholds_folder, filename)))
                idx.append(index)
    labels = [torch.tensor([x]) for _, x in sorted(zip(idx, labels), key=lambda pair: pair[0])]
    images_batch = [x for _, x in sorted(zip(idx, images_batch), key=lambda pair: pair[0])]
    labels = torch.cat(labels, dim=0).to(device)
    images_batch = torch.cat(images_batch, dim=0).to(device)
    return images_batch, labels
