#!/usr/bin/env python3

import os
from torchvision import transforms
import torch
from torch.utils import data
import os
from tqdm import tqdm
from utils.metrics import AverageMeter, AverageMeterSubgroups, accuracy, get_subgroup_masks, accuracy_subgroup, regroup_by
from torch.utils import data
import matplotlib.pyplot as plt
from vanilla_builder import VanillaModels, set_seed
import kmedoids, torchvision
from pathlib import Path
from wandb_wrapper import WandbWrapper
import argparse


def zip_folder(folder_path, zip_file_path, new_name=None):
    os.system(f"cp -r {folder_path} {os.environ['NAMING_BIASES_DATA_PATH']}/datasets")
    os.system(f"rm -f {zip_file_path}")
    print(f"Removed old output before logging a zip")
    os.system(f"zip {zip_file_path} {folder_path} -r")
    print(f"Archive created")
    print(f"Output copied in {os.environ['NAMING_BIASES_DATA_PATH']}/datasets")

def clean_previous_results(directory_path):
    dir_path = Path(directory_path)
    if dir_path.exists() and dir_path.is_dir():
        for item in dir_path.iterdir():
            if item.is_dir():
                clean_previous_results(item)  
            else:
                item.unlink() 
        dir_path.rmdir()  
        print(f"Directory '{directory_path}' removed successfully.")
    else:
        print(f"Directory '{directory_path}' does not exist or is not a directory.")

@torch.no_grad()
def evaluate_model(model, test_loader, num_classes, num_bias_attributes, criterion, device, wb):
    model.eval()
    groups_size = (num_classes, ) * (num_bias_attributes + 1) 
    loss_task_tot   : AverageMeter = AverageMeter()
    top1            : AverageMeter = AverageMeter()
    subgroup_top1   : AverageMeterSubgroups = AverageMeterSubgroups(
        size   = groups_size, 
        device = device
    )
    
    tk0 = tqdm(
        test_loader, total=int(len(test_loader)), leave=True, dynamic_ncols=True
    )
    
    for batch, (dat, labels, _) in enumerate(tk0):
        dat     : torch.Tensor = dat.to(device)
        target  : torch.Tensor = labels[0].to(device)
        bias_t  : torch.Tensor = labels[1].to(device)
        output  : torch.Tensor = model(dat)
        
        loss    : torch.Tensor = criterion(output, target)        
        loss_task_tot.update(loss.item(), dat.size(0))
        
        acc1 = accuracy(output, target, topk=(1, ))
        subgroup_masks = get_subgroup_masks(
            labels = labels, 
            num_classes = groups_size, 
            device = device
        )
        subgroup_acc1 = accuracy_subgroup(output, target, subgroup_masks, num_classes=num_classes)
        
        top1.update(acc1[0], dat.size(0))
        subgroup_top1.update(subgroup_acc1, subgroup_masks)

        acc1  = top1.avg
        acc_a = regroup_by(subgroup_top1, ("aligned", ))[0].item()
        acc_m = regroup_by(subgroup_top1, ("misaligned",))[0].item()

        tk0.set_postfix(
            acc1 = acc1,
            acc_a = acc_a,
            acc_m = acc_m,
        )
        if wb is not None:
            wb.log_output({"acc_1": acc1, "acc_a": acc_a, "acc_m": acc_m})

