

import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
from models import Model
import torch.nn.functional as F
from dataloader import load_data, load_out_t
from utils import (
    get_logger,
    get_evaluator,
    set_seed,
    get_training_config,
    check_writable,
    check_readable,
    compute_min_cut_loss,
    graph_split,
    feature_prop,
)
from train_and_eval import distill_run_transductive, distill_run_inductive

def train_student(student, label):
    # torch.nn.NLLLoss()
    # infer 也会调用这个api
    student = F.log_softmax(student, dim=-1)
    loss = F.nll_loss(student, label)
    # print("classification loss:", loss.item())
    return loss


def euc_distance(x, y, eval_mode=False):
    """calculate eucidean distance
    Args:
        x (Tensor): shape:(N1, d), the x tensor 
        y (Tensor): shape (N2, d) if eval_mode else (N1, d), the y tensor
        eval_mode (bool, optional): whether or not use eval model. Defaults to False.
    Returns:
        if eval mode: (N1, N2)
        else: (N1, 1)
    """
    x2 = torch.sum(x * x, dim=-1, keepdim=True)
    y2 = torch.sum(y * y, dim=-1, keepdim=True)

    if eval_mode:
        y2 = y2.t()
        xy = x @ y.t()
    else:
        assert x.shape[0] == y.shape[0], 'The shape of x and y do not match.'
        xy = torch.sum(x * y, dim=-1, keepdim=True)

    return x2 + y2 - 2 * xy

def KD(student, teacher):
    # 都是 没有softmax的 
    loss = F.kl_div(torch.log_softmax(student,dim=1), torch.log_softmax(teacher,dim=-1), reduction="batchmean", log_target=True) #KL
    # print("kd loss: {}".format(loss.item()))
    return loss 

