import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import sys
from ast import literal_eval
from models.clevrcbm import ClevrCBM
from models.clevrcbmrec import ClevrCBMRec
from models.clevrdsldpl import ClevrDSLDPL
from models.clevrdsldplrec import ClevrDSLDPLRec
from datasets.clevr import CLEVR
import numpy as np
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import accuracy_score, f1_score
import csv


data_folder = "./"

metrics = ["yf2","betaf1","cf1"]

def parse_which_c(file_name):
    try:
        start = file_name.index("which_c_") + len("which_c_")
        end = file_name.index("_c_sup", start)
        which_c = literal_eval(file_name[start:end])
        return len(which_c)
    except ValueError:
        return None

def parse_supervision_type(file_name):
    if "entropy" in file_name:
        return "entropy"
    elif "rec" in file_name:
        return "rec"
    elif "k_sup" in file_name:
        return "k_sup"
    else:
        return "unknown"

def get_concepts_and_labels_clevr(out_dict, true_concepts):
    predicted_labels = torch.argmax(out_dict["YS"], dim=1)

    refactored_true_concepts = true_concepts.view(true_concepts.shape[0], 4, -1)
    pCS = out_dict['pCS'].view(out_dict['pCS'].shape[0], 4, -1)

    def conditional_argmax(tensor):
        max_vals, argmax_vals = torch.max(tensor, dim=-1)  # Get max values and indices
        argmax_vals[max_vals == -1] = -1  # Set argmax to -1 if max value is -1
        return argmax_vals

    gt_colors, gt_shapes, gt_materials, gt_sizes = refactored_true_concepts[:, :, :8], refactored_true_concepts[:, :, 8:11], refactored_true_concepts[:, :, 11:13], refactored_true_concepts[:, :, 13:15]
    pt_colors, pt_shapes, pt_materials, pt_sizes = pCS[:, :, :8], pCS[:, :, 8:11], pCS[:, :, 11:13], pCS[:, :, 13:15]

    gt_colors, gt_shapes, gt_materials, gt_sizes = (
        conditional_argmax(gt_colors),
        conditional_argmax(gt_shapes),
        conditional_argmax(gt_materials),
        conditional_argmax(gt_sizes)
    )

    pt_colors, pt_shapes, pt_materials, pt_sizes = (
        conditional_argmax(pt_colors),
        conditional_argmax(pt_shapes),
        conditional_argmax(pt_materials),
        conditional_argmax(pt_sizes)
    )

    refactored_true_concepts = torch.stack(
        [gt_colors, gt_shapes, gt_materials, gt_sizes],
        dim = -1,
    )
    predicted_concepts = torch.stack(
        [pt_colors, pt_shapes, pt_materials, pt_sizes],
        dim = -1
    )

    return predicted_labels, predicted_concepts, refactored_true_concepts

def retrive_concepts_and_labels(model, dataset):
    true_labels, predicted_labels, true_concepts, predicted_concepts = [], [], [], []

    nll_loss = 0.0

    for i, data in enumerate(dataset):
        images, labels, concepts = data
        images, labels, concepts = (
            images.to(model.device),
            labels.to(model.device),
            concepts.to(model.device),
        )
        out_dict = model(images)
        out_label, out_concept = None, None

        pred_y = out_dict["YS"].float() + 1e-6
        Z = torch.sum(pred_y, dim=1, keepdim=True)
        pred_y /= Z
        loss = torch.nn.functional.nll_loss(pred_y.log().cpu(), labels.long().cpu(), reduction="sum")

        nll_loss += loss.item()

        out_label, out_concept, concepts = get_concepts_and_labels_clevr(
            out_dict, concepts
        )

        true_labels.append(labels.cpu().numpy())
        true_concepts.append(concepts.cpu().numpy())

        predicted_labels.append(out_label.detach().cpu().numpy())
        predicted_concepts.append(out_concept.cpu().numpy())

        if i == 2:
            break

    # concatenate
    true_labels = np.concatenate(true_labels, axis=0)
    predicted_labels = np.concatenate(predicted_labels, axis=0)
    true_concepts = np.concatenate(true_concepts, axis=0)
    predicted_concepts = np.concatenate(predicted_concepts, axis=0)

    avg_nll = nll_loss / len(dataset.dataset)

    assert true_labels.shape == predicted_labels.shape
    assert true_concepts.shape == predicted_concepts.shape, f"{true_concepts.shape} {predicted_concepts.shape}"

    return true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll


