import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import time
import json
import warnings
from tqdm import tqdm
from copy import deepcopy
from typing import Set, Callable, Any

import pandas as pd
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from transformers import optimization
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModel

import torch
from torch import Tensor
from torch.nn import Module
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch_geometric.data import DataLoader, Data
import tensorboard_logger as tb_logger

from utils.parsing import parse_option
from utils.evaluate import Evaluator
from utils.load_dataset import PygOurDataset
from utils.util import AverageMeter, set_optimizer, calmean
from loss.loss_scl_reg import SupConLossReg
from models.deepgcn import GraphxLSTM
from models.module import MLPMoE


import pdb
warnings.filterwarnings("ignore")

def set_seed(num_seed):
    np.random.seed(num_seed)
    torch.manual_seed(num_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(num_seed)
        torch.cuda.manual_seed_all(num_seed) 

#parse arguments
opt = parse_option()
args_dict = vars(opt)

def set_loader(opt: Any, dataname: str) -> Set[Data]:
    """Load dataset from opt.datas_dir.

    Args:
        opt (Any): Parsed arguments.
        dataname (str): The folder name of the dataset.

    Returns:
        Set[Data]: train/validation/test sets.
    """

    train_dataset = PygOurDataset(root=opt.data_dir, phase="train", dataname=dataname)
    test_dataset = PygOurDataset(root=opt.data_dir, phase="test", dataname=dataname)
    val_dataset = PygOurDataset(root=opt.data_dir, phase="valid", dataname=dataname)
    #positive_control = PygOurDataset(root=opt.data_dir, phase="pos", dataname=dataname)

    return train_dataset, val_dataset, test_dataset

class ContrastiveEncode(torch.nn.Module):
    def __init__(self, dim_feat: int):
        super().__init__()
        self.encode = torch.nn.Sequential(
                torch.nn.Linear(dim_feat, dim_feat),
                torch.nn.ReLU(),
                torch.nn.Linear(dim_feat, dim_feat)
        )

    def forward(self, x):
        x = self.encode(x)
        return x

class MLP(torch.nn.Module):
    def __init__(self, dim_feat: int):
        super().__init__()
        self.encode = torch.nn.Sequential(
                torch.nn.Linear(dim_feat, dim_feat),
                torch.nn.ReLU(),
                torch.nn.Linear(dim_feat, dim_feat)
        )

    def forward(self, x):
        x = self.encode(x)
        return x

class MolPropertyPrediction(torch.nn.Module):
    def __init__(self, molxlstm: Module, chemberta: Module, opt: Any):
        super(MolPropertyPrediction, self).__init__()
        emb_dim = opt.num_dim * opt.power

        self.molxlstm = molxlstm
        self.enc_rdkit = torch.nn.Sequential(
            torch.nn.Linear(204, emb_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(emb_dim, emb_dim)
        )

        self.head_gnn = ContrastiveEncode(emb_dim)
        self.head_rdkit = ContrastiveEncode(emb_dim)

        self.mlp_gnn = MLP(emb_dim)
        self.mlp_rdkit = MLP(emb_dim)

        self.classifier = MLPMoE(input_feat = opt.num_dim * opt.power, dim_feat = opt.num_dim * opt.power, num_tasks=opt.num_tasks, num_experts = opt.num_experts, num_heads = opt.num_heads, output_feat = opt.num_dim * opt.power)

        self.chemberta = chemberta

    def forward(self, input_molecule: Tensor, phase: str = "train"):
        #GNN feature
        f_out = self.molxlstm(input_molecule)

        #rdkit feature
        rdkit_feat = input_molecule.desp.view(input_molecule.y.shape[0], -1)
        rdkit_feat = self.enc_rdkit(rdkit_feat.float())

        f_out = self.mlp_gnn(f_out)
        rdkit_feat = self.mlp_rdkit(rdkit_feat)

        f_out_norm = F.normalize(self.head_gnn(f_out), 2)
        f_rdkit_norm = F.normalize(self.head_rdkit(rdkit_feat), 2)
        output_final, gate, f_moe, expert_num = self.classifier(f_out, rdkit_feat)

        if phase == "train":
            return (
                output_final,
                gate,
                f_out_norm,
                f_rdkit_norm
            )
        else:
            return (
                output_final,
                f_out,
                f_moe,
                rdkit_feat,
                expert_num
            )

def set_model(opt: Any):
    """Initialization of the model and loss functions.

    Args:
        opt (Any): Parsed arguments.

    Returns:
        Return the model and the loss functions.
    """
    model_name = "DeepChem/ChemBERTa-77M-MLM"
    chemberta = AutoModel.from_pretrained(model_name)

    for param in chemberta.parameters():
        param.requires_grad = False

    molgraph_xlstm = GraphxLSTM(opt)
    model = MolPropertyPrediction(molgraph_xlstm, chemberta, opt)
    criterion_task = torch.nn.MSELoss()

    if torch.cuda.is_available():
        model = model.cuda()
        cudnn.benchmark = False

    return model, criterion_task

def build_soft_target_distribution(y, intervals, mode="gaussian", sigma=0.2):
    """
    y: [B] true target（denormalized）
    intervals: list of (lower, upper) tuples
    mode: "gaussian"
    sigma: for gaussian
    return: [B, num_experts]
    """
    device = y.device
    y = y.unsqueeze(1)  # [B, 1]

    if mode == "gaussian":
        centers = torch.tensor([(l + u) / 2 for l, u in intervals], device=device)
        widths = torch.tensor([(u - l) for (l, u) in intervals], device=device)
        normalized_diff = (y - centers) / widths  
        sigma_per_expert = sigma  
        dist_sq = normalized_diff ** 2
        soft_labels = torch.exp(-dist_sq / (2 * sigma_per_expert**2))
        soft_labels = soft_labels / (soft_labels.sum(dim=1, keepdim=True) + 1e-8)
    return soft_labels

def gating_supervision_loss(gating_probs, targets_denorm, intervals, loss_type="kl", sigma=0.2):
    """
    gating_probs: [B, E] from softmax gating
    targets_denorm: [B] (denormalized y)
    intervals: list of (lower, upper)
    loss_type: "kl" or "mse"
    """

    soft_targets = build_soft_target_distribution(targets_denorm, intervals, sigma=sigma)  # [B, E]
    if loss_type == "kl":
        loss = F.kl_div(torch.log(gating_probs + 1e-8), soft_targets, reduction='batchmean')
    elif loss_type == "mse":
        loss = F.mse_loss(gating_probs, soft_targets)
    else:
        raise ValueError("Unsupported loss_type: choose 'kl' or 'mse'")

    return loss

class CustomExpertMSELoss(torch.nn.Module):
    def __init__(self, intervals = None, weights = None):
        """
        intervals: list of (lower, upper) tuples, e.g.,
                   [(0.0, 1.5), (0.5, 2.0), (1.0, 2.5), (1.5, 3.0)]
        weights: optional list of expert weights
        """
        super().__init__()
        self.intervals = intervals
    def forward(self, targets, gate, predictions = None):
        """
        predictions: [batch_size, num_experts]
        targets: [batch_size]
        mu, std: for denormalization
        """
        loss_gating = gating_supervision_loss(gate.squeeze(), targets, self.intervals, loss_type="kl", sigma=0.7)
        return loss_gating

class CustomMSELoss(torch.nn.Module):
    def __init__(self, threshold = 0):
        super().__init__()
        self.threshold = threshold 
    def forward(self, predictions, targets):
        loss_mse = (predictions - targets) ** 2 
        return loss_mse.mean()

def train(
    train_loader: Any,
    model: torch.nn.Sequential,
    criterion_task: Callable,
    optimizer: Optimizer,
    scheduler: Any,
    opt: Any,
    intervals: Any,
    group_labels: Any,
    mu: int = 0,
    std: int = 0,
    dynamic_t: int = 0, 
    max_dist: int = 0,
    epoch: int = 0,
):
    """One epoch training.

    Args:
        train_dataset (Set[Data]): Train set.
        model (torch.nn.Sequential): Model
        criterion_task (Callable): Task loss function
        optimizer (Optimizer): Optimizer
        opt (Any): Parsed arguments
        mu (int, optional): Mean value of the train set for the regression task. Defaults to 0.
        std (int, optional): Standard deviation of the train set for the regression task.
            Defaults to 0.

    Returns:
        Losses.
    """
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()

    losses_task = AverageMeter()
    losses_scl = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    for _, batch in enumerate(tqdm(train_loader, desc="Iteration")):
        batch = batch.to("cuda")
        data_time.update(time.time() - end)
        
        bsz = batch.y.shape[0]
        if not opt.classification:
            labels = (batch.y - mu) / std
            labels = labels.unsqueeze(1)
        else:
            labels = batch.y

        # compute loss
        (
            output_final,
            gate,
            f_out,
            f_rdkit,
        ) = model(batch, 'train')

        criterion_cl = SupConLossReg(gamma1 = 1, gamma2 = 0)
        criterion_custom_mse = CustomMSELoss()
        criterion_custom_expert_mse = CustomExpertMSELoss(intervals = intervals) 

        features_graph_1 = torch.cat([f_out.unsqueeze(1), f_out.unsqueeze(1)], dim=1)
        features_graph_2 = torch.cat([f_rdkit.unsqueeze(1), f_rdkit.unsqueeze(1)], dim=1)

        loss_task = criterion_custom_mse(output_final, labels)
        loss_expert = criterion_custom_expert_mse(labels.squeeze(), gate)

        soft_targets = build_soft_target_distribution(labels, intervals)
        loss_cl = (criterion_cl(features_graph_1, opt.gamma, soft_targets, group_labels[batch.idx].squeeze()) +
                   criterion_cl(features_graph_2, opt.gamma, soft_targets, group_labels[batch.idx].squeeze()))

        if opt.classification:
            loss = loss_task + loss_cl
        else:
            loss = loss_task + loss_cl * opt.lbd + loss_expert
        
        losses_task.update(loss_task.item(), bsz)
        losses_scl.update(loss_cl.item(), bsz)
        losses.update(loss.item(), bsz)

        optimizer[0].zero_grad()
        loss.backward()
        optimizer[0].step()
        scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

    return losses_task.avg, losses_scl.avg, losses.avg


def draw_tsne(feature, name, y):
    tsne = TSNE(n_components=2, perplexity=30, random_state=42)
    f_out_tsne = tsne.fit_transform(feature.numpy())

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(f_out_tsne[:, 0], f_out_tsne[:, 1], c=np.squeeze(y), cmap='viridis', s=20, alpha=0.7)
    plt.colorbar(scatter, label="Y")
    plt.title("t-SNE of GNN Feature (Validation Set)")
    plt.xlabel("TSNE-1")
    plt.ylabel("TSNE-2")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"tsne_{name}.png", dpi=300)

def validation(
    dataset: Set[Data],
    model: torch.nn.Sequential,
    opt: Any,
    name:str = 'val',
    mu: int = 0,
    std: int = 0,
    save_feature: int = 0,
    epoch: int =0
):
    """Calculate performance metrics.

    Args:
        dataset (Set[Data]): A dataset.
        model (torch.nn.Sequential): Model.
        opt (Any): Parsed arguments.
        mu (int, optional): Mean value of the train set for the regression task.
            Defaults to 0.
        std (int, optional): Standard deviation of the train set for the regression task.
            Defaults to 0.
        save_feature (int, optional): Whether save the learned features or not.
            Defaults to 0.

    Returns:
        auroc or rmse value.
    """
    model.eval()

    if opt.classification:
        evaluator = Evaluator(name=opt.dataset, num_tasks=opt.num_tasks, eval_metric="rocauc")
    else:
        evaluator = Evaluator(name=opt.dataset, num_tasks=opt.num_tasks, eval_metric="mae")
        evaluator_cls = Evaluator(name=opt.dataset, num_tasks=opt.num_tasks, eval_metric="rocauc")
    data_loader = DataLoader(
        dataset, batch_size=opt.batch_size, shuffle=False, follow_batch = ['fg_x', 'atom2fg_list']
    )

    with torch.no_grad():
        y_true = []
        y_pred = []
        f_out_all = []
        f_moe_all = []
        f_rdkit_all = []
        experts = []
        for _, batch in enumerate(tqdm(data_loader, desc="Iteration")):
            batch = batch.to("cuda")
            (
                output_final,
                f_out,
                f_moe,
                rdkit_feat, 
                expert_num
            ) = model(batch, "valid")

            if not opt.classification:
                output_0 = ((output_final) * std + mu)
                output = output_0

            if opt.classification:
                sigmoid = torch.nn.Sigmoid()
                output = sigmoid(output_final)

            y_true.append(batch.y.detach().cpu())
            y_pred.append(output.detach().cpu())
            experts.append(expert_num.detach().cpu())

            if save_feature:
                f_out_all.append(f_out.detach().cpu())
                f_moe_all.append(f_moe.detach().cpu())
                f_rdkit_all.append(rdkit_feat.detach().cpu())
        experts = torch.cat(experts, dim=0).numpy()
        y_true = torch.cat(y_true, dim=0).squeeze().unsqueeze(1).numpy()
        if opt.num_tasks > 1:
            y_pred = np.concatenate(y_pred)
            input_dict = {"y_true": y_true.squeeze(), "y_pred": y_pred.squeeze()}
        else:
            y_pred = np.expand_dims(np.concatenate(y_pred), 1)

            y_true_binary = (y_true < 0.9).astype(int)
            y_pred_binary = (y_pred < 0.9).astype(int)

            input_dict = {
                "y_true": np.expand_dims(y_true.squeeze(), 1),
                "y_pred": np.expand_dims(y_pred.squeeze(), 1),
            }

            input_dict_cls = {
                    "y_true": np.expand_dims(y_true_binary.squeeze(), 1),
                    "y_pred": np.expand_dims(y_pred_binary.squeeze(), 1),
            }

            if save_feature:
                f_out_all = torch.cat(f_out_all)
                np.save(f"f_out_{name}.npy", f_out_all.numpy())
                draw_tsne(f_out_all, 'out_' + name, y_true)

                f_moe_all = torch.cat(f_moe_all)
                np.save(f"f_moe_{name}.npy", f_moe_all.numpy())
                draw_tsne(f_moe_all, 'moe_' + name, y_true)

                f_rdkit_all = torch.cat(f_rdkit_all)
                np.save(f"f_rdkit_{name}.npy", f_rdkit_all.numpy())
                draw_tsne(f_rdkit_all, 'rdkit_' + name, y_true)

        if opt.classification:
            eval_result = evaluator.eval(input_dict)["rocauc"]
        else:
            eval_result = evaluator.eval(input_dict)["mae"]

    return y_true, y_pred, eval_result, experts

def main():

    for dataname in [opt.dataset + "_3", opt.dataset + "_1", opt.dataset + "_5", opt.dataset + "_2", opt.dataset + "_4"]: #, opt.dataset + "_6", opt.dataset + "_7", opt.dataset + "_8", opt.dataset + "_9", opt.dataset + "_0",]: 
        set_seed(10)
        train_dataset, val_dataset, test_dataset = set_loader(opt, dataname)
        if opt.classification:
            mu, std, dynamic_t, max_dist = 0, 0, 0, 0
        else:
            mu, std, dynamic_t, max_dist, intervals, intervals_overlap, group_labels, group_nums = calmean(train_dataset)

        # build model and criterion
        model, criterion_task = set_model(opt)

        # build optimizer
        optimizer = set_optimizer(opt.learning_rate, opt.weight_decay, model)

        model_name = "{}_{}".format(opt.model_name, dataname)

        # save folder
        opt.tb_folder = os.path.join(opt.tb_path, model_name)
        if not os.path.isdir(opt.tb_folder):
            os.makedirs(opt.tb_folder)

        opt.save_folder = os.path.join(opt.model_path, model_name)
        if not os.path.isdir(opt.save_folder):
            os.makedirs(opt.save_folder)

        with open(opt.save_folder + "//runtime_params.json", "w") as f:
                json.dump(args_dict, f, indent=4)

        # tensorboard
        logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)

        #best_acc = 0
        best_acc = float('inf')

        best_model = model
        best_epoch = 0

        train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, drop_last = True, shuffle=True, follow_batch=['fg_x', 'atom2fg_list'])

        num_training_steps =  len(train_loader) * opt.epochs
        num_warmup_steps = int(num_training_steps * 0.1)
        scheduler = optimization.get_linear_schedule_with_warmup(optimizer[0], num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

        # training routine
        for epoch in range(opt.epochs):
            torch.cuda.empty_cache()
            # train for one epoch
            time1 = time.time()
            loss_task, loss_scl, loss = train(
                train_loader,
                model,
                criterion_task,
                optimizer,
                scheduler,
                opt,
                intervals,
                group_labels,
                mu,
                std,
                dynamic_t, 
                max_dist,
                epoch,
            )
            time2 = time.time()
            print("epoch {}, total time {:.2f}".format(epoch, time2 - time1))

            _, _, acc, _ = validation(val_dataset, model, opt, 'valid', mu, std, 0, epoch)

            # tensorboard logger
            logger.log_value("task loss", loss_task, epoch)
            logger.log_value("supervised contrastive loss", loss_scl, epoch)
            logger.log_value("overall loss", loss, epoch)
            logger.log_value("validation auroc/rmse", acc, epoch)
            logger.log_value("learning rate", optimizer[0].state_dict()['param_groups'][0]['lr'], epoch)

            if acc < best_acc:
                best_acc = acc
                best_model = deepcopy(model).cpu()
                best_epoch = epoch
                print("val rmse:{}".format(acc))

        y_val_true, y_val_pred, val_acc, val_experts = validation(val_dataset, best_model.cuda(), opt, 'valid'+opt.dataset, mu, std, 1, epoch-1)
        y_train_true, y_train_pred, train_acc, train_experts = validation(train_dataset, best_model.cuda(), opt, 'train'+opt.dataset, mu, std, 1, epoch-1)
        y_test_true, y_test_pred, test_acc, test_experts = validation(test_dataset, best_model.cuda(), opt, 'test'+opt.dataset, mu, std, 1, epoch-1)

        save_dict = {
                "model_state_dict": best_model.state_dict(),
                "train_mean": mu,
                "train_std": std,  
        }

        df = pd.DataFrame(y_val_true.squeeze())
        df.to_csv(opt.save_folder+'//val_true_result.csv')
        df = pd.DataFrame(y_val_pred.squeeze())
        df.to_csv(opt.save_folder+'//val_pred_result.csv')
        np.save(opt.save_folder+'//val_experts.npy', val_experts)

        df = pd.DataFrame(y_train_true.squeeze())
        df.to_csv(opt.save_folder+'//train_true_result.csv')
        df = pd.DataFrame(y_train_pred.squeeze())
        df.to_csv(opt.save_folder+'//train_pred_result.csv')
        np.save(opt.save_folder+'//train_experts.npy', train_experts)

        df = pd.DataFrame(y_test_true.squeeze())
        df.to_csv(opt.save_folder+'//test_true_result.csv')
        df = pd.DataFrame(y_test_pred.squeeze())
        df.to_csv(opt.save_folder+'//test_pred_result.csv')
        np.save(opt.save_folder+'//test_experts.npy', test_experts)

        save_file = os.path.join(opt.save_folder, "result_pre.txt")
        txtFile = open(save_file, "w")
        txtFile.write("validation:" + str(val_acc) + "\n")
        txtFile.write("test:" + str(test_acc) + "\n")
        txtFile.write("best epoch:" + str(best_epoch) + "\n")
        txtFile.close()
        print("Val Result:{}".format(val_acc))
        print("Test Result:{}".format(test_acc))


if __name__ == "__main__":
    main()