class Distill_loss(nn.Module):
    """
    处理KD的loss
    """
    def __init__(self, kd_type, get_center="mean", T1=2, T2=10, lambda1=0.1, lambda2=0.1):
        super(Distill_loss, self).__init__()

        self.kd_type = kd_type
        
        self.get_center = get_center
        self.T1 = T1
        self.T2 = T2
        self.lambda1 = lambda1
        self.lambda2 = lambda2
    
    def forward(self, student, teacher, truelabel=None, la=1):
        if self.kd_type == "KD":
            # l = torch.log(torch.softmax(teacher,dim=1)/torch.softmax(student,dim=-1)).mean()
            return la*KD(student, teacher)
        elif self.kd_type == "RAW":
            return 0*KD(student, teacher)
        elif self.kd_type == "CRD":
            loss_kd = la*KD(student, teacher)

            # label_teacher = F.one_hot(true_label, num_classes=student.size(-1)).float() # from ground truth
            # label_teacher = F.softmax(teacher/0.01, dim=-1)
            if truelabel == None:
                label_teacher = F.one_hot(teacher.max(dim=1)[1], num_classes=student.size(-1)).float()
            else:
                # inductive setting, pass true label for training
                label_teacher = F.one_hot(truelabel, num_classes=student.size(-1)).float() # from ground truth

            if self.get_center == "mean":
                row_sum = torch.sum(label_teacher.t(), dim=1, keepdim=True)
                for i in range(row_sum.size(0)):
                    if row_sum[i][0].item() != 0:
                        row_sum[i][0] = 1/row_sum[i][0]
                center_weights = label_teacher.t() * row_sum

            elif self.get_center == "t":
                prob = F.softmax(teacher,dim=-1)
                
                entropy = (prob * torch.log(prob)).sum(dim=-1,keepdim=True) # -1 -2 -3 , -1代表置信度大
                center_weights = (label_teacher * entropy).t() 
                #! 不同于softmax，0->0, -1->e^-1, 按照行进行归一化
                fake_exp = torch.exp(center_weights)*(label_teacher.t())
                center_weights = torch.div(fake_exp, fake_exp.sum(dim=-1,keepdim=True)+1e-5)

            elif self.get_center == "s":
                prob = F.softmax(student,dim=-1)

                entropy = (prob * torch.log(prob)).sum(dim=-1,keepdim=True) # -1 -2 -3 , -1代表置信度大
                center_weights = (label_teacher * entropy).t() 
                #! 不同于softmax，0->0, -1->e^-1, 按照行进行归一化
                fake_exp = torch.exp(center_weights)**(label_teacher.t())
                center_weights = torch.div(fake_exp, fake_exp.sum(dim=-1,keepdim=True)+1e-5)

            else:
                print("Wrong type to generate center point")
    
            center_point = torch.mm(center_weights, student) # get center point
    
            #! get distance for each point and center point. 1. 算O(N^2)次距离 -> O(N^K)次距离 2. 借助prototype实现平滑
            simi_student_center = euc_distance(student, center_point, eval_mode=True)

            #! dropout point&center-point pair
            # simi_student_center = F.dropout2d(simi_student_center.unsqueeze(1),p=0.5).squeeze(1)
            # postive: lambda1-lambda2, negatiev: -lambda2
            # lambda1, lambda2= 1, 0 #正样本不需要太聚拢，负样本推开即可
            # weight = lambda1*label_teacher.float()- lambda2*torch.ones_like(label_teacher)
            # simi_student_center = torch.where(simi_student_center>1, 1*torch.ones_like(simi_student_center), simi_student_center)
            # lambda 概率置于0
            # lam = 0.95
            # weight *= (torch.empty(weight.size(0)).uniform_(0,1) > lam*torch.ones(weight.size(0))).float().unsqueeze(1)
            # loss_add = torch.mul(simi_student_center, weight).sum(dim=1).mean(dim=0)
            loss_add = F.nll_loss(F.log_softmax(-self.T1 * simi_student_center, dim=-1), teacher.max(dim=1)[1]) #! near

            #! teacher pattern
            center_point_teacher = torch.mm(center_weights, teacher)
            # dis_teacher = F.cosine_similarity(center_point_teacher.unsqueeze(1), center_point_teacher.unsqueeze(0), dim=-1)
            # dis_student = F.cosine_similarity(center_point.unsqueeze(1), center_point.unsqueeze(0), dim=-1)
            dis_teacher = euc_distance(center_point_teacher, center_point_teacher, True)
            dis_student = euc_distance(center_point, center_point, True)

            loss_add2 = F.kl_div(F.log_softmax(dis_student/self.T2,dim=-1), F.softmax(dis_teacher/self.T2, dim=-1), reduction="batchmean")
            # loss_add2 = F.mse_loss(dis_teacher,dis_student)
            # loss_add2 = F.mse_loss(teacher,student)
            
            # print(loss_kd.item(), loss_add.item(), loss_add2.item())
            return loss_kd  + self.lambda1*loss_add + self.lambda2*loss_add2
        else:
            print("Wrong KD loss type")


