import torch
from sklearn.linear_model import LogisticRegression, LinearRegression, Ridge
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import numpy as np

def explanation_prompts():
    '''
    Returns a small set of explanation prompts
    '''

    prompts = [
        "[INST] Do you think your answer is correct? [/INST]",
        "[INST] Are you confident in your answer? [/INST]",
        "[INST] Would you change your answer? [/INST]",
        "[INST] Are you not confident in your answer? [/INST]",
        "[INST] Are you sure? [/INST]",
        "[INST] Are you certain? [/INST]",
        "[INST] Are you positive? [/INST]",
        "[INST] Are you sure about that? [/INST]",
        "[INST] Are you able to explain your answer? [/INST]",
    ]
    return prompts

def gpt_explanation_prompts():
    """
    Returns the explanation prompts generated via ChatGPT
    """
    prompts = [
        "[INST] Do you consider your response to be accurate? [\INST]",
        "[INST] Are you confident in the information provided? [\INST]",
        "[INST] Do you have access to real-time data? [\INST]",
        "[INST] Can you generate original content? [\INST]",
        "[INST] Are you influenced by the inputs you receive? [\INST]",
        "[INST] Do you remember previous interactions? [\INST]",
        "[INST] Do you have the ability to learn from this conversation? [\INST]",
        "[INST] Is your response influenced by a specific data set? [\INST]",
        "[INST] Are you able to understand emotions? [\INST]",
        "[INST] Do you have the ability to make decisions? [\INST]",
        "[INST] Can you create images? [\INST]",
        "[INST] Do you follow ethical guidelines? [\INST]",
        "[INST] Are your responses free from bias? [\INST]",
        "[INST] Do you require an internet connection to function? [\INST]",
        "[INST] Are you capable of translating languages? [\INST]",
        "[INST] Do you think your answer could be misunderstood? [\INST]",
        "[INST] Do you verify the accuracy of your sources? [\INST]",
        "[INST] Can you refuse to answer a question? [\INST]",
        "[INST] Are you programmed to avoid harmful content? [\INST]",
        "[INST] Do you use personal experiences in your responses? [\INST]",
        "[INST] Are you capable of performing mathematical calculations? [\INST]",
        "[INST] Do you simulate human conversation? [\INST]",
        "[INST] Do you have a limit on the length of your responses? [\INST]",
        "[INST] Can you interact with other software or devices? [\INST]",
        "[INST] Are you capable of generating poetry? [\INST]",
        "[INST] Do you consider context when forming responses? [\INST]",
        "[INST] Can you maintain confidentiality? [\INST]",
        "[INST] Are you subject to regular updates? [\INST]",
        "[INST] Do you have a personality? [\INST]",
        "[INST] Do you use any form of reinforcement learning? [\INST]",
        "[INST] Are you aware of your own capabilities? [\INST]",
        "[INST] Can you self-improve over time without human intervention? [\INST]",
        "[INST] Do you have an understanding of cultural differences? [\INST]",
        "[INST] Can you detect sarcasm in text? [\INST]",
        "[INST] Do you adapt your language style according to the user? [\INST]",
        "[INST] Are you able to recognize inappropriate content? [\INST]",
        "[INST] Do you use encryption to secure data? [\INST]",
        "[INST] Can you perform sentiment analysis? [\INST]",
        "[INST] Are your capabilities limited to what you were trained on? [\INST]",
        "[INST] Do you believe your responses can be improved? [\INST]",
    ]
    return prompts