def rearrange_predictions_with_confusion_clevr(pred, gt):
    N, _, _ = pred.shape

    global_cost_matrix = np.zeros((4, 4))

    for i in range(N):

        local_cost_matrix = np.zeros((4, 4))
        for j in range(4):
            for k in range(4):
                matches = (pred[i, :, j] == gt[i, :, k]) & (gt[i, :, k] != -1)
                local_cost_matrix[j, k] = matches.sum()

        global_cost_matrix += local_cost_matrix

    row_ind, col_ind = linear_sum_assignment(-global_cost_matrix)
    pred_rearranged = np.copy(pred)
    for i in range(N):
        pred_rearranged[i] = pred[i, :, col_ind]

    return pred_rearranged, col_ind

def build_confusion_matrix(preds1, preds2, n_classes):
    confusion_matrix = np.zeros((n_classes, n_classes), dtype=int)

    # Populate the confusion matrix
    for i in range(preds1.shape[0]):
        confusion_matrix[preds1[i].item(), preds2[i].item()] += 1

    return confusion_matrix


def permutation_matrix_from_predictions(preds1, preds2, n_classes):
    confusion_matrix = build_confusion_matrix(preds1, preds2, n_classes)

    row_indices, col_indices = linear_sum_assignment(confusion_matrix, maximize=True)

    perm_matrix = np.zeros((n_classes, n_classes), dtype=np.float32)
    perm_matrix[row_indices, col_indices] = 1

    return torch.tensor(perm_matrix)


def get_hungarian_permutation(model, dataset):
    _, _, true_concepts, predicted_concepts, _ = retrive_concepts_and_labels(model, dataset)

    _, perm_idx = rearrange_predictions_with_confusion_clevr(predicted_concepts, true_concepts)
    
    perm_color = permutation_matrix_from_predictions(
        predicted_concepts[:, :, 0].flatten(), true_concepts[:, :, 0].flatten(), 8
    ).numpy()
    perm_shapes = permutation_matrix_from_predictions(
        predicted_concepts[:, :, 1].flatten(), true_concepts[:, :, 1].flatten(), 3
    ).numpy()
    perm_material = permutation_matrix_from_predictions(
        predicted_concepts[:, :, 2].flatten(), true_concepts[:, :, 2].flatten(), 2
    ).numpy()
    perm_sizes = permutation_matrix_from_predictions(
        predicted_concepts[:, :, 3].flatten(), true_concepts[:, :, 3].flatten(), 2
    ).numpy()

    return (perm_idx, perm_color, perm_shapes, perm_material, perm_sizes)
 

def retrive_concepts_and_labels_hungarian(model, perm_matrix, dataset):
    true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll = retrive_concepts_and_labels(model, dataset)
    
    (perm_idx, perm_color, perm_shapes, perm_material, perm_sizes) = perm_matrix

    # for i in range(predicted_concepts.shape[0]):
    #     predicted_concepts[i] = predicted_concepts[i, :, perm_idx]

    predicted_colors = perm_color[predicted_concepts[:, :, 0]]
    predicted_colors = np.argmax(predicted_colors, axis=-1)

    predicted_shapes = perm_shapes[predicted_concepts[:, :, 1]]
    predicted_shapes = np.argmax(predicted_shapes, axis=-1)

    predicted_materials = perm_material[predicted_concepts[:, :, 2]]
    predicted_materials = np.argmax(predicted_materials, axis=-1)

    predicted_sizes = perm_sizes[predicted_concepts[:, :, 3]]
    predicted_sizes = np.argmax(predicted_sizes, axis=-1)

    predicted_concepts = np.stack(
        [predicted_colors, predicted_shapes, predicted_materials, predicted_sizes],
        axis=-1
    )

    return true_labels, predicted_labels, true_concepts, predicted_concepts, avg_nll    

def pad(tensor, target_size = 8):
    current_size = tensor.size(1)  # Get the current size of the second dimension
    padding_size = max(0, target_size - current_size)  # Calculate how much padding is needed

    # Pad with zeros if needed
    if padding_size > 0:
        padded_tensor = torch.nn.functional.pad(tensor, (0, padding_size))
    else:
        padded_tensor = tensor[:, :target_size]
    return padded_tensor