@torch.no_grad()
def softmax_distribution_hist(network_outputs: torch.Tensor, targets, biases, target_class: int, epoch: int, wb):
    import seaborn as sns
    target_mask = targets == target_class
    aligned_mask = targets[target_mask] == biases[target_mask]
    conflicting_mask = targets[target_mask] != biases[target_mask]
    network_outputs = network_outputs[target_mask]
    network_outputs = network_outputs.double()     # numerical stability
    softmax_probs: torch.Tensor = torch.nn.functional.softmax(network_outputs, dim=1).clamp(min=1e-6, max=1.0) # logits to probs
    
    softmax_on_target = softmax_probs[:, target_class]    

    aligned_probs = softmax_on_target[aligned_mask]
    conflicting_probs = softmax_on_target[conflicting_mask]

    aligned_probs_np = aligned_probs.cpu().numpy()
    conflicting_probs_np = conflicting_probs.cpu().numpy()

    plt.rcParams["text.usetex"] = False
    plt.rcParams['font.family'] = 'serif'
    plt.rcParams["font.size"]  = 18
    plt.rcParams["font.weight"]  = "bold"
    plt.rcParams["axes.labelweight"] = "bold"
    plt.rcParams['axes.labelsize'] = 'medium'  # Label font size
    plt.rcParams['xtick.labelsize'] = 'small'  # X-axis tick font size
    plt.rcParams['ytick.labelsize'] = 'small'  # Y-axis tick font size
    plt.rcParams['legend.fontsize'] = 'medium'  # Legend font size
    plt.rcParams['lines.linewidth'] = 5.0  # Line width
    plt.rcParams['lines.markersize'] = 8  # Marker size
    plt.figure(figsize=(10, 10))
    plt.hist(aligned_probs_np, bins=50, alpha=0.5, color='#377eb8', density=True, label='Aligned')
    plt.hist(conflicting_probs_np, bins=50, alpha=0.5, color='#ff7f00', density=True, label='Conflicting')

    sns.kdeplot(aligned_probs_np, color='#377eb8', linewidth=2)
    sns.kdeplot(conflicting_probs_np, color='#ff7f00', linewidth=2)

    plt.axvline(0.5, color='r', linestyle='dashed', linewidth=2, label='Random Guess')
    plt.xlim((0.0, 1.0))
    plt.ylabel("Density (%)")
    plt.grid()
    plt.xlabel("Softmax Output")
    plt.legend()
    
    if wb is not None:
        wb.log_output({f"Class_{target_class}_probs_epoch_{epoch}": wb.backend.Image(plt)})
    
    plt.savefig(f"Class_{target_class}_probs_epoch_{epoch}.pdf", format="pdf", dpi=1200)
    plt.close()

@torch.no_grad()
def misclassified_distance_from_target(network_outputs: torch.Tensor, y_true: torch.Tensor):
    network_outputs = network_outputs.double()     # numerical stability
    _, preds = torch.max(network_outputs, dim=1)   # extract predictions
    misclassified_mask = preds != y_true           # exclude correctly classified samples
    network_outputs = network_outputs[misclassified_mask] 
    
    softmax_probs: torch.Tensor = torch.nn.functional.softmax(network_outputs, dim=1).clamp(min=1e-6, max=1.0) # logits to probs
    gathered_on_target: torch.Tensor = torch.gather(
        softmax_probs, dim=1, 
        index=torch.unsqueeze(y_true[misclassified_mask], dim=1)
    ) # onehot(y')_ij for i in N, j in C, j == y  
    
    gathered_on_pred: torch.Tensor = torch.gather(
        softmax_probs, dim=1, 
        index=torch.unsqueeze(preds[misclassified_mask], dim=1)
    ) # onehot(y')_ij for i in N, j in C, j == y  
    
    dist_from_target = ((
        gathered_on_pred.clamp(min=1e-6, max=1.0) - \
        gathered_on_target.clamp(min=1e-6, max=1.0)
    ) / 2).clamp(min=1e-6, max=1.0)
    
    return dist_from_target.mean(), dist_from_target.size(0) # Average, Size

