import copy
import csv
import pickle
from collections import Counter

import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from sentence_transformers import util
from torch import optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from tqdm import tqdm

from vds_module import CoupleDataset, Cluster
from vds_load import NeoLoader, Metric
from vds_shared import DEVICE, CACHE_OUTS_DIR, VISUAL_OUTS_DIR, DRYRUN_SAMPLE_NUM, CFG_MODE_DRYRUN, REPORT_OUTS_DIR
from vds_util import get_attr
from utils.dataset import load_dataset, get_context_allowed_shots
from utils.template import make_prompt


class Pipeline:
    def __init__(self, cfg=None):
        self.cfg = cfg
        self.tokenizer = NeoLoader.load_tokenizer(cfg.MODEL_NAME)
        self.config, self.model, self.attrs = NeoLoader.load_model(cfg.MODEL_NAME)

        # dim_size
        if 'gpt2' in cfg.MODEL_NAME:
            dim_size = self.model.config.n_embd
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.n_positions
        elif 'pythia' in cfg.MODEL_NAME:
            dim_size = self.model.config.hidden_size
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.max_position_embeddings
        elif 'gemma' in cfg.MODEL_NAME:
            dim_size = self.model.config.hidden_size
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.max_position_embeddings
        elif 'Qwen' in cfg.MODEL_NAME:
            dim_size = self.model.config.hidden_size
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.max_position_embeddings
        elif 'llama' in cfg.MODEL_NAME:
            dim_size = self.model.config.hidden_size
            self.vocab_size = self.model.config.vocab_size
            self.max_context_len = self.model.config.max_position_embeddings
        else:
            raise NotImplementedError

        # ...
        self.identifier = f'{self.cfg.DATA_CODE}.{self.cfg.MODEL_CODE}'

        # cluster_num
        # self.train_inputs, self.train_outputs, self.test_inputs, self.test_outputs = self.load_corpus()
        train_comments, self.train_labels, test_comments, self.test_labels = self.load_corpus()
        self.train_embeds, self.test_embeds = self.dump_embeds(train_comments, test_comments)

        if self.cfg.ABLA in ['mm', 'de4mm']:
            matrix = self.get_lm_head_matrix()
            self.matrix = matrix.transpose(0, 1).to(DEVICE)
        else:
            self.matrix = None
        self.cluster = Cluster(dim_size=dim_size, identifier=self.identifier)
        self.cluster = self.cluster.to(DEVICE)
        # self.cluster = self.cluster.to(dtype=self.config.torch_dtype).to(DEVICE)

        # we freeze all parameters to save computation
        for name, parameter in self.model.named_parameters():
            # print(name, parameter.size())
            parameter.requires_grad = False

        self.anchors = self.init_abs_anchors()

    def get_lm_head_matrix(self):
        lm_head_matrix = get_attr(self.model, self.attrs['lm_head'])
        lm_head_matrix = lm_head_matrix.weight.detach()
        return lm_head_matrix

    def init_abs_anchors(self):
        lm_head_matrix = self.get_lm_head_matrix()
        anchors = list()
        for label_id, vocab_label in enumerate(self.vocab_labels):
            # inputs_labels.append(torch.tensor([label_id]))
            anchor = lm_head_matrix[vocab_label].detach()
            # onehot_embed = self.compose_onehot_embed(vocab_label)
            # anchor = torch.matmul(onehot_embed, matrix)
            anchors.append(anchor)
        anchors = torch.stack(anchors, dim=0).to(DEVICE)
        return anchors

    # def compute_anchor_embed(self, matrix, token_label):
    #     # one-hot tensor
    #     # print(f'{embeddings_matrix.shape=}')
    #     one_hot_tensor = torch.zeros([1, self.vocab_size])
    #     one_hot_tensor[:, token_label] = 1.0
    #     # label_tensor = torch.tensor([token_label], dtype=torch.int64).to(device=DEVICE)
    #     # one_hot_tensor = torch.zeros(len(label_tensor), vocab_size, device=DEVICE)
    #     # one_hot_tensor = one_hot_tensor.scatter_(1, label_tensor.unsqueeze(1), 1.)
    #     heuristic_dist = one_hot_tensor  #.to(DEVICE)  #.requires_grad_(True)
    #     heuristic_repr = torch.linalg.lstsq(matrix.cpu(), heuristic_dist.T, driver='gels').solution
    #     # heuristic_repr = torch.matmul(heuristic_dist, matrix)
    #     heuristic_repr = heuristic_repr.flatten()
    #     return heuristic_repr

    def load_corpus(self):
        # prepare dataset
        train_data, dev_data = load_dataset(dataset=self.cfg.DATA_CODE)
        n_demo_shot = get_context_allowed_shots(dataset=self.cfg.DATA_CODE, context_len=self.max_context_len)

        # vocab_labels
        self.vocab_labels = list()
        for idx, label_verb in enumerate(train_data.id2verb):
            label = self.tokenizer.encode(' ' + label_verb)[-1]
            self.vocab_labels.append(label)
        # vocab_label2id
        self.vocab_label2id = dict()
        for label_id, vocab_label in enumerate(self.vocab_labels):
            self.vocab_label2id[vocab_label] = label_id

        copy_data = copy.deepcopy(train_data)
        copy_data.subsamplebyshot(n_demo_shot)
        prompt_prefix = make_prompt(copy_data, self.cfg.DATA_CODE, mode='train')

        # train
        train_comments = list()
        train_labels = list()
        label2id = train_data.label2id
        for ins in train_data.data:
            if ins in copy_data.data:
                # filtered_data = [datum for datum in copy_data.data if datum != ins]
                # filtered_prompt_prefix = make_prompt(copy_data, self.cfg.DATA_CODE, mode='compose', indices=filtered_data)
                # comment = filtered_prompt_prefix + make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
                comment = make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
            else:
                # comment = prompt_prefix + make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
                comment = make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
            train_comments.append(comment)
            # we use global_label in the approach, so we will filter the global_labels in the evaluation
            label_id = label2id[ins['label']]  # generally directly use the local_label to simplify the case
            label = self.vocab_labels[label_id]  # we still use the global_label, since we compute the anchor
            train_labels.append(label)

        train_data.subsamplebyshot(n_demo_shot)
        prompt_prefix = make_prompt(train_data, self.cfg.DATA_CODE, mode='train')

        # test
        test_comments = list()
        test_labels = list()
        label2id = dev_data.label2id
        for ins in dev_data.data:
            # comment = prompt_prefix + make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
            comment = make_prompt(ins, self.cfg.DATA_CODE, mode='inference')
            test_comments.append(comment)
            # we use global_label in the approach, so we will filter the global_labels in the evaluation
            label_id = label2id[ins['label']]  # generally directly use the local_label to simplify the case
            label = self.vocab_labels[label_id]  # we still use the global_label, since we compute the anchor
            test_labels.append(label)

        if CFG_MODE_DRYRUN:
            dry_factor_train = 4
            dry_factor_test = 1
            train_comments = train_comments[:DRYRUN_SAMPLE_NUM * dry_factor_train]
            train_labels = train_labels[:DRYRUN_SAMPLE_NUM * dry_factor_train]
            test_comments = test_comments[:DRYRUN_SAMPLE_NUM * dry_factor_test]
            test_labels = test_labels[:DRYRUN_SAMPLE_NUM * dry_factor_test]

        return train_comments, train_labels, test_comments, test_labels

    def dump_embeds(self, train_comments, test_comments):
        folder = CACHE_OUTS_DIR

        # load data_cache
        train_filename = f'{self.identifier}.train.pkl'
        test_filename = f'{self.identifier}.test.pkl'
        if (folder / train_filename).is_file() and (folder / test_filename).is_file():
            with open(folder / train_filename, 'rb') as handle:
                train_embeds = pickle.load(handle)
            with open(folder / test_filename, 'rb') as handle:
                test_embeds = pickle.load(handle)
            return train_embeds, test_embeds

        # dump data
        folder.mkdir(parents=True, exist_ok=True)
        train_embeds = self.compose_data(train_comments)
        with open(folder / train_filename, 'wb') as handle:
            pickle.dump(train_embeds, handle)
        test_embeds = self.compose_data(test_comments)
        with open(folder / test_filename, 'wb') as handle:
            pickle.dump(test_embeds, handle)

        return train_embeds, test_embeds

    def datum2tensor(self, datum):
        token_terms = datum.split()
        token_labels = [label for term in token_terms for label in self.tokenizer.encode(term)]
        # correct some out-of-vocabulary labels (might be caused by changed vocabulary)
        # however, it won't affect the length of tokenized token_labels
        token_labels = [self.tokenizer.encode(' ')[0] if label > 50256 else label for label in token_labels]
        token_values = [self.tokenizer.decode(label) for label in token_labels]
        return token_labels, token_values

    def _get_embeddings(self, token_labels):
        # embeddings_matrix = attrgetter(self.attrs['embedding'])(self.model)
        embeddings_matrix = get_attr(self.model, self.attrs['embedding'])
        embeddings_matrix = embeddings_matrix.weight
        inputs_embeds = [embeddings_matrix[token_label] for token_label in token_labels]
        inputs_embeds = torch.stack(inputs_embeds).unsqueeze(0)
        return inputs_embeds

    def llm_gen(self, prompt):
        inputs = self.tokenizer.encode_plus(prompt, return_tensors='pt', padding=True).to(device=DEVICE)
        if inputs['input_ids'].shape[1] > self.max_context_len:
            inputs['input_ids'] = inputs['input_ids'][:, -self.max_context_len:]
            inputs['attention_mask'] = inputs['attention_mask'][:, -self.max_context_len:]
        with torch.no_grad():
            self.model.forward(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])

    def extract_layer_embed(self, prompt):
        # feature hook & partials
        features = dict()
        last_layer_str = 'last_layer'
        last_layer = get_attr(self.model, self.attrs['layers'])

        # when hooking a layer (for last layer)
        def last_layer_feature_hook(m, i, o):
            layer_embeds = o[0].detach()
            # (multiple embeddings to one embedding)
            #  1. for `repr`, using the mean one
            # layer_embed = torch.mean(layer_embeds, dim=-2)
            #  2. for `gen`, using the last one
            layer_embed = layer_embeds[..., -1, :]
            features[last_layer_str] = layer_embed.cpu()
            # logger.critical(f'{o[0].shape=}')

        # hook the last layer
        handle = last_layer.register_forward_hook(last_layer_feature_hook)
        # output = self.model(inputs_embeds=inputs_)
        # prob = F.softmax(output.logits[:, -1, :], dim=-1)
        # pred = output.detach().cpu().numpy()
        self.model.eval()
        # if we need the dists, use this
        # inputs_embeds = self._get_embeddings(token_labels)
        # inputs_embeds = inputs_embeds.unsqueeze(0)
        # output = self.model(inputs_embeds=inputs_embeds)
        # dists = F.softmax(output.logits[:, -1, :], dim=-1)
        # else, we directly use this one
        # token_labels = self.tokenizer(prompt, return_tensors="pt").input_ids.to(DEVICE)
        # _ = self.model.generate(token_labels, max_new_tokens=1)
        # ...using the llm_gen with the truncation support
        self.llm_gen(prompt)
        layer_embed = features[last_layer_str]
        # remove the hook
        handle.remove()
        return layer_embed

    def compose_data(self, comments):
        embeds_list = list()
        for comment in tqdm(comments):
            # TODO use this case to do case study, for example:
            #  1. after refinement, the actual case become close to the virtual case ...
            #  2. after refinement, the virtual case become close to the label family ...
            #   ...
            # logits = self._predict(comment)
            # logger.critical(f'{logits.shape=}')
            # # in the virtual case
            # layer_embed = torch.matmul(logits, matrix)
            # # logger.critical(f'{layer_embed.shape=}')
            # assert torch.sum(layer_embed.isnan()).item() == 0
            # # logger.critical(f'{torch.sum(layer_embed.isnan())=}')

            # in the actual case
            layer_embed = self.extract_layer_embed(comment)
            # inputs = self.tokenizer.encode_plus(comment, return_tensors='pt', padding=True).to(device=DEVICE)
            # outputs = self.model.generate(inputs=inputs['input_ids'], max_new_tokens=1, return_dict_in_generate=True)
            # # confusing on the nested indices ...
            # hidden_states = outputs.hidden_states[0]
            # # ...
            # hidden_state = hidden_states[-1]
            # layer_embed = hidden_state[:, -1, :]
            embeds_list.append(layer_embed)
        return embeds_list

    def parse_response(self, gen_logits):
        gen_prob = torch.softmax(gen_logits, dim=-1)
        prob_per_cls = []
        for vocab_label in self.vocab_labels:
            prob_per_cls.append(gen_prob[:, vocab_label])
        # we filter the global_labels, to keep consist with the local_labels
        # pred = torch.argmax(torch.cat(prob_per_cls, dim=0)).tolist()
        label_id = torch.argmax(torch.cat(prob_per_cls, dim=0)).tolist()
        pred = self.vocab_labels[label_id]
        return pred

    def logits_to_output(self, logits):
        # pred = torch.argmax(logits).tolist()
        sorted_probs, sorted_labels = logits.sort(dim=-1, descending=True)
        argmax_probs = sorted_probs.cpu().numpy()[0].tolist()
        argmax_labels = sorted_labels.cpu().numpy()[0].tolist()

        # # filter sorted_indices using the domain vocabs
        # if self.allow_vocab_labels is not None:
        #     argmax_vocab_probs = list()
        #     argmax_vocab_labels = list()
        #     for argmax_prob, argmax_label in zip(argmax_probs, argmax_labels):
        #         if argmax_label in self.allow_vocab_labels:
        #             argmax_vocab_probs.append(argmax_prob)
        #             argmax_vocab_labels.append(argmax_label)
        #     argmax_probs = argmax_vocab_probs
        #     argmax_labels = argmax_vocab_labels
        # if self.block_vocab_labels is not None:
        #     argmax_vocab_probs = list()
        #     argmax_vocab_labels = list()
        #     for argmax_prob, argmax_label in zip(argmax_probs, argmax_labels):
        #         if argmax_label not in self.block_vocab_labels:
        #             argmax_vocab_probs.append(argmax_prob)
        #             argmax_vocab_labels.append(argmax_label)
        #     argmax_probs = argmax_vocab_probs
        #     argmax_labels = argmax_vocab_labels

        # we only care about top tokens
        # argmax_probs = argmax_probs[:MAX_FOCUSING_NUM]
        # argmax_labels = argmax_labels[:MAX_FOCUSING_NUM]
        argmax_prob = argmax_probs[0]
        argmax_label = argmax_labels[0]
        return argmax_prob, argmax_label

    # def _predict(self, prompt):
    #     inputs = self.tokenizer(prompt, return_tensors="pt").to(DEVICE)
    #     outputs = self.model.generate(
    #         **inputs,
    #         max_new_tokens=1,
    #         return_dict_in_generate=True,
    #     )
    #     logits = F.softmax(outputs.scores[0], dim=-1)
    #     return logits

    def compute_loss(self, embeds_z, labels):
        if self.matrix is not None:
            scores = torch.matmul(embeds_z.float().to(DEVICE), self.matrix.float())
            logits = scores.squeeze(1).to(DEVICE)
            if self.cfg.ABLA == 'de4mm':
                logits = F.softmax(logits, dim=-1)
            truths = labels.to(DEVICE)

            criterion = CrossEntropyLoss()
            clustering_loss = criterion(logits, truths)
            loss = clustering_loss
        else:
            """
            embeds_z.shape=torch.Size([1024, 1, 512])
            self.anchors.shape=torch.Size([8, 512])
            simis.shape=torch.Size([1024, 8])
            logits.shape=torch.Size([1024, 8])
            """
            # here we are using the ids of the anchors, so we take the label2id conversation
            vocab_labels = labels.cpu().tolist()  # global_label
            label_ids = [self.vocab_label2id[label] for label in vocab_labels]  # local_label

            # (1) using cross-entropy loss
            simis = F.cosine_similarity(embeds_z, self.anchors, dim=-1)
            logits = simis.to(DEVICE)
            if self.cfg.ABLA == 'de4sc':
                logits = F.softmax(logits, dim=-1)
            truths = torch.tensor(label_ids).to(DEVICE)
            criterion = CrossEntropyLoss()
            loss = criterion(logits, truths)

            # # (2) using cosine-similarity loss
            # anchors = torch.stack([self.anchors[label_id:label_id+1] for label_id in label_ids])
            # simis = F.cosine_similarity(embeds_z, anchors, dim=-1)
            # loss = (1 - simis).mean()

        # print(f'{loss=}')
        return loss

    def train(self, data_loader, optimizer):
        # self.cluster = self.cluster.to(DEVICE)
        self.cluster.train()
        for batch in tqdm(data_loader, leave=False):
            optimizer.zero_grad()
            embeds, labels = batch
            embeds_z = self.cluster(embeds.float().to(DEVICE))
            loss = self.compute_loss(embeds_z, labels)

            loss.backward()
            optimizer.step()

    def predict(self, data_loader):
        post_embeds = list()
        # self.cluster = self.cluster.to(DEVICE)
        self.cluster.eval()
        for batch in tqdm(data_loader, leave=False):
            embeds, labels = batch
            embeds_z = self.cluster(embeds.float().to(DEVICE))
            post_embeds.append(embeds_z.detach())
        return post_embeds

    def optimize_reprs(self):
        train_embeds = self.train_embeds
        train_labels = self.train_labels
        test_embeds = self.test_embeds
        test_labels = self.test_labels

        # logger.debug(f'{len(train_embeds)=}')
        # logger.debug(f'{len(train_labels)=}')
        train_dataset = CoupleDataset(train_embeds, train_labels)
        train_loader = DataLoader(train_dataset, batch_size=self.cfg.BATCH_SIZE)

        # logger.debug(f'{len(test_embeds)=}')
        # logger.debug(f'{len(test_labels)=}')
        test_dataset = CoupleDataset(test_embeds, test_labels)
        test_loader = DataLoader(test_dataset, batch_size=self.cfg.BATCH_SIZE)

        # # ...
        # whole_dataset = CoupleDataset(train_embeds + test_embeds, train_labels + test_labels)
        # whole_loader = DataLoader(whole_dataset, batch_size=BATCH_SIZE)

        # dataset_size = len(dataset)
        # train_indices = list(range(0, int(dataset_size * 0.8)))
        # test_indices = list(range(int(dataset_size * 0.8), dataset_size))
        # train_sampler = SequentialSampler(train_indices)
        # test_sampler = SequentialSampler(test_indices)
        # train_loader = DataLoader(dataset, sampler=train_sampler, batch_size=1)
        # test_loader = DataLoader(dataset, sampler=test_sampler, batch_size=1)

        train_embeds = torch.cat(train_embeds, 0)
        # logger.debug(f'{train_embeds.shape=}')

        test_embeds = torch.cat(test_embeds, 0)
        # logger.debug(f'{test_embeds.shape=}')
        # inspect(0, self.identifier, test_embeds, test_labels)

        optimizer = optim.AdamW(self.cluster.parameters(), lr=3e-4)
        for _ in tqdm(range(self.cfg.EPOCH_NUM)):
            self.train(train_loader, optimizer)

        test_refined_embeds = self.predict(test_loader)
        test_refined_embeds = torch.cat(test_refined_embeds, 0).to(DEVICE)
        test_refined_embeds = test_refined_embeds.squeeze(1)
        # inspect(EPOCH_NUM, test_refined_embeds, test_labels)
        logger.warning('TEST GEN ...')
        pre_labels, post_labels = self.report_gen(test_embeds, test_refined_embeds, test_labels)
        test_gen_acc = Metric.same_accuracy(post_labels, test_labels)
        logger.warning('TEST VAI ...')
        pre_labels, post_labels = self.report_vds(test_embeds, test_refined_embeds, test_labels)
        test_vds_acc = Metric.same_accuracy(post_labels, test_labels)

        REPORT_OUTS_DIR.mkdir(parents=True, exist_ok=True)
        save_results_file = REPORT_OUTS_DIR / f'glance_{self.cfg.ABLA}.csv'
        csv_exists = save_results_file.exists()
        with open(save_results_file, 'a+', newline='') as csvfile:
            csvwriter = csv.writer(csvfile)
            if not csv_exists:
                csvwriter.writerow(['llm', 'dataset', 'test_gen', 'test_vai'])
            csvwriter.writerow([self.cfg.MODEL_CODE, self.cfg.DATA_CODE, test_gen_acc, test_vds_acc])

        # # ...
        # train_embeds_nd = train_embeds.cpu().detach().numpy()
        # test_embeds_nd = test_embeds.cpu().detach().numpy()
        # test_refined_embeds = self.prepare_partial(test_embeds_nd, train_embeds_nd)
        # # test_refined_embeds = torch.tensor(test_refined_embeds)

        # # ... after the affinity transformation
        # test_refined_embeds = self.inject(test_refined_embeds.cpu(), train_refined_embeds.cpu())
        # test_refined_embeds = torch.tensor(test_refined_embeds).to(DEVICE)
        # self.report_gen(test_embeds, test_refined_embeds, test_labels)

        if self.cfg.DATA_CODE == 'agnews' or self.cfg.DATA_CODE == 'webss':
            train_refined_embeds = self.predict(train_loader)
            train_refined_embeds = torch.cat(train_refined_embeds, 0).to(DEVICE)
            train_refined_embeds = train_refined_embeds.squeeze(1)
            # # inspect(EPOCH_NUM, train_refined_embeds, train_labels)
            # logger.warning('TRAIN GEN ...')
            # pre_labels, post_labels = self.report_gen(train_embeds, train_refined_embeds, train_labels)
            # train_gen_acc = Metric.same_accuracy(post_labels, train_labels)
            # logger.warning('TRAIN VAI ...')
            # pre_labels, post_labels = self.report_vds(train_embeds, train_refined_embeds, train_labels)
            # train_vds_acc = Metric.same_accuracy(post_labels, train_labels)

            train_label_ids = [self.vocab_label2id[label] for label in train_labels]
            test_label_ids = [self.vocab_label2id[label] for label in test_labels]
            self.visualize(train_embeds, train_label_ids, test_embeds, test_label_ids, mark='old')
            self.visualize(train_refined_embeds, train_label_ids, test_refined_embeds, test_label_ids, mark='new')
            self.report_rag(train_embeds, train_refined_embeds, train_labels, test_embeds, test_refined_embeds, test_labels)

    def visualize(self, train_embeds, train_labels, test_embeds, test_labels, mark='tsne'):
        import vds_vis

        import matplotlib.pyplot as plt
        from openTSNE import TSNE

        folder = VISUAL_OUTS_DIR / self.identifier
        folder.mkdir(parents=True, exist_ok=True)

        tsne = TSNE(metric='cosine', n_jobs=8, random_state=42)
        embedding_train = tsne.fit(train_embeds.cpu())
        vds_vis.plot(embedding_train, train_labels, colors=vds_vis.MOUSE_10X_COLORS)
        plt.savefig(folder / f'tsne_{mark}_train.png', bbox_inches='tight', dpi=300)
        # plt.show()
        embedding_test = embedding_train.transform(test_embeds.cpu())
        vds_vis.plot(embedding_test, test_labels, colors=vds_vis.MOUSE_10X_COLORS)
        plt.savefig(folder / f'tsne_{mark}_test.png', bbox_inches='tight', dpi=300)
        # plt.show()

        # import umap
        #
        # manifold = umap.UMAP(metric='cosine', n_jobs=8, random_state=42).fit(train_embeds.cpu(), train_labels)
        # reduced_data = manifold.transform(train_embeds.cpu())
        # vds_vis.plot(reduced_data, train_labels, colors=vds_vis.MOUSE_10X_COLORS)
        # plt.savefig(f'umap_{mark}_train.png', bbox_inches='tight', dpi=300)
        # # plt.show()
        # reduced_data = manifold.transform(test_embeds.cpu())
        # vds_vis.plot(reduced_data, test_labels, colors=vds_vis.MOUSE_10X_COLORS)
        # plt.savefig(f'umap_{mark}_test.png', bbox_inches='tight', dpi=300)
        # # plt.show()

    def infer(self, embeds, matrix):
        argmax_labels = list()
        # argmax_tokens = list()
        for embed in embeds:
            # embed -> logit -> label
            logits = torch.matmul(embed.float().to(DEVICE), matrix.float())
            # logits = torch.matmul(embed.to(dtype=torch.float32).to(DEVICE), matrix)
            logits = logits.unsqueeze(0)
            # logger.warning(f'{logits.shape=}')
            # ...using the general LM vocabulary do the classification
            argmax_label = self.parse_response(logits)  # w/ whitelist
            # argmax_prob, argmax_label = self.logits_to_output(logits)  # w/o whitelist
            argmax_labels.append(argmax_label)
        return argmax_labels

    def refer(self, anchors, embeds):
        # embed -> logit -> label
        # simis = F.cosine_similarity(embeds.unsqueeze(1), anchors, dim=-1)

        # todo make it be an utility function
        A = embeds.unsqueeze(1)
        B = anchors
        chunk_size = 10000
        # Get the total number of chunks
        num_chunks = (A.shape[0] + chunk_size - 1) // chunk_size
        # Initialize an empty list to store chunk-wise cosine similarity results
        similarity_chunks = []
        # Compute cosine similarity for each chunk
        for i in range(num_chunks):
            start_idx = i * chunk_size
            end_idx = min((i + 1) * chunk_size, A.shape[0])  # Handle the last chunk
            # Get the chunk of tensor A
            chunk_A = A[start_idx:end_idx, :]
            # Compute cosine similarity for the chunk
            similarity_chunk = F.cosine_similarity(chunk_A, B, dim=-1)
            # Append the chunk-wise result to the list
            similarity_chunks.append(similarity_chunk)
        # Concatenate the chunk-wise results along the specified dimension
        simis = torch.cat(similarity_chunks, dim=0)

        logits = F.softmax(simis, dim=-1).to(DEVICE)
        argmax_labels = list()
        for argmax_id in torch.argmax(logits, dim=-1):
            argmax_label = self.vocab_labels[argmax_id]
            argmax_labels.append(argmax_label)
        return argmax_labels

    def ragen(self, test_embeds, train_embeds, train_labels):
        # train_logits = self.embeds_to_logits(train_embeds)
        # anchors = train_logits
        anchors = train_embeds.squeeze(1)

        # TODO this is by COS
        def get_knn_idx(queries, anchors, k=1):
            # return idx of k nearest neighbours of the query
            # query: logits
            # neighbour_idx: [idx_1, ...]

            size = len(queries)
            # queries = torch.stack(queries, dim=0).to(DEVICE)
            queries = queries.squeeze(1).to(DEVICE)
            anchors = anchors.squeeze(1).to(DEVICE)
            # logger.warning(f'{queries.shape=}')
            # logger.warning(f'{anchors.shape=}')
            simi_matrix = util.cos_sim(queries, anchors)
            # logger.debug(f'{simi_matrix.shape=}')
            simi_matrix = simi_matrix.detach().cpu().numpy()

            neighbors = list()
            for idx in range(size):
                similarities = simi_matrix[idx]
                neighbor_ids = np.argpartition(similarities, -k)[-k:]
                neighbors.append(neighbor_ids)
            return neighbors

        ks = [2 ** exp for exp in range(9)]
        assert 1 in ks
        all_neighbors = get_knn_idx(test_embeds, anchors, k=ks[-1])

        all_answers = list()
        for neighbors in all_neighbors:
            # logger.warning(f'{neighbors=}')
            answers = list()
            for k in ks:
                top_neighbors = neighbors[:k]
                top_answers = [train_labels[neighbor] for neighbor in top_neighbors]
                counter = Counter(top_answers)
                answer = counter.most_common(1)
                answers.append(answer[0][0])
            all_answers.append(answers)
        return all_answers

    def report_vds(self, origin_embeds, reform_embeds, refer_labels):
        # ...
        anchors = self.anchors
        pre_labels = self.refer(anchors.to(DEVICE), origin_embeds.to(DEVICE))
        post_labels = self.refer(anchors.to(DEVICE), reform_embeds.to(DEVICE))
        Metric.contrast_gen_scoring(pre_labels, post_labels, refer_labels)
        Metric.contrast_cluster_measuring(pre_labels, post_labels, refer_labels)

        # # ...
        # inv_anchors = self.cluster.inverse(anchors)
        # ada_labels = self.refer(inv_anchors.to(DEVICE), origin_embeds.to(DEVICE))
        # Metric.contrast_gen_scoring(gen_labels, ada_labels, refer_labels)
        # # # dependent on the REFER practice
        # Metric.contrast_cluster_measuring(gen_labels, ada_labels, refer_labels)

        # dump generalizations
        # prefix = '_' if DATA_NAME == 'human' else ''
        # well_dump_gens(prefix + '_'.join([DATA_NAME, EXP_CODE]), all_pre_gens, all_post_gens, all_oracle_gens)
        # well_dump_gens('_'.join([DATA_NAME, EXP_CODE]), all_pre_gens, all_post_gens, all_oracle_gens)
        return pre_labels, post_labels

    def report_gen(self, origin_embeds, reform_embeds, refer_labels):
        # ...
        matrix = self.get_lm_head_matrix()
        matrix = matrix.transpose(0, 1).to(DEVICE)
        pre_labels = self.infer(origin_embeds, matrix)
        post_labels = self.infer(reform_embeds, matrix)
        Metric.contrast_gen_scoring(pre_labels, post_labels, refer_labels)
        Metric.contrast_cluster_measuring(pre_labels, post_labels, refer_labels)
        return pre_labels, post_labels

    def report_rag(
            self,
            train_origin_embeds,
            train_reform_embeds,
            train_labels,
            test_origin_embeds,
            test_reform_embeds,
            test_labels
    ):
        all_origin_answers = self.ragen(test_origin_embeds, train_origin_embeds, train_labels)
        all_reform_answers = self.ragen(test_reform_embeds, train_reform_embeds, train_labels)

        ks = [2 ** exp for exp in range(9)]
        for idx, k in enumerate(ks):
            logger.debug(f'{k=}')
            origin_answers = [answer[idx] for answer in all_origin_answers]
            reform_answers = [answer[idx] for answer in all_reform_answers]
            # logger.error(f'{(0 in origin_answers)=}')
            # logger.error(f'{(0 in reform_answers)=}')
            # logger.error(f'{(0 in test_labels)=}')
            # logger.error(f'{(20 in origin_answers)=}')
            # logger.error(f'{(20 in reform_answers)=}')
            # logger.error(f'{(20 in test_labels)=}')
            # logger.info(f'{k=} {origin_answers[:10]=}')
            # logger.info(f'{k=} {reform_answers[:10]=}')
            # logger.info(f'{k=} {test_labels[:10]=}')
            # ...
            Metric.contrast_gen_scoring(origin_answers, reform_answers, test_labels)