def clevr_logic(vector):
    class_1_found = {'large_cube': False, 'large_cylinder': False}
    class_2_found = {'small_metal_cube': False, 'small_sphere': False}
    class_3_found = {'large_blue_sphere': False, 'small_yellow_sphere': False}

    for obj in vector:
        presence, color, shape, material, size = obj
        
        if presence == 0:
            continue

        colors = ["gray", "red", "blue", "green", "brown", "purple", "cyan", "yellow"]
        shapes = ["cube", "sphere", "cylinder"]
        materials = ["rubber", "metal"]
        sizes = ["large", "small"]

        color = colors[color]
        shape = shapes[shape]
        material = materials[material]
        size = sizes[size]

        if size == 'large' and shape == 'cube' and color == 'gray':
            class_1_found['large_cube'] = True
        if size == 'large' and shape == 'cylinder':
            class_1_found['large_cylinder'] = True

        if size == 'small' and material == 'metal' and shape == 'cube':
            class_2_found['small_metal_cube'] = True
        if size == 'small' and shape == 'sphere' and material == 'metal':
            class_2_found['small_sphere'] = True

        # Check for Class 3 objects
        if size == 'large' and color == 'blue' and shape == 'sphere':
            class_3_found['large_blue_sphere'] = True
        if size == 'small' and color == 'yellow' and shape == 'sphere':
            class_3_found['small_yellow_sphere'] = True

    class_1 = all(class_1_found.values())
    class_2 = all(class_2_found.values())
    class_3 = all(class_3_found.values())

    if sum([class_1, class_2, class_3]) == 1:
        if class_1:
            return 0
        elif class_2:
            return 1
        elif class_3:
            return 2
    return 3 # no found or not interesting

def get_concepts_label():
    samples, labels = [], []

    for _ in range(100):
        presence = np.random.randint(0, 2, size=4)
        color = np.random.randint(0, 8, size=4)
        shape = np.random.randint(0, 3, size=4)
        material = np.random.randint(0, 2, size=4)
        size = np.random.randint(0, 2, size=4)
        
        objects = np.stack([presence, color, shape, material, size], axis=-1)
        y = clevr_logic(objects)
        labels.append(torch.tensor([y]).to("cpu").long())

        logits = []
        for obj in objects:
            l = []
            presence, color, shape, material, size = torch.tensor(obj)

            l.append(torch.tensor([presence]).to("cpu").float())
            l.append(torch.nn.functional.one_hot(color, 8).to("cpu").float())
            l.append(torch.nn.functional.one_hot(shape, 3).to("cpu").float())
            l.append(torch.nn.functional.one_hot(material, 2).to("cpu").float())
            l.append(torch.nn.functional.one_hot(size, 2).to("cpu").float())

            l = torch.cat(l, dim=0)
            logits.append(l)
        logits = torch.stack(logits, dim=1)
        samples.append(logits)

    samples = torch.stack(samples, dim=0)
    labels = torch.cat(labels, dim=0)
    return samples, labels

def evaluate_knowledge_clevr(model, pi):
    samples, labels = get_concepts_label()

    perm_idx, perm_color, perm_shapes, perm_material, perm_sizes = pi
    perm_idx, perm_color, perm_shapes, perm_material, perm_sizes = perm_idx.astype(int), perm_color.astype(int) , perm_shapes.astype(int) , perm_material.astype(int) , perm_sizes.astype(int) 

    true_concepts = torch.tensor(samples)
    full_concept_vector = torch.zeros((true_concepts.shape[0], 4, 16), dtype=float)

    for idx in range(4):
        concept_img = true_concepts[:, idx, :]
        mask = (concept_img == -1).all(dim=1)
        mask = ~mask

        concept_vector = torch.zeros((concept_img.shape[0], 16), dtype=float)
    
        colors = concept_img[:, 0].to(int)
        shapes = concept_img[:, 1].to(int)
        materials = concept_img[:, 2].to(int)
        sizes = concept_img[:, 3].to(int)
        
        if mask.sum() != 0:
            concept_vector[mask, 0] = torch.tensor(mask, dtype=float)
            concept_vector[mask, 1:9] = torch.tensor(perm_color.T[colors[mask]], dtype=float)
            concept_vector[mask, 9:12] = torch.tensor(perm_shapes.T[shapes[mask]], dtype=float)
            concept_vector[mask, 12:14] = torch.tensor(perm_material.T[materials[mask]], dtype=float)
            concept_vector[mask, 14:] = torch.tensor(perm_sizes.T[sizes[mask]], dtype=float)
        else:
            concept_vector[:, 0] = torch.tensor(mask, dtype=float)
            
        full_concept_vector[:, idx, :] = concept_vector
        
    y = model.get_pred_from_prob(full_concept_vector.to(model.device).to(torch.float32), True).detach().cpu().numpy()
    y = np.argmax(y, axis=-1)
    
    return accuracy_score(labels.numpy(), y), f1_score(labels.numpy(), y, average="macro")


