import os
import torch
import numpy as np
from tqdm import tqdm
from config import get_config
from models import load_model
from functools import partialmethod
from core.metric import class_eval
from torch.utils.data import DataLoader
from core.meters import AverageMeter
from core.loss_func import DockingLoss
from core.dataset import batch_to_device
from core.dataset import PPDockDataset, collate_fnc

def predict(config, out_dir):
    
    test_dataset = PPDockDataset(config, split='test', dataset_name='rcsb')
    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             collate_fn=collate_fnc,
                             num_workers=config.num_data_workers,
                             pin_memory=True)
    
    # initialize model
    Model = load_model(f'{config.model}')
    model = Model(config)
    
    if torch.cuda.is_available():
        model.cuda()
    criterion = DockingLoss(config)

    with open(config.restore, 'rb') as fin:
        state_dict = torch.load(fin, map_location='cpu')['state_dict']
    model.load_state_dict(state_dict)
    
    eval(model, test_loader, criterion, out_dir)


def eval(model, data_loader, criterion, out_dir):
    # init average meters
    lig_bsp_losses = AverageMeter('LigBSPLoss', ':5.3f')
    lig_precisions = AverageMeter('LigPrec', ':5.3f')
    lig_recalls = AverageMeter('LigRec', ':5.3f')
    lig_fscores = AverageMeter('LigFsc', ':5.3f')
    lig_AUCs = AverageMeter('LigAUC', ':5.3f')
    lig_APs = AverageMeter('LigAP', ':5.3f')
    rec_bsp_losses = AverageMeter('RecBSPLoss', ':5.3f')
    rec_precisions = AverageMeter('RecPrec', ':5.3f')
    rec_recalls = AverageMeter('RecRec', ':5.3f')
    rec_fscores = AverageMeter('RecFsc', ':5.3f')
    rec_AUCs = AverageMeter('RecAUC', ':5.3f')
    rec_APs = AverageMeter('RecAP', ':5.3f')
    lig_attn_losses = AverageMeter('LigAttnLoss', ':5.3f')
    rec_attn_losses = AverageMeter('RecAttnLoss', ':5.3f')
    nce_losses = AverageMeter('NCELoss', ':5.3f')

    # switch to evaluate mode
    model.eval()
    os.makedirs(out_dir, exist_ok=True)
    with torch.no_grad():
        for batch in tqdm(data_loader):
            # send data to device and compute model output
            if torch.cuda.is_available():
                batch = batch_to_device(batch)
            output = model(batch)

            lig_bsp_loss, rec_bsp_loss, lig_attn_loss, rec_attn_loss, nce_loss = criterion(output)
            lig_precision, lig_recall, lig_fscore, lig_AUC, lig_AP = class_eval(output['lig_dict'])
            rec_precision, rec_recall, rec_fscore, rec_AUC, rec_AP = class_eval(output['rec_dict'])
            bsize = len(batch['lig_dict']['num_verts'])
            lig_bsp_losses.update(lig_bsp_loss.item(), bsize)
            lig_precisions.update(sum(lig_precision)/bsize, bsize)
            lig_recalls.update(sum(lig_recall)/bsize, bsize)
            lig_fscores.update(sum(lig_fscore)/bsize, bsize)
            lig_AUCs.update(sum(lig_AUC)/bsize, bsize)
            lig_APs.update(sum(lig_AP)/bsize, bsize)
            rec_bsp_losses.update(rec_bsp_loss.item(), bsize)
            rec_precisions.update(sum(rec_precision)/bsize, bsize)
            rec_recalls.update(sum(rec_recall)/bsize, bsize)
            rec_fscores.update(sum(rec_fscore)/bsize, bsize)
            rec_AUCs.update(sum(rec_AUC)/bsize, bsize)
            rec_APs.update(sum(rec_AP)/bsize, bsize)
            lig_attn_losses.update(lig_attn_loss.item(), bsize)
            rec_attn_losses.update(rec_attn_loss.item(), bsize)
            nce_losses.update(nce_loss.item(), bsize)
            
            # features
            lig_bid = batch['lig_dict']['num_verts']
            lig_bsp = batch['lig_dict']['bsp'].split(lig_bid)
            lig_h = batch['lig_dict']['h'].split(lig_bid)
            lig_scores = torch.tensor([lig_precision, lig_recall, lig_fscore, lig_AUC, lig_AP]).cpu().numpy()
            rec_bid = batch['rec_dict']['num_verts']
            rec_bsp = batch['rec_dict']['bsp'].split(rec_bid)
            rec_h = batch['rec_dict']['h'].split(rec_bid)
            rec_scores = torch.tensor([rec_precision, rec_recall, rec_fscore, rec_AUC, rec_AP]).cpu().numpy()

            # save predictions
            for idx in range(bsize):
                fpath = batch['batch_fpath'][idx]
                # load original data for vertices, faces, and atom info
                fname = fpath[fpath.rfind('/')+1:]
                raw_fpath = os.path.join('datasets/dataset_DB5/', fname)
                raw_data = np.load(raw_fpath)
                lig_verts = raw_data['lig_verts']
                lig_faces = raw_data['lig_faces']
                lig_atom_info = raw_data['lig_atom_info']
                rec_verts = raw_data['rec_verts']
                rec_faces = raw_data['rec_faces']
                rec_atom_info = raw_data['rec_atom_info']
                # ground truth labels
                lig_iface_label = torch.zeros(lig_bsp[idx].size(0), dtype=torch.float32, requires_grad=False)
                lig_iface_label[batch['lig_dict']['iface_p2p'][idx][:, 0]] = 1
                assert len(lig_verts) == lig_iface_label.size(0)
                rec_iface_label = torch.zeros(rec_bsp[idx].size(0), dtype=torch.float32, requires_grad=False)
                rec_iface_label[batch['rec_dict']['iface_p2p'][idx][:, 0]] = 1
                assert len(rec_verts) == rec_iface_label.size(0)
                # save output
                out_file = os.path.join(out_dir, fname)
                np.savez(out_file, 
                                  lig_verts=lig_verts,
                                  lig_faces=lig_faces,
                                  lig_atom_info=lig_atom_info,
                                  lig_bsp=lig_bsp[idx].cpu().numpy(),
                                  lig_h=lig_h[idx].cpu().numpy(),
                                  lig_iface_label=lig_iface_label.cpu().numpy(),
                                  lig_scores=lig_scores[:, idx],
                                  rec_verts=rec_verts,
                                  rec_faces=rec_faces,
                                  rec_atom_info=rec_atom_info,
                                  rec_bsp=rec_bsp[idx].cpu().numpy(),
                                  rec_h=rec_h[idx].cpu().numpy(),
                                  rec_iface_label=rec_iface_label.cpu().numpy(),
                                  rec_scores=rec_scores[:, idx])
            
    print_info = ['***** test\n', 
                f'LigBSPLoss: {lig_bsp_losses.avg:.3f}, ',
                f'LigPrec: {lig_precisions.avg:.3f}, LigRec: {lig_recalls.avg:.3f}, ',
                f'LigFsc: {lig_fscores.avg:.3f}, LigAUC: {lig_AUCs.avg:.3f}, ',
                f'LigAP: {lig_APs.avg:.3f}\n',
                f'RecBSPLoss: {rec_bsp_losses.avg:.3f}, ',
                f'RecPrec: {rec_precisions.avg:.3f}, RecRec: {rec_recalls.avg:.3f}, ',
                f'RecFsc: {rec_fscores.avg:.3f}, RecAUC: {rec_AUCs.avg:.3f}, ',
                f'RecAP: {rec_APs.avg:.3f}\n',
                f'LigAttnLoss: {lig_attn_losses.avg:.3f}\n',
                f'RecAttnLoss: {rec_attn_losses.avg:.3f}\n',
                f'NCELoss: {nce_losses.avg:.3f}\n',
                f'*****\n']
    print(''.join(print_info))  


if __name__ == '__main__':
    config = get_config()
    config.serial = True
    if config.restore is None:
        config.restore = f'./{config.run_name}_checkpoints/model_pretrain_best.pt'

    # mute tqdm for production runs
    if not config.unmute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    out_dir = './step2_output/'
    predict(config, out_dir)


