import torchvision
from tqdm import tqdm
import torch
import os
import sys
import numpy as np
import timm

sys.path.append("//")

from subpopbench.models.networks import *
from subpopbench.dataset import datasets


def get_embeddings(arch, n, sc, ci, ai, split):
    pretrained = True
    if arch == "resnet":
        model = torchvision.models.resnet18(pretrained=pretrained)
    elif arch == "clip":
        model = timm.create_model(
            "vit_base_patch32_clip_224.openai", pretrained=True, num_classes=0
        )
    else:
        raise ValueError(f"Architecture {arch} not supported")

    TASK_DIR = f"//output/div_explore/metashift/COCO-Cat-Dog-indoor-outdoor/metadata/datasize{n}_seed0"
    DATA_DIR = "//output/div_explore/metashift/COCO-Cat-Dog-indoor-outdoor/"

    hparams = {"last_layer_dropout": 0, "oversample": False, "undersample": False}

    # make sc, ci, ai as the format of 2 decimal places
    sc = "{:.2f}".format(sc)
    ci = "{:.2f}".format(ci)
    ai = "{:.2f}".format(ai)
    task_metadata_file = f"task_coco_sc{sc}_ci{ci}_ai{ai}.csv"
    task_metadata_path = os.path.join(TASK_DIR, task_metadata_file)
    hparams["metadata"] = task_metadata_path

    # check if metadata file exists
    if not os.path.exists(task_metadata_path):
        print(f"Metadata file {task_metadata_path} does not exist.")
        return
    else:
        # check hparams, datasets

        train_dataset = datasets.COCO(DATA_DIR, split, hparams, train_attr="yes")
        train_finite_loader = torch.utils.data.DataLoader(
            train_dataset,
            num_workers=4,
            drop_last=False,
            batch_size=64,
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        activ, ys, attrs = [], [], []
        model.eval()
        with torch.no_grad():
            for _, x, y, a in tqdm(train_finite_loader):
                feat = model(x.to(device))
                # if feat.squeeze().ndim == 1:
                #     feat = torch.sigmoid(feat).detach().cpu().numpy()
                # else:
                #     feat = torch.softmax(feat, dim=-1).detach().cpu().numpy()
                #     if num_labels == 2:
                #         feat = feat[:, 1]

                activ.append(feat.detach().cpu().numpy())
                ys.append(y)
                attrs.append(a)
        activ = np.concatenate(activ, axis=0)
        ys = np.concatenate(ys, axis=0)
        attrs = np.concatenate(attrs, axis=0)

        return activ, ys, attrs


n = [200, 500, 1000]
# sc_values = [
#     0.01,
#     0.05,
#     0.10,
#     0.30,
#     0.50,
#     0.70,
#     0.90,
#     0.95,
#     0.99,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
# ]
# ci_values = [
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.01,
#     0.05,
#     0.10,
#     0.30,
#     0.70,
#     0.90,
#     0.95,
#     0.99,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
# ]
# ai_values = [
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.01,
#     0.05,
#     0.10,
#     0.30,
#     0.70,
#     0.90,
#     0.95,
#     0.99,
# ]
# sc_values = [
#     0.68,
#     0.5,
#     0.3,
#     0.51,
#     0.4,
#     0.57,
#     0.17,
#     0.51,
#     0.55,
#     0.47,
#     0.49,
#     0.52,
#     0.23,
#     0.12,
#     0.36,
#     0.15,
#     0.5,
#     0.49,
#     0.48,
#     0.26,
#     0.28,
#     0.68,
#     0.27,
#     0.25,
#     0.23,
#     0.59,
#     0.72,
#     0.67,
#     0.64,
#     0.7,
# ]
# ci_values = [
#     0.35,
#     0.47,
#     0.49,
#     0.36,
#     0.11,
#     0.46,
#     0.4,
#     0.65,
#     0.24,
#     0.22,
#     0.29,
#     0.42,
#     0.24,
#     0.26,
#     0.39,
#     0.39,
#     0.7,
#     0.65,
#     0.4,
#     0.67,
#     0.38,
#     0.81,
#     0.39,
#     0.81,
#     0.6,
#     0.42,
#     0.58,
#     0.85,
#     0.24,
#     0.42,
# ]
# ai_values = [
#     0.47,
#     0.31,
#     0.71,
#     0.27,
#     0.65,
#     0.36,
#     0.49,
#     0.7,
#     0.47,
#     0.52,
#     0.55,
#     0.16,
#     0.64,
#     0.8,
#     0.45,
#     0.68,
#     0.65,
#     0.66,
#     0.57,
#     0.17,
#     0.75,
#     0.69,
#     0.53,
#     0.12,
#     0.21,
#     0.46,
#     0.61,
#     0.69,
#     0.36,
#     0.33,
# ]
sc_values, ci_values, ai_values = (
    [
        0.79,
        0.34,
        0.52,
        0.54,
        0.12,
        0.26,
        0.62,
        0.53,
        0.09,
        0.51,
        0.64,
        0.31,
        0.28,
        0.4,
        0.3,
        0.8,
        0.45,
        0.12,
        0.17,
        0.71,
        0.51,
        0.26,
        0.28,
        0.55,
        0.4,
        0.37,
        0.12,
        0.36,
        0.57,
        0.54,
    ],
    [
        0.46,
        0.12,
        0.5,
        0.1,
        0.65,
        0.46,
        0.78,
        0.29,
        0.2,
        0.43,
        0.48,
        0.56,
        0.5,
        0.47,
        0.48,
        0.91,
        0.45,
        0.47,
        0.78,
        0.8,
        0.55,
        0.49,
        0.67,
        0.5,
        0.13,
        0.26,
        0.91,
        0.72,
        0.1,
        0.67,
    ],
    [
        0.36,
        0.76,
        0.39,
        0.54,
        0.36,
        0.52,
        0.74,
        0.32,
        0.87,
        0.82,
        0.71,
        0.58,
        0.36,
        0.84,
        0.67,
        0.79,
        0.2,
        0.47,
        0.22,
        0.59,
        0.27,
        0.66,
        0.37,
        0.57,
        0.65,
        0.65,
        0.13,
        0.28,
        0.43,
        0.62,
    ],
)
# make arch and split as arguments
import argparse

parser = argparse.ArgumentParser(description="Get embeddings from models.")
parser.add_argument("--model", type=str, default="resnet", help="Model to use")
parser.add_argument("--split", type=str, choices=["tr", "te"])
parser.add_argument("--datafolder", type=str, default="coco_v2")

args = parser.parse_args()

arch = args.model
split = args.split
folder = args.datafolder

path = f"//exps/div_explore/{folder}/{arch}_embeddings_{split}/"
# check if path exists
if not os.path.exists(path):
    os.makedirs(path)

for n_samples in n:
    for j in range(len(sc_values)):
        sc = sc_values[j]
        ci = ci_values[j]
        ai = ai_values[j]
        res = get_embeddings(arch, n_samples, sc, ci, ai, split)
        if res is None:
            continue
        else:
            activ, ys, attrs = res
        print(activ.shape)
        # save activ, ys, attrs
        # save in a single file
        np.savez(
            os.path.join(
                path,
                f"coco_n{n_samples}_sc{sc}_ci{ci}_ai{ai}.npz",
            ),
            activ=activ,
            ys=ys,
            attrs=attrs,
        )
        print(
            f"Getting embeddings for n_samples {n_samples}, sc {sc}, ci {ci}, ai {ai}"
        )