def evaluate(net, args, dataset):

    # get dataset
    net.device = "cuda:2"
    net.to(net.device)
    if hasattr(net, "encoder"):
        net.encoder.to(net.device)
    if hasattr(net, "net"):
        net.net.to(net.device)

    train_loader, _, test_loader = dataset.get_data_loaders()
    pi = get_hungarian_permutation(net, train_loader)
    ind_data = retrive_concepts_and_labels_hungarian(net, pi, test_loader)
    w_acc, w_f1 = evaluate_knowledge_clevr(net, pi)
    concept_accuracy, concept_f1_macro, label_accuracy, label_f1_macro = compute_metrics(*ind_data)
    return (
        concept_accuracy, concept_f1_macro, label_accuracy, label_f1_macro,
        w_acc, w_f1
    )
    


def compute_metrics(
    true_labels,
    predicted_labels,
    true_concepts,
    predicted_concepts,
    avg_nll,
):

    mask_color = true_concepts[:, :, 0].reshape(-1) != -1
    mask_shapes = true_concepts[:, :, 1].reshape(-1) != -1
    mask_materials = true_concepts[:, :, 2].reshape(-1) != -1
    mask_sizes = true_concepts[:, :, 3].reshape(-1) != -1

    filtered_true_colors = true_concepts[:, :, 0].reshape(-1)[mask_color]
    filtered_predicted_colors = predicted_concepts[:, :, 0].reshape(-1)[mask_color]

    filtered_true_shapes = true_concepts[:, :, 1].reshape(-1)[mask_shapes]
    filtered_predicted_shapes = predicted_concepts[:, :, 1].reshape(-1)[mask_shapes]

    filtered_true_materials = true_concepts[:, :, 2].reshape(-1)[mask_materials]
    filtered_predicted_materials = predicted_concepts[:, :, 2].reshape(-1)[mask_materials]

    filtered_true_sizes = true_concepts[:, :, 3].reshape(-1)[mask_sizes]
    filtered_predicted_sizes = predicted_concepts[:, :, 3].reshape(-1)[mask_sizes]

        
    concept_accuracy_color = accuracy_score(
        filtered_true_colors, filtered_predicted_colors,
    )
    concept_f1_macro_color = f1_score(
        filtered_true_colors, filtered_predicted_colors,
        average="macro",
    )
    concept_accuracy_shape = accuracy_score(
        filtered_true_shapes, filtered_predicted_shapes,
    )
    concept_f1_macro_shape = f1_score(
        filtered_true_shapes, filtered_predicted_shapes,
        average="macro",
    )
    concept_f1_macro_shape = f1_score(
        filtered_true_shapes, filtered_predicted_shapes,
        average="macro",
    )
    concept_accuracy_sizes = accuracy_score(
        filtered_true_sizes, filtered_predicted_sizes,
    )
    concept_f1_macro_sizes = f1_score(
        filtered_true_sizes, filtered_predicted_sizes,
        average="macro",
    )
    concept_f1_macro_sizes = f1_score(
        filtered_true_sizes, filtered_predicted_sizes,
        average="macro",
    )
    concept_accuracy_materials = accuracy_score(
        filtered_true_materials, filtered_predicted_materials,
    )
    concept_f1_macro_materials = f1_score(
        filtered_true_materials, filtered_predicted_materials,
        average="macro",
    )
    concept_f1_macro_materials = f1_score(
        filtered_true_materials, filtered_predicted_materials,
        average="macro",
    )

    concept_accuracy = np.mean([concept_accuracy_color, concept_accuracy_shape, concept_accuracy_materials, concept_accuracy_sizes])
    concept_f1_macro = np.mean([concept_f1_macro_color, concept_f1_macro_shape, concept_f1_macro_materials, concept_f1_macro_sizes])

    label_accuracy = accuracy_score(true_labels, predicted_labels)
    label_f1_macro = f1_score(true_labels, predicted_labels, average="macro")

    return concept_accuracy, concept_f1_macro, label_accuracy, label_f1_macro


