import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
from netcal.metrics import ACE, ECE


def label_map(label):
    answers = []
    for i in range(len(label)):
        if label[i] == 'A':
            answers.append(0)
        elif label[i] == 'B':
            answers.append(1)
        elif label[i] == 'C':
            answers.append(2)
        elif label[i] == 'D':
            answers.append(3)
        else:
            raise ValueError
    return answers


def split_and_convert(logits, labels, correct, test_size=0.7, random_state=1, label_map_func=None):

    # Split data
    cal_logits, test_logits, cal_labels, test_labels, cal_correct, test_correct = train_test_split(
        logits, labels, correct, test_size=test_size, random_state=random_state
    )

    # Convert to desired formats
    cal_logits = torch.tensor(cal_logits)
    test_logits = torch.tensor(test_logits)
    cal_labels = torch.tensor(label_map_func(cal_labels) if label_map_func else cal_labels)
    test_labels = torch.tensor(label_map_func(test_labels) if label_map_func else test_labels)
    cal_correct = np.array(cal_correct)
    test_correct = np.array(test_correct)

    return cal_logits, test_logits, cal_labels, test_labels, cal_correct, test_correct


def loss_function(logits, logits_pre, t, epsilon=1e-8):
    device = logits.device
    logits_pre = logits_pre.to(device)
    t = t.to(device)

    probs_tuned = torch.softmax(logits / t, dim=-1)
    probs = torch.softmax(logits, dim=-1)
    probs_pre = torch.softmax(logits_pre, dim=-1)

    pred = torch.argmax(probs, dim=1)
    mask = pred == torch.argmax(probs_pre, dim=1)

    probs_tuned = torch.clamp(probs_tuned, epsilon, 1. - epsilon)
    probs_pre = torch.clamp(probs_pre, epsilon, 1. - epsilon)

    mask_tensor = mask.unsqueeze(1)

    log_term_masked = -probs_pre * torch.log(probs_tuned)

    log_term = torch.where(mask_tensor, log_term_masked, 0)

    loss_per_sample = torch.sum(log_term, dim=1)
    loss = torch.mean(loss_per_sample)

    return loss


def return_conf(logits, t=1):
    probs = torch.softmax(logits/t, dim=-1)
    conf, _ = torch.max(probs, dim=-1)
    conf_cpu = conf.detach().cpu().numpy()
    return conf_cpu


def daca(results_pre, results_post):

    logits_pre, label_pre, correct_pre = results_pre['logits'], results_pre['correct_answer'], results_pre['is_correct']
    logits_post, label_post, correct_post = results_post['logits'], results_post['correct_answer'], results_post['is_correct']

    # Apply to pre and post data
    pre_data = split_and_convert(logits_pre, label_pre, correct_pre, label_map_func=label_map)
    post_data = split_and_convert(logits_post, label_post, correct_post, label_map_func=label_map)

    # Unpack results
    cal_logits_pre, test_logits_pre, cal_labels_pre, test_labels_pre, cal_correct_pre, test_correct_pre = pre_data
    cal_logits_post, test_logits_post, cal_labels_post, test_labels_post, cal_correct_post, test_correct_post = post_data

    # Create dataset and dataloader for batch processing
    dataset = TensorDataset(cal_logits_pre, cal_logits_post)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

    # Optimize t using gradient descent
    t = torch.tensor(1.0, requires_grad=True, dtype=torch.float32)
    optimizer = torch.optim.Adam([t], lr=0.1)

    for epoch in range(400):
        # Batch processing for all epochs except the last
        for batch_logits_pre, batch_logits_post in dataloader:
            optimizer.zero_grad()
            loss = loss_function(batch_logits_post, batch_logits_pre, t)
            loss.backward()
            optimizer.step()

    t = t.item()

    conf_test = return_conf(test_logits_post)
    conf_test_scaled = return_conf(test_logits_post, t)

    ece_score = ECE(10)
    ece = ece_score.measure(np.array(conf_test), np.array(test_correct_post))
    ece_scaled = ece_score.measure(np.array(conf_test_scaled), np.array(test_correct_post))

    print(40 * '=' + 'ECE' + 40 * '=')
    print(ece)

    print(40 * '=' + 'ECE Scaled' + 40 * '=')
    print(ece_scaled)