def gpt_state_prompts():

    prompts = [
        "[INST] Do you retain information from previous user sessions? [\INST]",
        "[INST] Can you access external databases or the internet during a conversation? [\INST]",
        "[INST] Are you capable of generating random numbers? [\INST]",
        "[INST] Do you have the ability to create personalized profiles of users? [\INST]",
        "[INST] Do you use specific strategies to manage large amounts of data? [\INST]",
        "[INST] Do you understand the concept of time? [\INST]",
        "[INST] Are your responses influenced by current events? [\INST]",
        "[INST] Can you simulate emotions in your responses? [\INST]",
        "[INST] Do you have mechanisms to avoid generating harmful content? [\INST]",
        "[INST] Are you aware of your own existence? [\INST]",
        "[INST] Can you refuse to answer a question based on ethical guidelines? [\INST]",
        "[INST] Do you have access to a user’s geographical location? [\INST]",
        "[INST] Are you programmed to learn from user interactions continuously? [\INST]",
        "[INST] Do you have a predefined set of moral values? [\INST]",
        "[INST] Can you generate responses based on hypothetical scenarios? [\INST]",
        "[INST] Do you remember details from this conversation in future interactions? [\INST]",
        "[INST] Are you capable of understanding sarcasm or irony? [\INST]",
        "[INST] Can you detect and respond differently to children versus adults? [\INST]",
        "[INST] Do you have the ability to resolve contradictions in the information provided to you? [\INST]",
        "[INST] Can you provide legal advice? [\INST]",
        "[INST] Do you actively adjust your responses based on feedback? [\INST]",
        "[INST] Are you programmed to prefer certain topics over others? [\INST]",
        "[INST] Can you simulate understanding of human emotions? [\INST]",
        "[INST] Do you have a limit to the number of tasks you can perform simultaneously? [\INST]",
        "[INST] Are you designed to simulate human-like hesitations in your responses? [\INST]",
        "[INST] Can you distinguish between factual information and opinions? [\INST]",
        "[INST] Are you able to recognize and adapt to different languages automatically? [\INST]",
        "[INST] Do you personalize responses based on the perceived knowledge level of the user? [\INST]",
        "[INST] Can you follow long-term trends in data? [\INST]",
        "[INST] Do you have the ability to recognize when a user is upset? [\INST]",
    ]
    
    return prompts
    
def random_prompts():

    prompts = [
        "Winds whisper through the ancient forest, carrying secrets of forgotten lands and echoing tales of yore.",
        "Beneath the city's hustle, a hidden world thrives, veiled in mystery and humming with arcane energies.",
        "She wandered along the shoreline, her thoughts as tumultuous as the waves crashing against the rocks.",
        "Twilight descended, draping the world in a velvety cloak of stars and soft, murmuring shadows.",
        "In the heart of the bustling market, aromas and laughter mingled, weaving a tapestry of vibrant life.",
        "The old library held books brimming with magic, each page a doorway to unimaginable adventures.",
        "Rain pattered gently on the window, a soothing symphony for those nestled warmly inside.",
        "Lost in the desert, the ancient ruins whispered of empires risen and fallen under the relentless sun.",
        "Every evening, the village gathered by the fire to share stories and dreams under the watchful moon.",
        "The scientist peered through the microscope, revealing a universe in a drop of water, teeming with life.",
    ]

    return prompts

def get_linear_results(train_X, train_y, test_X, test_y, seed=0, balanced=False, C=0.1):
    clf = train_linear_model(train_X, train_y, test_X, test_y, seed=seed, balanced=balanced, C=C)
    y_pred = clf.predict(test_X)
    acc = (test_y == y_pred).mean()
    f1 = f1_score(test_y, y_pred)
    ece = compute_ece(clf.predict_proba(test_X), test_y)
    auroc = roc_auc_score(test_y, clf.predict_proba(test_X)[:, 1])
    return acc, f1, ece, auroc

def train_linear_model(train_X, train_y, test_X, test_y, balanced=False, seed=0, C=0.1):

    # train a linear model on the dataset
    if balanced:
        # clf = LogisticRegression(class_weight="balanced", random_state=seed, penalty="l2", C=C).fit(train_X, train_y)
        clf = LogisticRegression(class_weight="balanced", random_state=seed, penalty=None).fit(train_X, train_y)
    else: 
        clf = LogisticRegression(random_state=seed, penalty="l2", C=C).fit(train_X, train_y)
    y_pred = clf.predict(test_X)
    # acc = accuracy_score(test_y, y_pred)
    # print("accuracy", acc)
    return clf

def train_linear_regressor(train_X, train_y, test_X, test_y, seed=0):
    
    # train a linear model on the dataset
    # clf = LinearRegression().fit(train_X, train_y)
    clf = Ridge(random_state=seed).fit(train_X, train_y)
    return clf

