import sys
sys.path.append('..')

from src.utils import set_random_seed
import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import MSELoss, BCEWithLogitsLoss
import numpy as np
import random
from src.data.featurizer import Vocab, N_ATOM_TYPES, N_BOND_TYPES
from fragment_mol.utils.chem_utils import DATASET_TASKS
from fragment_mol.utils.utils import WarmUpLR
from fragment_mol.models.model_utils import ModelWithEMA 

from src.trainer.scheduler import PolynomialDecayLR
from src.trainer.finetune_trainer import Trainer
from fragment_mol.evaluator import Evaluator
import time 
import json 
from pathlib import Path 
import wandb 

from fragment_mol.register import MODEL_DICT, DATASET_DICT, COLLATOR_DICT, MODEL_ARG_FUNC_DICT
from fragment_mol.utils.chem_utils import DATASET_TASKS, get_task_metrics, get_task_type, METRIC_BEST_TYPE 
from fragment_mol.utils.fingerprint import FP_FUNC_DICT
from fragment_mol.ps_lg.mol_bpe_new import TokenizerNew
from fragment_mol.models.model_utils import model_n_params 

from fragment_mol.register import EXPLAIN_DICT
from fragment_mol.datasets import DatasetAttribution
from fragment_mol.explain import GradInput

from tqdm import tqdm 

import warnings
warnings.filterwarnings("ignore")

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def valid_model(model, valid_dataloader, evaluators, criterion, args):
    model.eval()
    label_list = []
    predict_list = []
    epoch_loss = 0
    for idx, (input_data, labels) in enumerate(valid_dataloader):
      
        predict = model(input_data)
        label_list.append(labels)
        predict_list.append(predict)
        
        is_labeled = (~torch.isnan(labels)).to(torch.float32)
        labels = torch.nan_to_num(labels, nan=0.0)
        loss = (criterion(predict, labels) * is_labeled).mean()

        epoch_loss += loss.item()
        
    avg_loss = epoch_loss/len(valid_dataloader)
    labels = torch.cat(label_list, dim=0)
    predicts = torch.cat(predict_list, dim=0)

    score = {metric: evaluator.eval(labels, predicts) for metric, evaluator in evaluators.items()}
    return score, avg_loss

