import argparse
import numpy as np
import os
import sys
import torch
from tqdm import tqdm
import torch.nn.functional as F
from load_test_models import load_test_models
from ofa.imagenet_classification.run_manager.run_config import get_data_provider_by_name
from ofa.imagenet_classification.networks import MobileNetV3Large
import alipy
from sklearn import preprocessing
from path_prefix import path_prefix
pwd = os.getcwd()
sys.path.append(pwd)


def calc_leverage(mat):
    lvs = torch.zeros((1, mat.shape[0]))
    piv = torch.linalg.pinv(mat.T @ mat)
    for ir, row in enumerate(mat):
        lvs[0, ir] = row @ piv @ row.T
    return lvs


class ALScoringFunctions:
    def __init__(self, batch_size, n_classes, unlab_idx, net, dataloader):
        self.net = net
        self.dataloader = dataloader
        self.scores = []
        self.unlab_idx = unlab_idx
        self.n_classes = n_classes
        self.batch_size = batch_size
        self.probs = None
        self.embeddings = None
        self.lab_arr = None

    def get_embedding(self, fname, lab_dataloader=None, save_fea=False, skip_unlab=False, return_labels=False):
        if self.embeddings is None and not skip_unlab:
            self._get_proba_pred()
        # get labeled data embeddings
        self.net.eval()
        if lab_dataloader:
            with torch.no_grad():
                lab_embeddings = None
                lab_lab_th = None
                ori_fea_mat = None
                with tqdm(total=len(lab_dataloader), desc="Extracting labeled data embedding...") as t:
                    for i, (images, labels) in enumerate(lab_dataloader):
                        images, labels = images.to("cuda"), labels.to("cuda")
                        logits = self.net.get_logits(images).detach()
                        if lab_embeddings is None:
                            lab_embeddings = logits.detach()
                            lab_lab_th = labels.detach()
                        else:
                            lab_embeddings = torch.cat([lab_embeddings, logits.detach()], dim=0)
                            lab_lab_th = torch.cat([lab_lab_th, labels.detach()], dim=0)
                        t.update(1)
        if save_fea:
            FOLDER_PATH = "."
            torch.save((self.embeddings, self.lab_arr), os.path.join(FOLDER_PATH, f"{fname}_unlab.pth"))
            if lab_dataloader:
                torch.save((lab_embeddings, lab_lab_th), os.path.join(FOLDER_PATH, f"{fname}_lab.pth"))
        else:
            if return_labels:
                return lab_embeddings, lab_lab_th
            else:
                return lab_embeddings

    def _get_proba_pred(self):
        self.net.eval()
        with torch.no_grad():
            probs = torch.zeros([len(self.unlab_idx), self.n_classes])
            embeddings = None
            labels_th = None
            with tqdm(total=len(self.dataloader), desc="Extracting unlabeled data embedding/proba prediction...") as t:
                for i, (images, labels) in enumerate(self.dataloader):
                    images, labels = images.to("cuda"), labels.to("cuda")
                    logits, output = self.net.get_logits_and_pred(images)
                    prob = F.softmax(output, dim=1)
                    probs[i * self.dataloader.batch_size:i * self.dataloader.batch_size + len(labels),
                    :] = prob.detach()
                    if embeddings is None:
                        embeddings = logits.detach()
                        labels_th = labels.detach().reshape((1, -1))
                    else:
                        embeddings = torch.cat([embeddings, logits.detach()], dim=0)
                        labels_th = torch.cat([labels_th, labels.detach().reshape((1, -1))], dim=1)
                    t.update(1)
        self.probs = probs
        self.embeddings = embeddings
        self.lab_arr = labels_th

    def entropy(self, return_scores=False):
        if self.probs is None:
            self._get_proba_pred()
        log_probs = torch.log(self.probs)
        uncertainties = (self.probs * log_probs).sum(1)
        if return_scores:
            return uncertainties 
        else:
            return self.unlab_idx[uncertainties.sort()[1][:self.batch_size]]  

    def coreset(self, lab_dataloader, return_scores=False):
        if self.embeddings is None:
            self._get_proba_pred()
        # get labeled data embeddings
        self.net.eval()
        with torch.no_grad():
            lab_embeddings = None
            with tqdm(total=len(lab_dataloader), desc="Extracting labeled data embedding...") as t:
                for i, (images, labels) in enumerate(lab_dataloader):
                    images, labels = images.to("cuda"), labels.to("cuda")
                    logits = self.net.get_logits(images).detach()
                    if lab_embeddings is None:
                        lab_embeddings = logits.detach()
                    else:
                        lab_embeddings = torch.cat([lab_embeddings, logits.detach()], dim=0)
                    t.update(1)

        lab_num = lab_embeddings.shape[0]
        unlab_num = self.embeddings.shape[0]
        coreset_qs = alipy.query_strategy.QueryInstanceCoresetGreedy(
            X=torch.cat([self.embeddings.cpu(), lab_embeddings.cpu()], dim=0), y=np.zeros(lab_num + unlab_num),
            train_idx=list(range(lab_num + unlab_num)))
        if return_scores:
            selected = coreset_qs.select(label_index=np.arange(lab_num) + unlab_num,
                                         unlabel_index=np.arange(unlab_num), batch_size=len(self.unlab_idx))
            scores = np.zeros(len(self.unlab_idx))
            for iidx, idx in enumerate(selected):
                scores[idx] = iidx
            return scores
        else:
            selected = coreset_qs.select(label_index=np.arange(lab_num) + unlab_num,
                                         unlabel_index=np.arange(unlab_num), batch_size=self.batch_size)
            return [self.unlab_idx[i] for i in selected]


    def DIAM(self, return_scores=False):
        global ofa_checkpoint_root
        flist = os.listdir(os.path.join(ofa_checkpoint_root, '0', "checkpoint"))
        pt_list = [fi for fi in flist if fi.endswith('.pt')]
        pt_list = sorted(pt_list)
        all_votes = torch.zeros([len(pt_list), args.model_num, len(self.unlab_idx)])
        for ipt, pt in enumerate(pt_list):
            for i in range(args.model_num):
                ofa_unlab_pred_root = os.path.join(ofa_checkpoint_root, str(i), "checkpoint")
                pre_mat = torch.load(os.path.join(ofa_unlab_pred_root, pt))
                all_votes[ipt, i] = pre_mat
        first_order = []
        second_order = []
        for voter in range(args.model_num):
            votes = all_votes[:, voter, :]
            is_dis = (torch.amax(votes, dim=0) != torch.amin(votes, dim=0))
            first_order.append(is_dis.numpy())
            unc = alipy.query_strategy.QueryInstanceQBC.calc_vote_entropy(votes.numpy())
            second_order.append(unc)
        sorted_args = np.lexsort((np.sum(second_order, axis=0), np.sum(first_order, axis=0)))[::-1]
        selected = sorted_args[:self.batch_size * 5]
        np.random.shuffle(selected)
        selected = selected[:self.batch_size]
        return [self.unlab_idx[i] for i in selected]
    


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--method",
        type=str,
        default="lewisweight",
        choices=[
            "entropy",
            "coreset",
            "random",
            "DIAM",
            "lewisweight",
            "qbc",
        ],
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="mnist",
        choices=[
            "cifar10",
            "mnist",
            "kmnist",
            "fmnist",
            "svhn",
            "emnistlet",
        ],
    )
    parser.add_argument("--batch_size", type=int, default=3000)
    parser.add_argument("--save_root", type=str, default="al_results")
    parser.add_argument("--al_iter", type=int, default=1)
    parser.add_argument("--model_num", type=int, default=50)

    args = parser.parse_args()
    args.save_root = os.path.join(path_prefix, args.save_root)
    al_idx_save_root = os.path.join(args.save_root, str(args.al_iter), args.dataset)
    al_save_root = os.path.join(al_idx_save_root, args.method)
    if args.al_iter == 0:
        ofa_checkpoint_root = f"{path_prefix}exp/{args.al_iter}/{args.dataset}/"
        lab_idx = np.loadtxt(os.path.join(ofa_checkpoint_root, "lab_idx.txt"), dtype=int)
        unlab_idx = np.loadtxt(os.path.join(ofa_checkpoint_root, "unlab_idx.txt"), dtype=int)
    else:
        ofa_checkpoint_root = f"{path_prefix}exp/{args.al_iter}/{args.dataset}/{args.method}/"
        lab_idx = np.loadtxt(os.path.join(al_save_root, "lab_idx.txt"), dtype=int)
        unlab_idx = np.loadtxt(os.path.join(al_save_root, "unlab_idx.txt"), dtype=int)
    al_idx_save_root = os.path.join(args.save_root, str(args.al_iter + 1), args.dataset)
    al_save_root = os.path.join(al_idx_save_root, args.method)
    if args.dataset == 'emnistlet':
        NCLASSES = 27
    else:
        NCLASSES = 10

    if args.method == "random":
        np.random.shuffle(unlab_idx)
        selected_idx = unlab_idx[:args.batch_size]
    else:
        # load model
        net, image_size = load_test_models(net_id=0, trained_weights=None)
        del net
        ########################### construct sequential unlab load dataloader ########################################
        DataProvider = get_data_provider_by_name(args.dataset)
        dpv = DataProvider(
            train_batch_size=256,
            test_batch_size=256,
            valid_size=3000,
            n_worker=16,
            resize_scale=0.08,
            distort_color="tf",
            image_size=image_size,
            num_replicas=None,
            lab_idx=lab_idx,
            unlab_idx=unlab_idx,
        )
        unlab_dataloader = dpv.unlab
        ######################### query ##########################################
        if args.method == "coreset":
            lab_dataloader = dpv.train
            teacher_model = MobileNetV3Large(
                n_classes=NCLASSES,
                bn_param=(0.1, 1e-5),
                dropout_rate=0,
                width_mult=1.0,
                ks=7,
                expand_ratio=6,
                depth_param=4,
            )
            teacher_model.load_state_dict(
                torch.load(".torch/ofa_checkpoints/0/ofa_D4_E6_K7", map_location="cpu")["state_dict"])
            teacher_model.cuda()
            qs = ALScoringFunctions(batch_size=args.batch_size,
                                    n_classes=NCLASSES,
                                    unlab_idx=unlab_idx,
                                    net=teacher_model,
                                    dataloader=unlab_dataloader)
            selected_idx = qs.coreset(lab_dataloader=lab_dataloader)
        elif args.method == "qbc":
            all_scores = []
            for rs in tqdm(range(args.model_num), desc="extract prob. pred. from models"):
                ofa_net_weight = os.path.join(ofa_checkpoint_root, str(rs), "checkpoint/checkpoint.pth.tar")
                net, image_size = load_test_models(net_id=rs, n_classes=NCLASSES, trained_weights=ofa_net_weight)
                net.cuda()
                qs = ALScoringFunctions(batch_size=args.batch_size,
                                        n_classes=NCLASSES,
                                        unlab_idx=unlab_idx,
                                        net=net,
                                        dataloader=unlab_dataloader)
                qs._get_proba_pred()
                all_scores.append(qs.probs.cpu().numpy())
            qbc_scores = alipy.query_strategy.query_labels.QueryInstanceQBC.calc_avg_KL_divergence(all_scores)
            selected_idx = np.argsort(qbc_scores)[-args.batch_size:]
            selected_idx = [unlab_idx[i] for i in selected_idx]
        elif args.method == "lewisweight":
            lab_dataloader = dpv.train
            lv_scores = []
            for rs in tqdm(range(args.model_num), desc="test models"):
                ofa_net_weight = os.path.join(ofa_checkpoint_root, str(rs), "checkpoint/checkpoint.pth.tar")
                net, image_size = load_test_models(net_id=rs, n_classes=NCLASSES, trained_weights=ofa_net_weight)
                net.cuda()
                qs = ALScoringFunctions(batch_size=args.batch_size,
                                        n_classes=NCLASSES,
                                        unlab_idx=unlab_idx,
                                        net=net,
                                        dataloader=unlab_dataloader)
                qs._get_proba_pred()
                feature_mat = qs.embeddings
                if max(feature_mat.shape) > 1e+5:
                    feature_mat = feature_mat.to("cpu")
                lv_scores.append(calc_leverage(feature_mat).numpy()[0])
            lv_scores = np.array(lv_scores)
            if args.model_num > 1:
                all_scores = np.max(lv_scores, axis=0)
            else:
                all_scores = lv_scores[0]

            selected_idx = []
            selected_prob = []
            prob = all_scores/np.sum(all_scores)
            np.random.seed(0)
            repeat_entry = dict()
            lv_num_sample = 0
            while len(selected_idx) < args.batch_size:
                sid = np.random.choice(a=len(prob), replace=True, p=prob)
                lv_num_sample += 1
                weight = prob[sid]
                if unlab_idx[sid] not in selected_idx:
                    selected_idx.append(unlab_idx[sid])
                    selected_prob.append(weight)
                else:
                    rep_inx = selected_idx.index(unlab_idx[sid])
                    if rep_inx in repeat_entry.keys():
                        repeat_entry[rep_inx] += 1
                    else:
                        repeat_entry[rep_inx] = 2
            for ii in range(len(selected_prob)):
                selected_prob[ii] = 1/np.sqrt(lv_num_sample*selected_prob[ii])
            for k, v in repeat_entry.items():
                selected_prob[k] = selected_prob[k] * np.sqrt(v)

            if args.al_iter >= 1:
                last_iter_weight_path = os.path.join(os.path.join(args.save_root, str(args.al_iter), args.dataset), args.method)
                last_weights = np.loadtxt(os.path.join(last_iter_weight_path, "select_prob.txt"), dtype=float)
                selected_prob = np.hstack((last_weights, selected_prob))
        elif args.method == "entropy":
            minmax = preprocessing.MinMaxScaler()
            all_scores = np.zeros([args.model_num, len(unlab_idx)])
            for rs in tqdm(range(args.model_num), desc="test models"):
                ofa_net_weight = os.path.join(ofa_checkpoint_root, str(rs), "checkpoint/checkpoint.pth.tar")
                net, image_size = load_test_models(net_id=rs, n_classes=NCLASSES, trained_weights=ofa_net_weight)
                net.cuda()
                qs = ALScoringFunctions(batch_size=args.batch_size,
                                        n_classes=NCLASSES,
                                        unlab_idx=unlab_idx,
                                        net=net,
                                        dataloader=unlab_dataloader)
                kwargs = dict()
                kwargs['return_scores'] = True
                data_socres = eval(f"qs.{args.method}(**kwargs)")
                all_scores[rs, :] = minmax.fit_transform(
                    data_socres.numpy().reshape([1, -1]))  # scale to 0-1, avoid dominating by a certain qs

            # mean and sort the scores
            mean_all_scores = np.mean(all_scores, axis=0)
            selected_idx = np.argsort(mean_all_scores)[:args.batch_size]
            selected_idx = [unlab_idx[i] for i in selected_idx]
        elif args.method == "DIAM":
            qs = ALScoringFunctions(batch_size=args.batch_size,
                                    n_classes=NCLASSES,
                                    unlab_idx=unlab_idx,
                                    net=None,
                                    dataloader=unlab_dataloader)
            selected_idx = eval(f"qs.{args.method}()")

    # update index
    assert set(selected_idx).issubset(set(unlab_idx))
    lab_idx = np.hstack((lab_idx, selected_idx))
    unlab_idx = np.setdiff1d(unlab_idx, lab_idx)
    np.random.shuffle(unlab_idx)
    # save
    os.makedirs(al_save_root, exist_ok=True)
    print(f"save to {al_save_root}...")
    np.savetxt(os.path.join(al_save_root, "lab_idx.txt"), lab_idx, fmt="%d")
    np.savetxt(os.path.join(al_save_root, "unlab_idx.txt"), unlab_idx, fmt="%d")
    if "lewisweight" == args.method:
        np.savetxt(os.path.join(al_save_root, "select_prob.txt"), selected_prob, fmt="%f")