def train_model_dist_from_target(model, train_loader, device, criterion, optimizer, num_classes, epochs=10, check_dist=10, warmup=50, wb=None, make_figures=True):
    max_dist=0
    cur_model_name = f"{model_name}-biased-init.pt"
    print(f"Saving {cur_model_name}")
    torch.save(model.state_dict(), os.path.join(PATH_TO_MODELS,cur_model_name))

    for epoch in range(epochs):
        model.train(True)
        loss_task_tot = AverageMeter()
        top1 = AverageMeter()
        dist_avg = AverageMeter()
        subgroup_top1 = AverageMeterSubgroups((num_classes, )*(num_bias_attributes+1), device=device)
        tk0 = tqdm(
            train_loader, total=int(len(train_loader)), leave=True, dynamic_ncols=True
        )
        epoch_outputs = []
        epoch_targets = []
        epoch_biases  = []
        for batch, (dat, labels, _) in enumerate(tk0):
            dat = dat.to(device)
            target = labels[0].to(device)
            bias_l = labels[1].to(device)
            output = model(dat)
            
            if make_figures:
                epoch_outputs.append(output)
                epoch_targets.append(target)
                epoch_biases.append(bias_l)

            dist, num_elems = misclassified_distance_from_target(network_outputs=output, y_true=target)
            dist_avg.update(dist.clamp(min=1e-6, max=1.0), num_elems)
        
            loss_task = criterion(output, target)
            loss_task_tot.update(loss_task.item(), dat.size(0))
            loss_task.backward()
            optimizer.step()
            optimizer.zero_grad()
            acc1 = accuracy(output, target, topk=(1,))
            subgroup_masks = get_subgroup_masks(labels, num_classes=(num_classes,)*(num_bias_attributes+1),device=device)
            subgroup_acc1 = accuracy_subgroup(output, target, subgroup_masks, num_classes=num_classes)
            top1.update(acc1[0], dat.size(0))
            subgroup_top1.update(subgroup_acc1, subgroup_masks)
            tk0.set_postfix(epoch=epoch,
                            acc1=top1.avg, 
                            acc_a = regroup_by(subgroup_top1, ("aligned",))[0].item(), 
                            acc_m = regroup_by(subgroup_top1, ("misaligned",))[0].item(),
                            dist = torch.as_tensor(dist_avg.avg).item())
            if batch % check_dist == 0 :
                if max_dist < torch.as_tensor(dist_avg.avg).item() and ((epoch + 1) * (batch + 1)) >= warmup:
                    max_dist = torch.as_tensor(dist_avg.avg).item()
                    cur_model_name = f"{model_name}-biased-updated.pt"
                    print(f"Saving {cur_model_name}")
                    torch.save(model.state_dict(), os.path.join(PATH_TO_MODELS,cur_model_name))
        if make_figures:
            epoch_outputs = torch.cat(epoch_outputs, dim=0)
            epoch_targets = torch.cat(epoch_targets, dim=0)
            epoch_biases  = torch.cat(epoch_biases, dim=0)
            softmax_distribution_hist(epoch_outputs, epoch_targets, epoch_biases, target_class=0, epoch=epoch, wb=wb)
            softmax_distribution_hist(epoch_outputs, epoch_targets, epoch_biases, target_class=1, epoch=epoch, wb=wb)
    
def extract_samples(model, device, train_loader, criterion, num_classes, num_bias_attributes):
    idx_correct = torch.zeros(size=(1, num_classes), device=device)
    idx_incorrect = torch.zeros(size=(1, num_classes),device=device)
    loss_task_tot = AverageMeter()
    top1 = AverageMeter()
    subgroup_top1 = AverageMeterSubgroups((num_classes, ) * (num_bias_attributes+1), device=device)
    tk0 = tqdm(
        train_loader, total=int(len(train_loader)), leave=True, dynamic_ncols=True, desc="Misclassified Samples Extraction"
    )

    images = {i: {"correct": {"counter": 0, "imgs": [], "ids": []}, 
                    "incorrect": {"counter": 0, "imgs": [], "ids": []}} for i in range(num_classes)}
    with torch.no_grad():
        for _, (dat, labels, idx) in enumerate(tk0):
            dat = dat.to(device)
            target = labels[0].to(device)
            idx = idx.to(device)
            
            output = model(dat)            
            target_one_hot = torch.nn.functional.one_hot(target, num_classes)
            predicted = torch.argmax(output, dim=1)
            predicted_one_hot = torch.nn.functional.one_hot(predicted, num_classes)
            
            aligned_mask    = (predicted == target).unsqueeze(-1)
            misaligned_mask = (torch.logical_not(aligned_mask))            
            
            aligned_mask    = aligned_mask * target_one_hot
            misaligned_mask = misaligned_mask * predicted_one_hot
            
            idx_correct   = torch.concat((idx_correct,   (idx.unsqueeze(-1) +1) * aligned_mask), dim=0)
            idx_incorrect = torch.concat((idx_incorrect, (idx.unsqueeze(-1) +1) * misaligned_mask), dim=0)

            loss_task = criterion(output, target)
            loss_task_tot.update(loss_task.item(), dat.size(0))
            acc1 = accuracy(output, target, topk=(1,))
            subgroup_masks = get_subgroup_masks(labels, num_classes=(num_classes,)*(num_bias_attributes+1),device=device)
            subgroup_acc1 = accuracy_subgroup(output, target, subgroup_masks, num_classes=num_classes)
            top1.update(acc1[0], dat.size(0))
            subgroup_top1.update(subgroup_acc1, subgroup_masks)
            tk0.set_postfix(acc1=top1.avg, 
                            acc_a = regroup_by(subgroup_top1, ("aligned",))[0].item(), 
                            acc_m = regroup_by(subgroup_top1, ("misaligned",))[0].item())
        images = {i: {"correct": {"ids": idx_correct[:,i][idx_correct[:,i].nonzero()].transpose(1,0) - 1}, 
                    "incorrect": {"ids": idx_incorrect[:,i][idx_incorrect[:,i].nonzero()].transpose(1,0) -1}} for i in range(num_classes)}
        
        return images
    
