import os
import sys
import argparse
from collections import defaultdict

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
from torch.nn import functional as F

from tqdm.auto import tqdm, trange

import laft
from laft import CLIP, PromptLearner

torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")


def parse_range(range_strs):
    # range format: "start:end(:step)"
    range_set = set()
    for range_str in range_strs:
        start, end, *step = map(int, range_str.split(":"))
        step = step[0] if step else 1
        range_set.update(range(start, end, step))
    return sorted(range_set)


def get_train_loop(
    model: CLIP,
    train_features: torch.Tensor,
    valid_features: torch.Tensor,
    test_features: torch.Tensor,
    valid_attrs: torch.Tensor,
    valid_attend_labels: torch.Tensor,
    valid_ignore_labels: torch.Tensor,
    test_attend_labels: torch.Tensor,
    test_ignore_labels: torch.Tensor,
    transform,
):

    @torch.set_grad_enabled(True)
    def train(
        prompt_learner: PromptLearner,
        n_neighbors: int,
        n_components: int,
        n_samples: int,
        n_epochs: int,
        lr: float,
        weight_decay: float = 0.0,
        seed: int = 0,
        print_every: int = 50,
        autocast: bool = True,
        verbose: bool = True,
    ):
        target_idx = laft.sample_index(valid_attrs, n_samples=n_samples, seed=seed)

        target_features = valid_features[target_idx]
        target_attend_labels = valid_attend_labels[target_idx]
        # target_ignore_labels = valid_ignore_labels[target_idx]

        optimizer = torch.optim.Adam(prompt_learner.parameters(), lr=lr, weight_decay=weight_decay)

        print_fn = tqdm.write if verbose else print
        iterator = trange(n_epochs + 1, ncols=80) if verbose else range(n_epochs + 1)

        best_auroc, best_scores = 0.0, None

        disable_compile = False
        model_encode_prompt = torch.compile(model.encode_prompt, disable=disable_compile)
        laft_prompt_pair = torch.compile(laft.prompt_pair, disable=disable_compile)
        laft_knn = torch.compile(laft.knn, disable=disable_compile)

        for epoch in iterator:
            with torch.autocast(device_type=model.device.type, enabled=autocast):
                text_features = model_encode_prompt(prompt_learner.prompt, prompt_learner.eot_idxs)
                prompt_pair = laft_prompt_pair(text_features)

            try:
                prompt_basis = laft.pca(prompt_pair.float(), n_components=n_components, center=True)
            except RuntimeError:
                print_fn(f"Failed to compute PCA with {n_components} components.")
                return None

            train_laft_features = transform(train_features, prompt_basis)
            target_laft_features = transform(target_features, prompt_basis)

            scores = laft_knn(train_laft_features, target_laft_features, n_neighbors=n_neighbors)
            loss = F.binary_cross_entropy(scores.clamp(0, 1), target_attend_labels.float())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if print_every > 0 and (epoch % print_every == 0 or epoch == n_epochs):
                with torch.inference_mode():
                    auroc = laft.binary_auroc(scores, target_attend_labels)

                    # train_laft_features = transform(train_features, prompt_basis)
                    test_laft_features = transform(test_features, prompt_basis)
                    scores = laft_knn(train_laft_features, test_laft_features, n_neighbors=n_neighbors).detach()

                    attend_auroc = laft.binary_auroc(scores, test_attend_labels)
                    ignore_auroc = laft.binary_auroc(scores, test_ignore_labels)

                if verbose:
                    print_fn(f"AUROC {epoch:4d}: {auroc:.3f} / {attend_auroc:.4f} / {ignore_auroc:.4f}")

                if auroc > best_auroc:
                    best_auroc, best_scores = auroc, scores

        return best_scores
    return train


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-r", "--results-dir", default="results")
    parser.add_argument("-m", "--model-name", default="ViT-B-16:dfn2b")
    parser.add_argument("-d", "--dataset-name", type=str, required=True)
    parser.add_argument("-k", "--n-neighbors", type=int, default=30)
    parser.add_argument("-g", "--guidance", type=str, required=True)
    parser.add_argument("-c", "--n-ctx", type=int, default=4)
    parser.add_argument("-s", "--n-samples", type=int, default=-1)
    parser.add_argument("-rg", "--range", type=str, nargs="+", default=("3:10:1",))
    parser.add_argument("-ns", "--n-seeds", type=int, default=1)
    parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3)
    args = parser.parse_args()

    assert args.dataset_name in ("celeba", "waterbirds", "color_mnist")

    ds = laft.prompts.get_ds(args.dataset_name)
    model, data = laft.get_cached_features(args.model_name, ds.DATASET_NAME, ds.DATASET_CONFIG)
    print()

    oai_model, _ = laft.load_clip("ViT-B/16")
    oai_model.float()
    oai_model.load_state_dict(model.state_dict())

    def replace_quickgelu_with_gelu(module):
        for name, child in module.named_children():
            if name == "gelu":
                setattr(module, name, torch.nn.GELU())
            else:
                replace_quickgelu_with_gelu(child)

    replace_quickgelu_with_gelu(oai_model)

    model = oai_model

    ls = ds.get(data, args.guidance)
    attend_name = ls["attend_name"]
    ignore_name = ls["ignore_name"]
    train_features = ls["train_features"]
    valid_features = ls["valid_features"]
    test_features = ls["test_features"]
    valid_attrs = ls["valid_attrs"]
    valid_attend_labels = ls["valid_attend_labels"]
    valid_ignore_labels = ls["valid_ignore_labels"]
    test_attend_labels = ls["test_attend_labels"]
    test_ignore_labels = ls["test_ignore_labels"]
    words = ls["words"][:200]

    model.float()

    n_components_range = parse_range(args.range)

    metrics = {
        attend_name: defaultdict(list),
        ignore_name: defaultdict(list),
    }

    # No guidance
    scores = laft.knn(train_features, test_features, n_neighbors=args.n_neighbors)
    metrics[attend_name]["KNN"].append(laft.binary_metrics(scores, test_attend_labels))
    metrics[ignore_name]["KNN"].append(laft.binary_metrics(scores, test_ignore_labels))

    # LAFT CoOp
    metric_name = f"{args.guidance.split('_')[0].capitalize()}/{args.guidance.split('_')[1].capitalize()}"
    table_path = os.path.join(
        args.results_dir, ds.DATASET_NAME, "laft_coop", args.guidance,
        f"{args.n_neighbors}_ns{args.n_samples}_nc{args.n_ctx}.txt"
    )

    train = get_train_loop(
        model, train_features, valid_features, test_features, valid_attrs,
        valid_attend_labels=valid_attend_labels,
        valid_ignore_labels=valid_ignore_labels,
        test_attend_labels=test_attend_labels,
        test_ignore_labels=test_ignore_labels,
        transform=(laft.inner if args.guidance.startswith("guide") else laft.orthogonal),
    )

    for n_components in n_components_range:
        metrics[attend_name][f"{metric_name}/{n_components}"] = attend_metrics = []
        metrics[ignore_name][f"{metric_name}/{n_components}"] = ignore_metrics = []

        for seed in range(0, args.n_seeds):
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)

            prompt_learner = PromptLearner(model, words, n_ctx=args.n_ctx)
            scores = train(
                prompt_learner, args.n_neighbors, n_components, args.n_samples, n_epochs=3000,
                lr=args.learning_rate, seed=seed, verbose=True,
                autocast=True,
            )

            if scores is None:
                prompt_learner = PromptLearner(model, words, n_ctx=args.n_ctx)
                scores = train(
                    prompt_learner, args.n_neighbors, n_components, args.n_samples, n_epochs=3000,
                    lr=args.learning_rate, seed=seed, verbose=True,
                    autocast=False,
                )

            if scores is None:
                print(f"Failed to train CoOp for {n_components} components.")
                attend_metrics = ignore_metrics = {
                    "auroc": float("nan"), "auprc": float("nan"), "accuracy": float("nan"), "f1": float("nan"), "fpr95": float("nan"),
                }
                break

            attend_metrics.append(laft.binary_metrics(scores, test_attend_labels))
            ignore_metrics.append(laft.binary_metrics(scores, test_ignore_labels))

        table = laft.utils.build_table(metrics, group_headers=("Guide", "Attr.", "#C"))
        print(table)
        laft.utils.save_table(table, table_path)
