import torchvision
from tqdm import tqdm
import torch
import os
import sys
import numpy as np
import timm

from subpopbench.models.networks import *
from subpopbench.dataset import datasets


def get_embeddings(arch, y, a, n, sc, ci, ai, split):
    pretrained = True
    if arch == "resnet":
        model = torchvision.models.resnet18(pretrained=pretrained)
    elif arch == "clip":
        model = timm.create_model(
            "vit_base_patch32_clip_224.openai", pretrained=True, num_classes=0
        )
    else:
        raise ValueError(f"Architecture {arch} not supported")

    TASK_DIR = f"YOUR_DIR"
    DATA_DIR = "YOUR_DIR"

    hparams = {"last_layer_dropout": 0, "oversample": False, "undersample": False}

    # make sc, ci, ai as the format of 2 decimal places
    sc = "{:.2f}".format(sc)
    ci = "{:.2f}".format(ci)
    ai = "{:.2f}".format(ai)
    task_metadata_file = f"task_celeba_sc{sc}_ci{ci}_ai{ai}.csv"
    task_metadata_path = os.path.join(TASK_DIR, task_metadata_file)
    hparams["metadata"] = task_metadata_path

    # check hparams, datasets

    train_dataset = datasets.CelebA(DATA_DIR, split, hparams, train_attr="yes")
    train_finite_loader = torch.utils.data.DataLoader(
        train_dataset,
        num_workers=4,
        drop_last=False,
        batch_size=64,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    activ, ys, attrs = [], [], []
    model.eval()
    with torch.no_grad():
        for _, x, y, a in tqdm(train_finite_loader):
            feat = model(x.to(device))
            # if feat.squeeze().ndim == 1:
            #     feat = torch.sigmoid(feat).detach().cpu().numpy()
            # else:
            #     feat = torch.softmax(feat, dim=-1).detach().cpu().numpy()
            #     if num_labels == 2:
            #         feat = feat[:, 1]

            activ.append(feat.detach().cpu().numpy())
            ys.append(y)
            attrs.append(a)
    activ = np.concatenate(activ, axis=0)
    ys = np.concatenate(ys, axis=0)
    attrs = np.concatenate(attrs, axis=0)

    return activ, ys, attrs


task_list = [[2, 31], [21, 36], [8, 20], [25, 19]]
n = [200, 500, 1000, 2000, 5000, 10000]
# sc_values = [
#     0.01,
#     0.05,
#     0.10,
#     0.30,
#     0.50,
#     0.70,
#     0.90,
#     0.95,
#     0.99,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
# ]
# ci_values = [
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.01,
#     0.05,
#     0.10,
#     0.30,
#     0.70,
#     0.90,
#     0.95,
#     0.99,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
# ]
# ai_values = [
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.50,
#     0.01,
#     0.05,
#     0.10,
#     0.30,
#     0.70,
#     0.90,
#     0.95,
#     0.99,
# ]

sc_values = [
    0.68,
    0.5,
    0.3,
    0.51,
    0.4,
    0.57,
    0.17,
    0.51,
    0.55,
    0.47,
    0.49,
    0.52,
    0.23,
    0.12,
    0.36,
    0.15,
    0.5,
    0.49,
    0.48,
    0.26,
    0.28,
    0.68,
    0.27,
    0.25,
    0.23,
    0.59,
    0.72,
    0.67,
    0.64,
    0.7,
]
ci_values = [
    0.35,
    0.47,
    0.49,
    0.36,
    0.11,
    0.46,
    0.4,
    0.65,
    0.24,
    0.22,
    0.29,
    0.42,
    0.24,
    0.26,
    0.39,
    0.39,
    0.7,
    0.65,
    0.4,
    0.67,
    0.38,
    0.81,
    0.39,
    0.81,
    0.6,
    0.42,
    0.58,
    0.85,
    0.24,
    0.42,
]
ai_values = [
    0.47,
    0.31,
    0.71,
    0.27,
    0.65,
    0.36,
    0.49,
    0.7,
    0.47,
    0.52,
    0.55,
    0.16,
    0.64,
    0.8,
    0.45,
    0.68,
    0.65,
    0.66,
    0.57,
    0.17,
    0.75,
    0.69,
    0.53,
    0.12,
    0.21,
    0.46,
    0.61,
    0.69,
    0.36,
    0.33,
]

# make arch and split as arguments
import argparse

parser = argparse.ArgumentParser(description="Get embeddings from models.")
parser.add_argument("--arch", type=str, default="resnet", help="Model to use")
parser.add_argument("--split", type=str, choices=["tr", "te"])

args = parser.parse_args()

arch = args.arch
split = args.split

path = f"YOUR_DIR"
# check if path exists
if not os.path.exists(path):
    os.makedirs(path)

for i, (y, a) in enumerate(task_list):
    for n_samples in n:
        for j in range(len(sc_values)):
            sc = sc_values[j]
            ci = ci_values[j]
            ai = ai_values[j]
            activ, ys, attrs = get_embeddings(arch, y, a, n_samples, sc, ci, ai, split)
            print(activ.shape)
            # save activ, ys, attrs
            # save in a single file
            np.savez(
                os.path.join(
                    path,
                    f"celeba_y{y}_a{a}_n{n_samples}_sc{sc}_ci{ci}_ai{ai}.npz",
                ),
                activ=activ,
                ys=ys,
                attrs=attrs,
            )
        print(f"Getting embeddings for task {i}, n_samples {n_samples}")