def build_subdata_subdataloader(dataset, images, num_classes):
    ds = {
        target_class: {
            subset: torch.utils.data.Subset(
                dataset, 
                images[target_class][subset]["ids"][0].cpu().numpy().astype(int)
            ) for subset in ["correct", "incorrect"] 
            if len(images[target_class][subset]) > 0
        } for target_class in range(num_classes)
    }

                                                         
    dls = {
        target_class: {
            subset: data.DataLoader(
                ds[target_class][subset], 
                batch_size=128,
                shuffle=False,
                num_workers=8,
                pin_memory=True
            ) for subset in ["correct", "incorrect"] if len(ds[target_class][subset]) > 0
        } for target_class in range(num_classes)
    }
    

    return ds, dls

def extract_bottleneck_reps(model, bottleneck, dls, num_classes):
    loader_gen = lambda _dict : {
        target_class: {
            subset: data.DataLoader(
                        _dict[target_class][subset], 
                        batch_size=len(_dict[target_class][subset]),
                        shuffle=False,
                        num_workers=8,
                        pin_memory=True
                    ) for subset in ["correct", "incorrect"]  if len(_dict[target_class][subset]) > 0
        } for target_class in range(num_classes)
    }
    
    features    = {t:{s: [] for s in ["correct", "incorrect"]} for t in range(num_classes)}
    targets     = {t:{s: [] for s in ["correct", "incorrect"]} for t in range(num_classes)}
    indices     = {t:{s: [] for s in ["correct", "incorrect"]} for t in range(num_classes)}
    
    model.eval()
    with torch.no_grad():
        for target in range(num_classes):
            if "correct" in dls[target].keys():
                for subset in dls[target].keys():
                    tq = tqdm(dls[target][subset], desc=f"Class {target} ({subset}) (bottleneck representations)", total=int(len(dls[target][subset])), leave=True, dynamic_ncols=True)
                    for _, (img, _labels, idx) in enumerate(tq):
                        img = img.to(device)
                        _labels = _labels[0].to(device)
                        idx = idx.to(device)
                        _ = model(img)
                        out_bot = bottleneck.output.clone().detach().squeeze(-1).squeeze(-1)

                        features[target][subset].append(
                            out_bot[:, 0] if len(out_bot.size()) > 2 else out_bot
                        )                
                        targets[target][subset].append(_labels)
                        indices[target][subset].append(idx)          

                    features[target][subset] = torch.cat(features[target][subset], dim=0).cpu()
                    targets[target][subset]  = torch.cat(targets[target][subset],  dim=0).cpu()
                    indices[target][subset]  = torch.cat(indices[target][subset],  dim=0).cpu()
            else:
                print(f"Skipping class {target}, as it has no correctly classified samples", end="\r", flush=True)

    f_dls = loader_gen(features)
    l_dls = loader_gen(targets)
    i_dls = loader_gen(indices)
    
    return f_dls, l_dls, i_dls