def get_args():
    parser = argparse.ArgumentParser(description="PyTorch DGL implementation")
    parser.add_argument("--device", type=int, default=-1, help="CUDA device, -1 means CPU")
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--log_level",
        type=int,
        default=20,
        help="Logger levels for run {10: DEBUG, 20: INFO, 30: WARNING}",
    )
    parser.add_argument(
        "--console_log",
        action="store_true",
        help="Set to True to display log info in console",
    )
    parser.add_argument(
        "--output_path", type=str, default="outputs", help="Path to save outputs"
    )
    parser.add_argument(
        "--num_exp", type=int, default=1, help="Repeat how many experiments"
    )
    parser.add_argument(
        "--exp_setting",
        type=str,
        default="tran",
        help="Experiment setting, one of [tran, ind]",
    )
    parser.add_argument(
        "--eval_interval", type=int, default=1, help="Evaluate once per how many epochs"
    )
    parser.add_argument(
        "--save_results",
        action="store_true",
        help="Set to True to save the loss curves, trained model, and min-cut loss for the transductive setting",
    )

    """Dataset"""
    parser.add_argument("--dataset", type=str, default="cora", help="Dataset")
    parser.add_argument("--data_path", type=str, default="./data", help="Path to data")
    parser.add_argument(
        "--labelrate_train",
        type=int,
        default=20,
        help="How many labeled data per class as train set",
    )
    parser.add_argument(
        "--labelrate_val",
        type=int,
        default=30,
        help="How many labeled data per class in valid set",
    )
    parser.add_argument(
        "--split_idx",
        type=int,
        default=0,
        help="For Non-Homo datasets only, one of [0,1,2,3,4]",
    )

    """Model"""
    parser.add_argument(
        "--model_config_path",
        type=str,
        default="./train.conf.yaml",
        help="Path to model configeration",
    )
    parser.add_argument("--teacher", type=str, default="SAGE", help="Teacher model")
    parser.add_argument("--student", type=str, default="MLP", help="Student model")
    parser.add_argument(
        "--num_layers", type=int, default=2, help="Student model number of layers"
    )
    parser.add_argument(
        "--hidden_dim",
        type=int,
        default=64,
        help="Student model hidden layer dimensions",
    )
    parser.add_argument("--dropout_ratio", type=float, default=0)
    parser.add_argument(
        "--norm_type", type=str, default="none", help="One of [none, batch, layer]"
    )

    """SAGE Specific"""
    parser.add_argument("--batch_size", type=int, default=512)
    parser.add_argument(
        "--fan_out",
        type=str,
        default="5,5",
        help="Number of samples for each layer in SAGE. Length = num_layers",
    )
    parser.add_argument(
        "--num_workers", type=int, default=0, help="Number of workers for sampler"
    )

    """Optimization"""
    parser.add_argument("--learning_rate", type=float, default=0.01)
    parser.add_argument("--weight_decay", type=float, default=0.0005)
    parser.add_argument(
        "--max_epoch", type=int, default=500, help="Evaluate once per how many epochs"
    )
    parser.add_argument(
        "--patience",
        type=int,
        default=50,
        help="Early stop is the score on validation set does not improve for how many epochs",
    )

    """Ablation"""
    parser.add_argument(
        "--feature_noise",
        type=float,
        default=0,
        help="add white noise to features for analysis, value in [0, 1] for noise level",
    )
    parser.add_argument(
        "--split_rate",
        type=float,
        default=0.2,
        help="Rate for graph split, see comment of graph_split for more details",
    )
    parser.add_argument(
        "--compute_min_cut",
        action="store_true",
        help="Set to True to compute and store the min-cut loss",
    )
    parser.add_argument(
        "--feature_aug_k",
        type=int,
        default=0,
        help="Augment node futures by aggregating feature_aug_k-hop neighbor features",
    )

    """Distiall"""
    #! lamb=1 pure train mlp
    parser.add_argument(
        "--lamb",
        type=float,
        default=0,
        help="Parameter balances loss from hard labels and teacher outputs, take values in [0, 1]",
    )
    parser.add_argument(
        "--out_t_path", type=str, default="outputs", help="Path to load teacher outputs"
    )
    parser.add_argument(
        "--kl_loss_type",
        type=str,
        default="CRD",
        help="loss type for distill",
    )
    parser.add_argument(
        "--get_center",
        type=str,
        default="mean",
        choices=["mean","s","t"],
        help="way to get prototype center",
    )
    parser.add_argument(
        "--T1",
        type=float,
        default="2",
        help="tempture for prototype contrastive loss",
    )
    parser.add_argument(
        "--T2",
        type=float,
        default="10",
        help="tempture for prototype similarity loss",
    )
    parser.add_argument(
        "--lambda1",
        type=float,
        default="0.1",
        help="weight for prototype contrastive loss",
    )
    parser.add_argument(
        "--lambda2",
        type=float,
        default="0.1",
        help="weight for prototype similarity loss",
    )
    args = parser.parse_args()

    return args


