import argparse
import os
import sys
import random

import alipy
import torch
import numpy as np
from tqdm import tqdm

from ofa.imagenet_classification.networks import MobileNetV3Large
from al_select import ALScoringFunctions, calc_leverage
from load_test_models import load_test_models
from ofa.imagenet_classification.run_manager.run_config import get_data_provider_by_name
from path_prefix import path_prefix
from torch.utils.data import TensorDataset, DataLoader
from torchmetrics import R2Score, MeanAbsoluteError, MeanSquaredError
pwd = os.getcwd()
sys.path.append(pwd)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--method",
        type=str,
        default="random",
        choices=[
            "random",
            "qbc",
            "lewisweight",
            "coreset",
        ],
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="biwi",
        choices=[
            "flw",
            "biwi",
        ],
    )
    parser.add_argument("--batch_size", type=int, default=500)
    parser.add_argument("--save_root", type=str, default="al_results")
    parser.add_argument("--al_iter", type=int, default=0)
    parser.add_argument("--model_num", type=int, default=50)

    args = parser.parse_args()
    args.replace = True
    args.save_root = os.path.join(path_prefix, args.save_root)
    if args.al_iter == 0:
        ofa_checkpoint_root = f"{path_prefix}exp/{args.al_iter}/{args.dataset}/"
    else:
        ofa_checkpoint_root = f"{path_prefix}exp/{args.al_iter}/{args.dataset}/{args.method}/"
    al_idx_save_root = os.path.join(args.save_root, str(args.al_iter), args.dataset)
    al_save_root = os.path.join(al_idx_save_root, args.method)
    os.makedirs(al_save_root, exist_ok=True)
    if args.dataset == 'biwi':
        NCLASSES = 2
    else:
        NCLASSES = 10

    if args.al_iter == 0:
        lab_idx = np.loadtxt(os.path.join(ofa_checkpoint_root, "lab_idx.txt"), dtype=int)
        unlab_idx = np.loadtxt(os.path.join(ofa_checkpoint_root, "unlab_idx.txt"), dtype=int)
    else:
        lab_idx = np.loadtxt(os.path.join(al_save_root, "lab_idx.txt"), dtype=int)
        unlab_idx = np.loadtxt(os.path.join(al_save_root, "unlab_idx.txt"), dtype=int)
    al_idx_save_root = os.path.join(args.save_root, str(args.al_iter + 1), args.dataset)
    al_save_root = os.path.join(al_idx_save_root, args.method)

    if args.method == "random":
        np.random.shuffle(unlab_idx)
        selected_idx = unlab_idx[:args.batch_size]
    else:
        # load model
        net, image_size = load_test_models(net_id=0, trained_weights=None)
        del net
        ########################### construct sequential unlab load dataloader ########################################
        DataProvider = get_data_provider_by_name(args.dataset)
        dpv = DataProvider(
            train_batch_size=256,
            test_batch_size=256,
            valid_size=500,
            n_worker=16,
            resize_scale=0.08,
            distort_color="tf",
            image_size=image_size,
            num_replicas=None,
            lab_idx=lab_idx,
            unlab_idx=unlab_idx,
        )
        unlab_dataloader = dpv.unlab
        ######################### query ##########################################
        if args.method == "qbc":
            all_scores = []
            for rs in tqdm(range(args.model_num), desc="extract prob. pred. from models"):
                ofa_net_weight = os.path.join(ofa_checkpoint_root, str(rs), "checkpoint/checkpoint.pth.tar")
                net, image_size = load_test_models(net_id=rs, n_classes=NCLASSES, trained_weights=ofa_net_weight)
                net.cuda()
                qs = ALScoringFunctions(batch_size=args.batch_size,
                                        n_classes=NCLASSES,
                                        unlab_idx=unlab_idx,
                                        net=net,
                                        dataloader=unlab_dataloader)
                qs._get_proba_pred()
                all_scores.append(qs.probs.cpu().numpy())
            qbc_scores = alipy.query_strategy.query_labels.QueryInstanceQBC.calc_avg_KL_divergence(all_scores)
            selected_idx = np.argsort(qbc_scores)[-args.batch_size:]
            selected_idx = [unlab_idx[i] for i in selected_idx]

        elif args.method == "coreset":
            lab_dataloader = dpv.train
            teacher_model = MobileNetV3Large(
                n_classes=NCLASSES,
                bn_param=(0.1, 1e-5),
                dropout_rate=0,
                width_mult=1.0,
                ks=7,
                expand_ratio=6,
                depth_param=4,
            )
            teacher_model.load_state_dict(
                torch.load(".torch/ofa_checkpoints/0/ofa_D4_E6_K7", map_location="cpu")["state_dict"])
            teacher_model.cuda()
            qs = ALScoringFunctions(batch_size=args.batch_size,
                                    n_classes=NCLASSES,
                                    unlab_idx=unlab_idx,
                                    net=teacher_model,
                                    dataloader=unlab_dataloader)
            selected_idx = qs.coreset(lab_dataloader=lab_dataloader)
        elif args.method == "lewisweight":
            lab_dataloader = dpv.train
            lv_scores = []
            for rs in tqdm(range(args.model_num), desc="test models"):
                ofa_net_weight = os.path.join(ofa_checkpoint_root, str(rs), "checkpoint/checkpoint.pth.tar")
                net, image_size = load_test_models(net_id=rs, n_classes=NCLASSES, trained_weights=ofa_net_weight)
                net.cuda()
                qs = ALScoringFunctions(batch_size=args.batch_size,
                                        n_classes=NCLASSES,
                                        unlab_idx=unlab_idx,
                                        net=net,
                                        dataloader=unlab_dataloader)
                qs._get_proba_pred()
                feature_mat = qs.embeddings
                if max(feature_mat.shape) > 1e+5:
                    feature_mat = feature_mat.to("cpu")
                lv_scores.append(calc_leverage(feature_mat).numpy()[0])
            lv_scores = np.array(lv_scores)
            if args.model_num > 1:
                all_scores = np.max(lv_scores, axis=0)
            else:
                all_scores = lv_scores[0]

            selected_idx = []
            selected_prob = []
            prob = all_scores/np.sum(all_scores)
            np.random.seed(0)
            repeat_entry = dict()
            lv_num_sample = 0
            while len(selected_idx) < args.batch_size:
                sid = np.random.choice(a=len(prob), replace=True, p=prob)
                lv_num_sample += 1
                weight = prob[sid]
                if unlab_idx[sid] not in selected_idx:
                    selected_idx.append(unlab_idx[sid])
                    selected_prob.append(weight)
                else:
                    rep_inx = selected_idx.index(unlab_idx[sid])
                    if rep_inx in repeat_entry.keys():
                        repeat_entry[rep_inx] += 1
                    else:
                        repeat_entry[rep_inx] = 2
            for ii in range(len(selected_prob)):
                selected_prob[ii] = 1/np.sqrt(lv_num_sample*selected_prob[ii])
            for k, v in repeat_entry.items():
                selected_prob[k] = selected_prob[k] * np.sqrt(v)

            if args.al_iter >= 1:
                last_iter_weight_path = os.path.join(os.path.join(args.save_root, str(args.al_iter), args.dataset), args.method)
                last_weights = np.loadtxt(os.path.join(last_iter_weight_path, "select_prob.txt"), dtype=float)
                selected_prob = np.hstack((last_weights, selected_prob))

    # update index
    assert set(selected_idx).issubset(set(unlab_idx))
    lab_idx = np.hstack((lab_idx, selected_idx))
    unlab_idx = np.setdiff1d(unlab_idx, lab_idx)
    np.random.shuffle(unlab_idx)
    # save
    os.makedirs(al_save_root, exist_ok=True)
    print(f"save to {al_save_root}...")
    np.savetxt(os.path.join(al_save_root, "lab_idx.txt"), lab_idx, fmt="%d")
    np.savetxt(os.path.join(al_save_root, "unlab_idx.txt"), unlab_idx, fmt="%d")
    if "lewisweight" == args.method:
        np.savetxt(os.path.join(al_save_root, "select_prob.txt"), selected_prob, fmt="%f")

    ############################### train linear regression model #######################################
    def set_seed(seed: int = 42) -> None:
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        # When running on the CuDNN backend, two further options must be set
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        # Set a fixed value for the hash seed
        os.environ["PYTHONHASHSEED"] = str(seed)
        print(f"Random seed set as {seed}")


    class SingleLayerPredictor(torch.nn.Module):
        def __init__(self, n_input, n_output):
            super(SingleLayerPredictor, self).__init__()
            self.pred_layer = torch.nn.Sequential(
                torch.nn.Linear(n_input, n_output),   # hidden layer
                torch.nn.ReLU(),
            )

        def forward(self, x):
            x = self.pred_layer(x)
            return x
        
        def reset_params(self):
            set_seed(0)
            for layer in self.children():
                if hasattr(layer, 'reset_parameters'):
                    layer.reset_parameters()
        
        
    net, image_size = load_test_models(net_id=0, trained_weights=None)
    DataProvider = get_data_provider_by_name(args.dataset)
    dpv = DataProvider(
        train_batch_size=256,
        test_batch_size=128,
        valid_size=500,
        n_worker=16,
        resize_scale=0.08,
        distort_color="tf",
        image_size=image_size,
        num_replicas=None,
        lab_idx=lab_idx,
        unlab_idx=unlab_idx,
    )
    test_dl = dpv.test
    train_dl = dpv.train
    if args.method == "lewisweight":
        tr_weight = np.hstack([np.ones(500),  selected_prob]).tolist()
    else:
        tr_weight = [1]*len(lab_idx)
    
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    performances = np.zeros([args.model_num, 1])
    for rs in tqdm(range(args.model_num), desc="test models"):
        ofa_net_weight = os.path.join(ofa_checkpoint_root, str(rs), "checkpoint/checkpoint.pth.tar")
        net, image_size = load_test_models(net_id=rs, n_classes=NCLASSES, trained_weights=ofa_net_weight)
        net.cuda()
        qs = ALScoringFunctions(batch_size=args.batch_size,
                                n_classes=NCLASSES,
                                unlab_idx=unlab_idx,
                                net=net,
                                dataloader=None)
        train_fea, train_label = qs.get_embedding(fname='', lab_dataloader=dpv.train, save_fea=False, skip_unlab=True, return_labels=True)
        test_fea, test_label = qs.get_embedding(fname='', lab_dataloader=dpv.test, save_fea=False, skip_unlab=True, return_labels=True)

        lr_model = SingleLayerPredictor(train_fea.shape[1], NCLASSES).to(device)
        lr_model.reset_params()
        lr_model.train()
        tr_batch_size = 128
        X_torch_dataset = TensorDataset(train_fea.cpu().float(),
                                        train_label.cpu().float())
        X_dataloader = DataLoader(X_torch_dataset, batch_size=tr_batch_size, num_workers=4, pin_memory=True)
        optimizer = torch.optim.SGD(lr_model.parameters(), lr=1e-3, weight_decay=1e-1)
        loss_func = torch.nn.MSELoss(reduction='none', reduce=False)
        ttt = tqdm(range(50))
        for t in ttt:
            for i, (features, labels) in enumerate(X_dataloader):
                features = features.to(device)
                labels = labels.to(device)
                weights = tr_weight[i*tr_batch_size:min((i+1)*tr_batch_size, len(tr_weight))]
                weights = torch.tensor(weights, requires_grad=False)
                weights = weights.float().to(device)

                optimizer.zero_grad()
                out = lr_model(features)
                # print(features.requires_grad, out.requires_grad, labels.requires_grad)
                loss = loss_func(out, labels).float()
                loss = (torch.diag(weights)@loss).mean()
                loss.backward()
                optimizer.step()
            ttt.set_description(f"loss: {loss:.4f}")

        # test
        MSEmetric = MeanSquaredError().to("cuda")
        msescore = 0
        mean_msescore = 0
        all_num = 0
        lr_model.eval()
        X_torch_dataset = TensorDataset(test_fea.cpu().float(),
                                        test_label.cpu().float())
        test_dl = DataLoader(X_torch_dataset, batch_size=tr_batch_size, num_workers=4, pin_memory=True)
        with torch.no_grad():
            for i, (features, labels) in enumerate(test_dl):
                features, labels = features.to(device), labels.to(device)
                # compute output
                output = lr_model(features)
                msescore = MSEmetric(output, labels)
                all_num += output.size(0)
                mean_msescore += msescore * output.size(0)

        mean_msescore /= all_num
        print(mean_msescore)
        performances[rs, 0] = float(mean_msescore)

        results_saving_dir = f"./extracted_results/{args.model_num}/{args.dataset}/{args.method}/{args.al_iter}"
        os.makedirs(results_saving_dir, exist_ok=True)
        np.savetxt(os.path.join(results_saving_dir, "performances_linear.txt"), performances, fmt="%f")