def extract_kmedoids(model, f_dls, l_dls, i_dls, ds, num_classes, dataset_name, K=10):
    base_dir = f"./medoids_results/{dataset_name}_mined_bias_exemplars"    
    img_avg = torch.as_tensor([0.485, 0.456, 0.406])[None, :, None, None] 
    img_std = torch.as_tensor([0.229, 0.224, 0.225])[None, :, None, None]
    
    torch.cuda.empty_cache()
    os.makedirs(base_dir, exist_ok=True)
    os.makedirs(os.path.join(base_dir, "imgs"), exist_ok=True)
    os.makedirs(os.path.join(base_dir, f"biases-k-{K}"), exist_ok=True)
    model.eval()
    with torch.no_grad():            
        for target in range(num_classes):
            if dataset_name == "imagenet-a":
                os.makedirs(os.path.join(base_dir, "imgs", str(target)), exist_ok=True)
            else:
                os.makedirs(os.path.join(base_dir, "imgs"), exist_ok=True)
            os.makedirs(os.path.join(base_dir, f"biases-k-{K}", str(target)), exist_ok=True)    
            for subset in f_dls[target].keys():
                # try:        
                    tk0 = tqdm(zip(f_dls[target][subset], l_dls[target][subset], i_dls[target][subset]),
                        desc=f"Class {target} ({subset}) (medoids)",
                        total=int(len(f_dls[target][subset])),
                        leave=True,
                        dynamic_ncols=True
                    )
                    
                    class_centroid = torch.zeros((1, nb_features))
                    for _, (out_bot, labels, idx, ) in enumerate(tk0):
                        if subset == "correct":         
                            if dataset_name == "celeba":  
                                out_bot = out_bot[::10]
                            distance_matrix = torch.cdist(out_bot.squeeze(-1), out_bot.squeeze(-1), p=2).numpy()

                        
                        kmed: kmedoids.KMedoidsResult = kmedoids.fasterpam(diss=distance_matrix, medoids=K, max_iter=100)
                        match subset:
                            case "correct":                    
                                class_centroid = torch.mean(out_bot, dim=0).unsqueeze(0)                                             
                                with open(os.path.join(base_dir, f"biases-k-{K}", str(target), f"{subset}ly-classified.txt"), mode="w+") as f:
                                    for m in kmed.medoids:
                                        img = ds[target][subset][m][0]
                                        img = (img * img_std) + img_avg                                    
                                        if dataset_name == "imagenet-a":
                                            torchvision.utils.save_image(img, os.path.join(base_dir, "imgs", str(target), f"{idx[m]}.png"))                             
                                        else: torchvision.utils.save_image(img, os.path.join(base_dir, "imgs", f"{idx[m]}.png"))                             
                                        f.write(f"{idx[m]}\n")

                            case "incorrect":
                                dists_from_centroid = torch.cdist(class_centroid, out_bot, p=2).squeeze(0)
                                sorting_idx = torch.argsort(dists_from_centroid, descending=True)
                                with open(os.path.join(base_dir, f"biases-k-{K}", str(target), f"{subset}ly-classified.txt"), mode="w+") as f:                                    
                                    for i, (_index, dist) in enumerate(zip(idx[sorting_idx], dists_from_centroid[sorting_idx])):
                                        img = ds[target][subset][i][0]
                                        img = (img * img_std) + img_avg
                                        if dataset_name == "imagenet-a":
                                            torchvision.utils.save_image(img, os.path.join(base_dir, "imgs", str(target), f"{_index}.png"))                             
                                        else: torchvision.utils.save_image(img, os.path.join(base_dir, "imgs", f"{_index}.png"))                             
                                        f.write(f"{_index}\n")                                    
                # except:
                #     print(f"Skipping target {target}:{subset}, not enough correct/incorrect")
                        

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="waterbirds", required=True, help="dataset name. choose in [waterbirds, bar, celeba, imagenet-a]")
parser.add_argument("--use_wb", type=str, default="true", help="whether to use weights and biases logging or not, default=true")
parser.add_argument("--retrain", type=str, default="false", help="repeat experiment and overwrite vanilla model, default=false")
parser.add_argument("--model", type=str, default="resnet50", help="which model to use, default=resnet50. vitb16 available for Waterbirds and Imagenet-A, swinv2b available for ImageNet-A")    
parser.add_argument("--evaluate_test", type=str, default="false", help="run model in inference against misaligned samples and extract exemplars")
parser.add_argument("--k", type=str, default=10, help="K for the K-medoids step, default=10")
parser.add_argument("--ablation_on_k", type=str, default="false", help="Ablation study on K for Waterbirds and ResNet-50, default=false, if set to true overwrites other arguments")