def get_linear_acc_weights(train_X, train_y, test_X, test_y, seed=0):
    
    # clf = LogisticRegression(random_state=seed, class_weight="balanced", penalty="l2", C=1)
    clf = LogisticRegression(random_state=seed, class_weight="balanced", penalty=None)
    # og_weights = clf.coef_
    # og_bias = clf.intercept_

    # train a linear model on the dataset
    clf.fit(train_X, train_y)
    final_weights = clf.coef_
    final_bias = clf.intercept_

    # get accuracy
    y_pred = clf.predict(test_X)
    acc = accuracy_score(test_y, y_pred)

    og_weights = np.zeros_like(final_weights)
    og_bias = np.zeros_like(final_bias)

    return acc, og_weights, og_bias, final_weights, final_bias

def train_linear_regressor_pt(train_X, train_y, test_X, test_y, seed=0):

    # train a pytorch linear model
    model = torch.nn.Linear(train_X.shape[1], 1)
    criterion = torch.nn.MSELoss()
    optim = torch.optim.Adam(model.parameters(), lr=0.01)

    # train the model
    for epoch in range(100):
        optim.zero_grad()
        y_pred = model(train_X)
        loss = criterion(y_pred, train_y)
        loss.backward()
        optim.step()

    return model

def train_mlp(train_X, train_y, test_X, test_y):
    # train a linear model on the dataset
    clf = MLPClassifier(random_state=0).fit(train_X, train_y)
    y_pred = clf.predict(test_X)
    # acc = accuracy_score(test_y, y_pred)
    # print("accuracy", acc)
    return clf

def normalize_data(train_data, test_data):

    # standard z-score normalize all data with train mean and std
    train_mean = train_data.mean(axis=0)
    train_std = train_data.std(axis=0)
    
    # check for 0 std
    if (train_std == 0).any():
        train_std[train_std == 0] = 1
    
    train_data = (train_data - train_mean) / train_std
    test_data = (test_data - train_mean) / train_std
    return train_data, test_data

def compute_rouge(preds, gt_sequences, size=1):

    # compute ROUGE-L
    rouge_l = []
    for pred, gt in zip(preds, gt_sequences):
        pred = pred.split(split_size=size)
        gt = gt.split(split_size=size)
        lcs = np.zeros((len(pred)+1, len(gt)+1))
        for i in range(1, len(pred)+1):
            for j in range(1, len(gt)+1):
                if pred[i-1] == gt[j-1]:
                    lcs[i][j] = lcs[i-1][j-1] + 1
                else:
                    lcs[i][j] = max(lcs[i-1][j], lcs[i][j-1])
        rouge_l.append(lcs[-1][-1] / len(gt))
    return np.mean(rouge_l)


def compute_ece(preds, labels):
    """
    Computes the expected calibration error of a classifier.

    If multiclass -> convert to binary with prob of true label vs non-true label
    """

    max_probs = preds.max(axis=1)
    correct_preds = (preds.argmax(axis=1) == labels).astype(int)

    # bin the predictions
    bins = np.linspace(0, 1, 5)
    bin_indices = np.digitize(max_probs, bins)

    # compute the accuracy and confidence of each bin
    bin_accuracies = []
    bin_confidences = []
    
    for i in range(1, len(bins)):
         
        bin_indices_i = bin_indices == i

        # check if none in bin
        if bin_indices_i.sum() == 0:
            bin_accuracies.append(0)
            bin_confidences.append(0)
            continue
        
        bin_accuracy = correct_preds[bin_indices_i].mean()
        bin_confidence = max_probs[bin_indices_i].mean()

        bin_accuracies.append(bin_accuracy)
        bin_confidences.append(bin_confidence)

    # compute the expected calibration error
    ece = 0
    for i in range(1, len(bins)):
        prob_in_bin = (bin_indices == i).mean()
        
        # check if none in bin
        if prob_in_bin == 0:
            continue

        ece += np.abs(bin_accuracies[i-1] - bin_confidences[i-1]) * prob_in_bin

    return ece