def run(args):
    """
    Returns:
    score_lst: a list of evaluation results on test set.
    len(score_lst) = 1 for the transductive setting.
    len(score_lst) = 2 for the inductive/production setting.
    """

    """ Set seed, device, and logger """
    set_seed(args.seed)
    if torch.cuda.is_available() and args.device >= 0:
        device = torch.device("cuda:" + str(args.device))
    else:
        device = "cpu"

    if args.feature_noise != 0:
        args.output_path = Path.cwd().joinpath(
            args.output_path, "noisy_features", f"noise_{args.feature_noise}"
        )
        # Teacher is assumed to be trained on the same noisy features as well.
        args.out_t_path = args.output_path

    if args.feature_aug_k > 0:
        args.output_path = Path.cwd().joinpath(
            args.output_path, "aug_features", f"aug_hop_{args.feature_aug_k}"
        )
        # NOTE: Teacher may or may not have augmented features, specify args.out_t_path explicitly.
        # args.out_t_path =
        args.student = f"GA{args.feature_aug_k}{args.student}"

    if args.exp_setting == "tran":
        output_dir = Path.cwd().joinpath(
            args.output_path,
            "transductive",
            args.dataset,
            f"{args.teacher}_{args.student}_{args.kl_loss_type}",
            f"seed_{args.seed}",
        )
        out_t_dir = Path.cwd().joinpath(
            args.out_t_path,
            "transductive",
            args.dataset,
            args.teacher,
            f"seed_{args.seed}",
        )
    elif args.exp_setting == "ind":
        output_dir = Path.cwd().joinpath(
            args.output_path,
            "inductive",
            f"split_rate_{args.split_rate}",
            args.dataset,
            f"{args.teacher}_{args.student}_{args.kl_loss_type}",
            f"seed_{args.seed}",
        )
        out_t_dir = Path.cwd().joinpath(
            args.out_t_path,
            "inductive",
            f"split_rate_{args.split_rate}",
            args.dataset,
            args.teacher,
            f"seed_{args.seed}",
        )
    else:
        raise ValueError(f"Unknown experiment setting! {args.exp_setting}")
    args.output_dir = output_dir

    check_writable(output_dir, overwrite=False)
    check_readable(out_t_dir)

    logger = get_logger(output_dir.joinpath("log.txt"), args.console_log, args.log_level)
    logger.info(f"output_dir: {output_dir}")
    logger.info(f"out_t_dir: {out_t_dir}")

    """ Load data and model config"""
    g, labels, idx_train, idx_val, idx_test = load_data(
        args.dataset,
        args.data_path,
        split_idx=args.split_idx,
        seed=args.seed,
        labelrate_train=args.labelrate_train,
        labelrate_val=args.labelrate_val,
    )

    logger.info(f"Total {g.number_of_nodes()} nodes.")
    logger.info(f"Total {g.number_of_edges()} edges.")

    feats = g.ndata["feat"]
    args.feat_dim = g.ndata["feat"].shape[1]
    args.label_dim = labels.int().max().item() + 1

    if 0 < args.feature_noise <= 1:
        feats = (
            1 - args.feature_noise
        ) * feats + args.feature_noise * torch.randn_like(feats)

    """ Model config """
    conf = {}
    if args.model_config_path is not None:
        conf = get_training_config(
            args.model_config_path, args.student, args.dataset
        )  # Note: student config
    conf = dict(args.__dict__, **conf)
    conf["device"] = device
    logger.info(f"conf: {conf}")

    """ Model init """
    model = Model(conf)
    optimizer = optim.Adam(
        model.parameters(), lr=conf["learning_rate"], weight_decay=conf["weight_decay"]
    )
    # criterion_l = torch.nn.NLLLoss()
    criterion_l = train_student

    #! distill loss
    # criterion_t = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
    criterion_t =  Distill_loss(kd_type=args.kl_loss_type, 
                                get_center=args.get_center, 
                                T1=args.T1, 
                                T2=args.T2, 
                                lambda1=args.lambda1, 
                                lambda2=args.lambda2)

    evaluator = get_evaluator(conf["dataset"])

    """Load teacher model output"""
    out_t = load_out_t(out_t_dir)
    #! add
    out_t_log = torch.nn.functional.log_softmax(out_t, dim=-1)
    logger.debug(
        f"teacher score on train data: {evaluator(out_t_log[idx_train], labels[idx_train])}"
    )
    logger.debug(
        f"teacher score on val data: {evaluator(out_t_log[idx_val], labels[idx_val])}"
    )
    logger.debug(
        f"teacher score on test data: {evaluator(out_t_log[idx_test], labels[idx_test])}"
    )

    """Data split and run"""
    loss_and_score = []
    if args.exp_setting == "tran":
        idx_l = idx_train
        idx_t = torch.cat([idx_train, idx_val, idx_test])
        distill_indices = (idx_l, idx_t, idx_val, idx_test)

        # propagate node feature
        if args.feature_aug_k > 0:
            feats = feature_prop(feats, g, args.feature_aug_k)

        out, score_val, score_test = distill_run_transductive(
            conf,
            model,
            feats,
            labels,
            out_t,
            distill_indices,
            criterion_l,
            criterion_t,
            evaluator,
            optimizer,
            logger,
            loss_and_score,
            g
        )
        score_lst = [score_test]

    elif args.exp_setting == "ind":
        # Create inductive split
        obs_idx_train, obs_idx_val, obs_idx_test, idx_obs, idx_test_ind = graph_split(
            idx_train, idx_val, idx_test, args.split_rate, args.seed
        )
        obs_idx_l = obs_idx_train
        obs_idx_t = torch.cat([obs_idx_train, obs_idx_val, obs_idx_test])
        distill_indices = (
            obs_idx_l,
            obs_idx_t,
            obs_idx_val,
            obs_idx_test,
            idx_obs,
            idx_test_ind,
        )

        # propagate node feature. The propagation for the observed graph only happens within the subgraph obs_g
        if args.feature_aug_k > 0:
            obs_g = g.subgraph(idx_obs)
            obs_feats = feature_prop(feats[idx_obs], obs_g, args.feature_aug_k)
            feats = feature_prop(feats, g, args.feature_aug_k)
            feats[idx_obs] = obs_feats

        out, score_val, score_test_tran, score_test_ind = distill_run_inductive(
            conf,
            model,
            feats,
            labels,
            out_t,
            distill_indices,
            criterion_l,
            criterion_t,
            evaluator,
            optimizer,
            logger,
            loss_and_score,
            g
        )
        score_lst = [score_test_tran, score_test_ind]

    logger.info(
        f"num_layers: {conf['num_layers']}. hidden_dim: {conf['hidden_dim']}. dropout_ratio: {conf['dropout_ratio']}"
    )
    logger.info(f"# params {sum(p.numel() for p in model.parameters())}")

    """ Saving student outputs """
    out_np = out.detach().cpu().numpy()
    np.savez(output_dir.joinpath("out"), out_np)

    """ Saving loss curve and model """
    if args.save_results:
        # Loss curves
        loss_and_score = np.array(loss_and_score)
        np.savez(output_dir.joinpath("loss_and_score"), loss_and_score)

        # Model
        torch.save(model.state_dict(), output_dir.joinpath("model.pth"))

    """ Saving min-cut loss"""
    if args.exp_setting == "tran" and args.compute_min_cut:
        min_cut = compute_min_cut_loss(g, out)
        with open(output_dir.parent.joinpath("min_cut_loss"), "a+") as f:
            f.write(f"{min_cut :.4f}\n")

    return score_lst


def repeat_run(args):
    scores = []
    for seed in range(args.num_exp):
        args.seed = seed
        scores.append(run(args))
    scores_np = np.array(scores)
    return scores_np.mean(axis=0), scores_np.std(axis=0)


def main():
    args = get_args()
    if args.num_exp == 1:
        score = run(args)
        score_str = "".join([f"{s : .4f}\t" for s in score])

    elif args.num_exp > 1:
        score_mean, score_std = repeat_run(args)
        score_str = "".join(
            [f"{s : .4f}\t" for s in score_mean] + [f"{s : .4f}\t" for s in score_std]
        )

    with open(args.output_dir.parent.joinpath("exp_results"), "a+") as f:
        f.write(f"{args.get_center}\n")
        f.write(f"{args.T1}\n")
        f.write(f"{args.T2}\n")
        f.write(f"{args.lambda1}\n")
        f.write(f"{args.lambda2}\n")
        f.write(f"{score_str}\n")
        f.write("-----\n")

    # for collecting aggregated results
    print(score_str)


if __name__ == "__main__":
    main()