if __name__ == "__main__":    
    args = parser.parse_args()
    
    dataset_name = args.dataset        
    model_arch = "resnet50" if args.model is None else args.model
    use_wb = args.use_wb == "true" if args.use_wb else True
    setup_test = args.evaluate_test == "true" if args.evaluate_test else False
    retrain = args.retrain == "true" if args.retrain else False

    ablation_on_k = args.ablation_on_k == "true" if args.ablation_on_k else False

    if ablation_on_k:
        dataset_name = "waterbirds"
        model_arch   = "resnet50"
        setup_test   = False
    
    print("Using w&b: ", use_wb)

    match dataset_name:
        case "waterbirds": 
            build_dict, config = VanillaModels.WaterbirdsModel(setup_test=setup_test, model_name=model_arch)
        case "celeba":
            build_dict, config = VanillaModels.CelebAModel(setup_test=setup_test)
        case "bar":
            build_dict, config = VanillaModels.BARModel()
        case "imagenet-a": 
            build_dict, config = VanillaModels.ImageNetAModel(setup_test=setup_test, model_name=model_arch)    
        case _:
            print("Unsupported Dataset. Choose one from [waterbirds, bar, celeba, imagenet-a]")
            exit(-1)

    PATH_TO_MODELS = VanillaModels.PATH_TO_MODELS
    K = args.k

    config["medoids_K"] = K

    model_name          = config["model_name"]
    num_classes         = config["num_classes"]
    dataset_args        = config["dataset_args"]
    num_bias_attributes = config["num_bias_attributes"]
    seed                = config["seed"]

    model               = build_dict["model"]
    device              = build_dict["device"]
    bottleneck          = build_dict["bottleneck"]
    nb_features         = build_dict["nb_features"]
    dataset             = build_dict["dataset"]
    warmup              = build_dict["warmup"]
    check_dist          = build_dict["check_dist"]
    train_epochs        = build_dict["train_epochs"]
    train_loader        = build_dict["train_loader"]
    test_loader         = build_dict["test_loader"] if setup_test else "none"
    criterion           = build_dict["criterion"]
    optimizer           = build_dict["optimizer"]

    
    set_seed(seed)
    if use_wb:
        wb = WandbWrapper(
            project_name="Bias-Mining", 
            config=config
        ) 
    else: wb = None

    if setup_test:
        if dataset_name != "imagenet-a" and retrain == False: 
            print("Loading pre-trained vanilla model")
            model.load_state_dict(torch.load(f"./data/saved_models/{model_name}-biased-updated.pt"))                    
            model_name = f"{model_name}-biased-updated.pt"
        elif dataset_name != "imagenet-a" and retrain:
            train_model_dist_from_target(model, train_loader, device, criterion, optimizer, num_classes, epochs=train_epochs, warmup=warmup, check_dist=check_dist, wb=wb, make_figures=False)
            model.load_state_dict(torch.load(f"./data/saved_models/{model_name}-biased-updated.pt"))    
        
        model.eval()    
        evaluate_model(model, train_loader, num_classes, num_bias_attributes, criterion, device, wb)
        exit(0)

    try:
        if retrain == True and dataset_name != "imagenet-a":
                train_model_dist_from_target(model, train_loader, device, criterion, optimizer, num_classes, epochs=train_epochs, warmup=warmup, check_dist=check_dist, wb=wb, make_figures=False)
        if dataset_name != "imagenet-a":
            print("Loading pre-trained vanilla model")
            model.load_state_dict(torch.load(f"./data/saved_models/{model_name}-biased-updated.pt"))
    except:    
        if dataset_name != "imagenet-a":
            print("Model unavailable, training one")
            train_model_dist_from_target(model, train_loader, device, criterion, optimizer, num_classes, epochs=train_epochs, warmup=warmup, check_dist=check_dist, wb=wb, make_figures=False)
            model.load_state_dict(torch.load(f"./data/saved_models/{model_name}-biased-updated.pt"))    
    
    model.eval()    
    images = extract_samples(model, device, train_loader, criterion, num_classes, num_bias_attributes)
    dataset.transform = torchvision.transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    ds, dls = build_subdata_subdataloader(dataset, images, num_classes)
    f_dls, l_dls, i_dls = extract_bottleneck_reps(model, bottleneck, dls, num_classes)

    clean_previous_results(f"./medoids_results/{dataset_name}_mined_bias_exemplars")
    if ablation_on_k:
        for k in [1, 5, 10, 25, 50]:
            extract_kmedoids(model, f_dls, l_dls, i_dls, ds, num_classes, dataset_name, K=k)
    else:
        extract_kmedoids(model, f_dls, l_dls, i_dls, ds, num_classes, dataset_name, K=K)

    if dataset_name == "imagenet-a":
        for dir in sorted(os.listdir(f"./medoids_results/{dataset_name}_mined_bias_exemplars/imgs")):
            if str(dir).split("/")[-1] in {"124", "306", "313", "314"}:
                os.system(f"mv ./medoids_results/{dataset_name}_mined_bias_exemplars/imgs/{dir}/*.png ./medoids_results/{dataset_name}_mined_bias_exemplars/imgs")
                clean_previous_results(f"./medoids_results/{dataset_name}_mined_bias_exemplars/imgs/{dir}")    
            clean_previous_results(f"./medoids_results/{dataset_name}_mined_bias_exemplars/imgs/{dir}")
        for dir in sorted(os.listdir(f"./medoids_results/{dataset_name}_mined_bias_exemplars/biases-k-{args.k}")):
            if str(dir).split("/")[-1] in {"124", "306", "313", "314"}:
                continue
            clean_previous_results(f"./medoids_results/{dataset_name}_mined_bias_exemplars/biases-k-{args.k}/{dir}")    
    
    zip_name = None
    match dataset_name:
        case "waterbirds"   : 
            if model_arch in {"vitb16", "swinv2b"}:                    
                os.system(f"mv ./medoids_results/waterbirds_mined_bias_exemplars ./medoids_results/waterbirds_mined_bias_exemplars_{model_arch}")
                zip_folder(f"./medoids_results/waterbirds_mined_bias_exemplars_{model_arch}", f"./waterbirds_mined_bias_exemplars_{model_arch}.zip")
                zip_name = f"waterbirds_mined_bias_exemplars_{model_arch}.zip"
            else:
                os.system(f"mv ./medoids_results/waterbirds_mined_bias_exemplars ./medoids_results/waterbirds_mined_bias_exemplars_seed_{seed}")
                zip_folder(f"./medoids_results/waterbirds_mined_bias_exemplars_seed_{seed}", f"./waterbirds_mined_bias_exemplars_seed_{seed}.zip")
                zip_name = f"waterbirds_mined_bias_exemplars_seed_{seed}.zip"
        case "celeba": 
            os.system(f"mv ./medoids_results/celeba_mined_bias_exemplars ./medoids_results/celeba_mined_bias_exemplars_v5")
            zip_folder(f"./medoids_results/celeba_mined_bias_exemplars_v5", f"./celeba_mined_bias_exemplars_v5.zip")
            zip_name = f"celeba_mined_bias_exemplars_v5.zip"    
        case "bar": 
            os.system(f"mv ./medoids_results/bar_mined_bias_exemplars ./medoids_results/bar_mined_bias_exemplars-v2")
            zip_folder(f"./medoids_results/bar_mined_bias_exemplars-v2", f"./bar_mined_bias_exemplars-v2.zip")
            zip_name = f"bar_mined_bias_exemplars-v2.zip"
        case "imagenet-a": 
            if model_arch in {"vitb16", "swinv2b"}:                    
                os.system(f"mv ./medoids_results/imagenet-a_mined_bias_exemplars ./medoids_results/insects-on-hand_mined_bias_exemplars_{model_arch}")
                zip_folder(f"./medoids_results/insects-on-hand_mined_bias_exemplars_{model_arch}", f"./insects-on-hand_mined_bias_exemplars_{model_arch}.zip")
                zip_name = f"insects-on-hand_mined_bias_exemplars_{model_arch}.zip"
            else: 
                os.system(f"mv ./medoids_results/imagenet-a_mined_bias_exemplars ./medoids_results/insects-on-hand_mined_bias_exemplars")
                zip_folder(f"./medoids_results/insects-on-hand_mined_bias_exemplars", f"./insects-on-hand_mined_bias_exemplars.zip")   
                zip_name = f"insects-on-hand_mined_bias_exemplars.zip"
                    

    if wb is not None:
        model_name = f"{model_name}-biased-updated.pt"
        wb_model_name = f"{dataset_name}_vanilla.pt"
        wb.log_model(model=model, model_name=wb_model_name)
        print("Logging output zip to w&b...")
        zip_artifact = wb.backend.Artifact(zip_name, type="dataset")
        zip_artifact.add_file(zip_name)
        wb.backend.log_artifact(zip_artifact)

        wb.finish()