from math import ceil

import numpy as np
import torch
import os

from models.distil_bert import Distil_Bert
from models.roberta import Roberta
from utils.constants import BERT_SENTIMENT_PATH, DISTIL_BERT_SENTIMENT_PATH, ROBERTA_SENTIMENT_PATH, DEVICE, \
    SAVED_TRAINED_MODELS_PATH, SENTIMENT_PATH
from models.bert import Bert


def get_fine_tuned_model(architecture, setup_name):
    if setup_name == 'cebab':
        if architecture == 'bert':
            model_path = BERT_SENTIMENT_PATH
            model = Bert(pretrained_model_path=model_path)
            return model
        elif architecture == 'distil_bert':
            model_path = DISTIL_BERT_SENTIMENT_PATH
            model = Distil_Bert(pretrained_model_path=model_path)
            return model
        elif architecture == 'roberta':
            model_path = ROBERTA_SENTIMENT_PATH
            model = Roberta(pretrained_model_path=model_path)
            return model
        elif architecture == 'challenge':
            return Roberta(
                pretrained_model_path="./saved_models/overall_sentiment_roberta_food_mention/roberta-base")
        else:
            raise NotImplementedError
    elif setup_name == 'stance':
        if architecture == 'roberta':
            model = Roberta(
                pretrained_model_path='./saved_models/stance_setup/original_instruction_to_label/roberta-base',
                num_labels=3)
            return model


def get_correlation_models_paths():
    path_dir = "/home/XXXXXX/MatchingBasedCausalExplanation/saved_models/sentiment_models/challenge_setups"
    paths = os.listdir(path_dir)
    return {f'{path.split(".")[0]}': os.path.join(path_dir, path, 'roberta-base') for path in paths}
