import torch
import pickle
import os

def load_plain_teacher_data(path, teacher_model, class_tpr_min=None, od_exclusion_threshold=None, calibrate_temperature=True, ssl_config=None):

    if calibrate_temperature:
        tpr_file = os.path.join(path, 'Thresholds', f'{teacher_model}_id_thresholds_T.pickle')
        temperature_file = os.path.join(path, 'Thresholds', f'{teacher_model}_temperature.pickle')
        temperature = pickle.load(open(temperature_file, 'rb'))
        print(f'Teacher temperature {temperature}')
    else:
        tpr_file = os.path.join(path, 'Thresholds', f'{teacher_model}_id_thresholds.pickle')
        temperature = 1.0

    class_tpr_thresholds = pickle.load(open(tpr_file, 'rb'))[class_tpr_min]

    tiny_image_target_file = os.path.join(path, f'{teacher_model}.pt')
    with torch.no_grad():
        if calibrate_temperature:
            model_confidences = torch.softmax(
                torch.load(tiny_image_target_file, map_location=torch.device('cpu')) / temperature,
                dim=1).detach()
        else:
            model_confidences = torch.softmax(torch.load(tiny_image_target_file, map_location=torch.device('cpu')),
                                              dim=1).detach()

    if ssl_config is not None:
        ssl_config['ID Min TPR'] = class_tpr_min,
        ssl_config['ID OD Exclusion'] = od_exclusion_threshold,
        ssl_config['Temperature'] = temperature if calibrate_temperature else None
        ssl_config['TinyImageTarget'] = teacher_model

    return model_confidences, class_tpr_thresholds, temperature


def load_additional_data(path, prefix, teacher_model):
    additional_target_file = f'{path}{prefix}_{teacher_model}.pt'
    print(f'Using additional teacher file: {additional_target_file}')
    teacher_logits = torch.load(additional_target_file, map_location=torch.device('cpu'))
    return teacher_logits

def load_teacher_data(path, teacher_model, selection_model=None, class_tpr_min=None, od_exclusion_threshold=None,
                      calibrate_temperature=True, ssl_config=None):

    if calibrate_temperature:
        temperature_file = os.path.join(path, 'Thresholds',  f'{teacher_model}_temperature.pickle')
        temperature = pickle.load(open(temperature_file, 'rb'))
        print(f'Teacher temperature {temperature}')
    else:
        temperature = 1.0

    #Allow to use a different density_model for image selection vs image labeling
    #Make sure to load thresholds of selection density_model but use confidences and tempeature of teacher
    if selection_model is None:
        tpr_file = os.path.join(path, 'Thresholds', f'{teacher_model}_id_thresholds.pickle')
        od_file = os.path.join(path, 'Thresholds', f'{teacher_model}_id_thresholds_from_od.pickle')
    else:
        tpr_file = os.path.join(path, 'Thresholds', f'{selection_model}_id_thresholds.pickle')
        od_file = os.path.join(path, 'Thresholds', f'{selection_model}_id_thresholds_from_od.pickle')

    all_tpr_thresholds = pickle.load(open(tpr_file, 'rb'))

    if class_tpr_min is not None and class_tpr_min != 'None':
        print(f'TPR: {class_tpr_min}')
        class_thresholds = all_tpr_thresholds[class_tpr_min]
    else:
        num_classes = len(all_tpr_thresholds[all_tpr_thresholds.keys[0]])
        class_thresholds = torch.zeros(num_classes, dtype=torch.float)

    if od_exclusion_threshold is not None and (od_exclusion_threshold != 'None'):
        print(f'OD Exclusion: {od_exclusion_threshold}')
        all_od_thresholds = pickle.load(open(od_file, 'rb'))
        class_od_thresholds = all_od_thresholds[od_exclusion_threshold]
        class_thresholds = torch.max( torch.stack([class_thresholds, class_od_thresholds], dim=1) ,dim=1)[0]

    print(f'Max threshold: {torch.max(class_thresholds).item()}')

    tiny_image_target_file = os.path.join(path, f'{teacher_model}.pt')
    print(f'Using teacher file: {tiny_image_target_file}')
    teacher_logits= torch.load(tiny_image_target_file, map_location=torch.device('cpu'))

    if selection_model is None:
        #If no selection density_model is use, use teacher logits for selection
        selection_logits = teacher_logits
    else:
        #else load separate selection logits
        tiny_image_target_file = os.path.join(path, f'{selection_model}.pt')
        print(f'Using selection file: {tiny_image_target_file}')
        selection_logits = torch.load(tiny_image_target_file, map_location=torch.device('cpu'))

    if ssl_config is not None:
        ssl_config['ID Min TPR'] = class_tpr_min,
        ssl_config['ID OD Exclusion'] = od_exclusion_threshold,
        ssl_config['Temperature'] = temperature if calibrate_temperature else None
        ssl_config['Teacher Model'] = teacher_model
        ssl_config['Selection Model'] = selection_model

    return teacher_logits, selection_logits, class_thresholds, temperature