def finetune(args):
    set_random_seed(args.seed)
    
    dataset_class = DATASET_DICT['frag_graph']
  
    test_dataset = dataset_class(args.dataset, split="test", scaffold_id=args.scaffold_id, args=args)
    
    collator_class = COLLATOR_DICT['frag_graph']
    collator = collator_class(device=args.device)
    
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collator)
    
    n_tasks_ft = DATASET_TASKS[args.dataset]
    
    vocab_path = "fragment_mol/ps_lg/chembl29_vocab_lg.txt"
    
    tokenizer = TokenizerNew(vocab_path=vocab_path)
    model_class = MODEL_DICT['fragformer']
    model_args = MODEL_ARG_FUNC_DICT['fragformer']()
    n_tasks_pt = len(tokenizer)
    model = model_class(args=model_args, n_tasks=n_tasks_pt)
    model.init_ft_predictor(n_tasks_ft, args.dropout)

    model.load_state_dict({k.replace('module.',''): v for k, v in torch.load(f'{args.model_path}').items()})
    
    
    model = model.to(args.device)
    model.eval()
    
    print("model have {}M paramerters in total".format(sum(x.numel() for x in model.parameters())//int(10**6)))
    
    task_type = get_task_type(dataset=args.dataset)
    if task_type == 'cls':
        criterion = BCEWithLogitsLoss(reduction='none').to(args.device)
    else:
        criterion = MSELoss(reduction='none').to(args.device)
        
    metrics = get_task_metrics(args.dataset)
    evaluators = {metric: Evaluator(name=args.dataset, eval_metric=metric, n_tasks=DATASET_TASKS[args.dataset], mean=getattr(args, 'target_mean', None), std=getattr(args, 'target_std', None)) for metric in metrics}
    
    
    test_dataset.remove_negative_samples()
    test_explain_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collator) 
    
    
    dataset_attribution = DatasetAttribution(args.dataset)
    explain_cls = EXPLAIN_DICT[args.explain_method]
    explain_model = explain_cls(args)
    
    gt_attr_score_list = [dataset_attribution.attribution[i] for i in test_dataset.pos_index_list]
    gt_smiles = [dataset_attribution.smiles[i] for i in test_dataset.pos_index_list]
    print(f"size of gt: {len(gt_attr_score_list)}")
    pred_attr_score_list = []
    print(f"size of explain: {len(test_dataset)}")
    for data, labels in tqdm(test_explain_dataloader):
        group_idx = data['group_idx_list'][0]
        attr_score = explain_model.explain(model, data, labels[0], frag=True, group_idx=group_idx) 

        pred_attr_score_list.append(attr_score)
    
    # exit()
    print("finish explain, start evaluating...")
    print(gt_attr_score_list[0].shape)
    print(pred_attr_score_list[0].shape)
    print(test_dataset.smiles[0])
    print(gt_smiles[0])
    score_list = []
    n_error = 0
    valid_gt_attr_score_list = []
    valid_pred_attr_score_list = []
    valid_smiles = []
    for i, (gt_score, pred_score) in enumerate(zip(gt_attr_score_list, pred_attr_score_list)):
        gt_score, pred_score = gt_score.reshape([-1, 1]), pred_score.reshape([-1, 1])
        
        if gt_score.sum()==0 or gt_score.sum()==len(gt_score):
            n_error += 1
            continue
        score = {metric: evaluator.eval(gt_score, pred_score) for metric, evaluator in evaluators.items()}
        score_list.append(score)
        valid_gt_attr_score_list.append(gt_score.reshape(-1).tolist())
        valid_pred_attr_score_list.append(pred_score.reshape(-1).tolist())
        valid_smiles.append(test_dataset.smiles[i])
    
    num_gt_nodes = [s.sum() for s in gt_attr_score_list]
    indices = np.argsort([s['rocauc'] if num_gt_nodes[i]>0 else 0 for i, s in enumerate(score_list)])[::-1]
    res_list = []
    for index in indices[:10]:
        res = {"score": score_list[index], "smiles": valid_smiles[index], "gt_attr": valid_gt_attr_score_list[index], "pred_attr": valid_pred_attr_score_list[index]}
        res_list.append(res)
    import json 
    json.dump(res_list, open(f"explain_mutag.json", "w"), indent=2)
    score = {metric: np.mean([s[metric] for s in score_list]) for metric in score_list[0]}
    print(f"n_error: {n_error}")
    print(score)
    
  
    
    
    
    
if __name__ == '__main__':
    start_time = time.time()
    parser = argparse.ArgumentParser(description="Arguments for training LiGhT")
    parser.add_argument("--seed", type=int, default=22)
    parser.add_argument('--device', type=int, default=0, help='device id')
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument('--warmup_epochs', type=int, default=1, help='# of warmup epochs')
    parser.add_argument('--warmup', action='store_true', help='whether to use warmup')
    parser.add_argument("--ema", type=float, default=0.9, help='ema decay')
    
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--dataset", type=str)
    parser.add_argument("--data_path", type=str)
    parser.add_argument('--scaffold_id', type=int, default=0, help='scaffold id')
    parser.add_argument("--batch_size", type=int, default=32, help="batch size")
    
    parser.add_argument("--weight_decay", type=float, required=True)
    parser.add_argument("--dropout", type=float, required=True)
    parser.add_argument("--lr", type=float, required=True)
    parser.add_argument("--n_threads", type=int, default=4)
    parser.add_argument("--knodes", type=str, default=[], nargs="*", help="knowledge type",
                        choices=list(FP_FUNC_DICT.keys()))
    parser.add_argument('--wandb', action='store_true', help='whether to use wandb')
    parser.add_argument('--debug', action='store_true', help='debug mode') 
    
    parser.add_argument('--explain_method', type=str, default="cam", help='name of explain method')
    
    args, _ = parser.parse_known_args()
    
    if args.wandb:
        wandb.init(
                project = "FragFormer",
                name = f"finetune-{args.dataset}-{args.scaffold_id}",
                config = args,
            )
    print(f"finetune on {args.dataset}, {args.scaffold_id}, model={args.model_path}")
    
    finetune(args)
    end_time = time.time()
    print(f"Time cost: {end_time-start_time:.2f}s")
    if args.wandb:
        wandb.finish()
    
    


    
    
    