def load_data(model_name):
    data, data_ood = [], []

    for which_c in [
        [-2],
        [3, 4],
        [3, 4, 1, 0],
        [3, 4, 1, 0, 2, 8],
        [3, 4, 1, 0, 2, 8, 9, 5],
        [3, 4, 1, 0, 2, 8, 9, 5, 6, 7],
    ]:
        which_c_count = len(which_c) if which_c != [-2] else 0  # Handle -2 as 0 concepts supervised

        strategies = [
            ("clevrcbm", "c_sup_1.0_entropy_2.0", "entropy"),
            ("clevrcbmrec", "c_sup_1.0_rec_1.0", "rec + contrastive"),
            ("clevrcbm", "c_sup_1.0_k_sup_0.0", "0% k sup"),
            ("clevrcbm", "c_sup_1.0_k_sup_0.2", "20% k sup"),
            ("clevrcbm", "c_sup_1.0_k_sup_0.4", "40% k sup"),
            ("clevrcbm", "c_sup_1.0_k_sup_0.6", "60% k sup"),
            ("clevrcbm", "c_sup_1.0_k_sup_0.8", "80% k sup"),
            ("clevrcbm", "c_sup_1.0_k_sup_1.0", "100% k sup"),
        ]

        if "cbm" not in model_name:
            strategies = [
                ("clevrdsldpl", "c_sup_1.0_entropy_2.0", "entropy"),
                ("clevrdsldplrec", "c_sup_1.0_rec_1.0", "rec + contrastive"),
                ("clevrdsldpl", "c_sup_1.0_k_sup_0.0", "0% k sup"),
                ("clevrdsldpl", "c_sup_1.0_k_sup_0.2", "20% k sup"),
                ("clevrdsldpl", "c_sup_1.0_k_sup_0.4", "40% k sup"),
                ("clevrdsldpl", "c_sup_1.0_k_sup_0.6", "60% k sup"),
                ("clevrdsldpl", "c_sup_1.0_k_sup_0.8", "80% k sup"),
                ("clevrdsldpl", "c_sup_1.0_k_sup_1.0", "100% k sup"),
            ]

        from argparse import Namespace
        args = Namespace(
            backbone="conceptizer",  #
            preprocess=0,
            finetuning=0,
            batch_size=32,#256,
            n_epochs=20,
            validate=1,
            dataset="clevr",
            lr=0.001,
            exp_decay=0.99,
            warmup_steps=1,
            wandb=None,
            task="clevr",
            boia_model="ce",
            model=model_name,
            c_sup=1,
            which_c=[-1],
            joint=False,
            boia_ood_knowledge=True,
            splitted=False,
            eps_sym=0.5,
            eps_rul=0.5
        )

        dataset = CLEVR(args)
        encoder, decoder =dataset.get_backbone()

        net = None
        if model_name == "clevrcbm":
            net = ClevrCBM(args=args, encoder=encoder, nr_classes=3)
        elif model_name == "clevrcbmrec":
            net = ClevrCBMRec(args=args, encoder=encoder, decoder=decoder, nr_classes=4)
        elif model_name == "clevrdsldpl":
            net = ClevrDSLDPL(args=args, encoder=encoder, nr_classes=19)
        elif model_name == "clevrdsldplrec":
            net = ClevrDSLDPLRec(args=args, encoder=encoder, decoder=decoder, nr_classes=4)

        for model, strategy, supervision_type in strategies:

            metric_results, ood_metrics = [], []

            for seed in [1011, 1213, 1415, 1617, 1819]:
                base_filename = f"clevr_{model}_{seed}_which_c_{which_c}_{strategy}"
                
                if os.path.exists(base_filename + ".csv"):
                    print(f"{base_filename} already processed")
                    continue
                else:
                    if not os.path.exists(base_filename + ".pth"):
                        print(base_filename, "does not exists")
                        continue

                    print("Loading....")
                    net.load_state_dict(torch.load(base_filename + ".pth"))
                
                    concept_accuracy, concept_f1_macro, label_accuracy, label_f1_macro, w_acc, w_f1 = evaluate(net, args, dataset)

                    header = ["cacc", "cf1", "lacc", "lf1", "kacc", "kf1"]
                    row = [concept_accuracy, concept_f1_macro, label_accuracy, label_f1_macro, w_acc, w_f1]

                    # Save CSV file
                    with open(base_filename + ".csv", mode="w", newline="") as file:
                        writer = csv.writer(file)
                        writer.writerow(header)  # Write header
                        writer.writerow(row)  # Write data row
                    print(base_filename, "done!")

if __name__ == "__main__":
    model_name = "clevrdsldplrec"

    load_data(model_name)
    
