from .args_factory import get_args
from .utils import step_loss, draw_atom, preprocess_data, normalize_ohe_features
from datasets import load_dataset
from torch.utils.data import DataLoader
import sys
sys.path.append('../')
from graph_reconstruction.utils import get_model, get_X, get_A, possible_feature_values
from metrics import evaluate_metrics
import yaml
import neptune.new as neptune
import torch
from rdkit import Chem
import random
import matplotlib.pyplot as plt
import io
import numpy as np
from tqdm import tqdm


def run_attack(args, model, criterion, sample, generator=None):
    
    gt_mol = Chem.MolFromSmiles(sample['smiles'][0])
    gt_fms = torch.tensor(get_X(gt_mol, feature_onehot_encoding=args.do_ohe)).to(args.device)
    gt_fms.requires_grad_(True)

    gt_ams = torch.tensor(get_A(gt_mol)).to(args.device)


    label = batch['target']
    if label==1:
        gt_ls = torch.tensor([0.,1.]).to(args.device)
    else:
        gt_ls = torch.tensor([1.,0.]).to(args.device)

    # Compute the correct gradients
    logits = model(gt_ams, gt_fms)
    loss = criterion(logits, gt_ls).to(args.device)
    assert(loss.numel()==2)
    loss = loss.sum() / loss.numel()         
    true_gradient = torch.autograd.grad(loss, model.parameters())
    
    # Add relevant inputs for optimization
    to_optim = []
    if not args.fix_A:
        dummy_A = torch.randn(gt_ams.shape, requires_grad=True, generator=generator, device=args.device)
        to_optim.append(dummy_A)
    else:
        dummy_A = gt_ams.clone()
        
    if not args.fix_X:
        dummy_X = torch.randn(gt_fms.shape, requires_grad=True, generator=generator, device=args.device)
        to_optim.append(dummy_X)
    else:
        dummy_X = gt_fms.clone()
        
    if not args.fix_y:
        dummy_label = torch.randn(gt_ls.shape, requires_grad=True, generator=generator, device=args.device)
        to_optim.append(dummy_label)
    else:
        dummy_label = gt_ls.clone()
                    
    optimizer = torch.optim.LBFGS(to_optim, lr=0.01)
      
    # Run attack 
    for i in tqdm(range(args.max_iter)):        
        step_fn = lambda: step_loss(args, model, criterion, dummy_X, dummy_A, dummy_label, true_gradient, optimizer)
        optimizer.step(step_fn)
        if i % 200 == 0: 
            loss = step_loss(args, model, criterion, dummy_X, dummy_A, dummy_label, true_gradient, optimizer)
            print(i, "%.4f" % loss.item())

    return preprocess_data(args, dummy_A, dummy_X, dummy_label)

if __name__ == '__main__':
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)
    
    args, run = get_args()
    
    # Load dataset
    dataset = load_dataset(args.dataset)
    dataset = dataset.shuffle(seed=args.rng_seed)
    dataloader_clintox = DataLoader(dataset.with_format("torch")['train'], batch_size=1, shuffle=False)
    
    with open(args.config_path, 'r') as file:
        config = yaml.safe_load(file)
          
    model_args = config['model_args']
    
    if args.neptune:
        for arg in model_args:
            run[f'parameters/model_{arg}'] = model_args[arg]
        
    
    # Load model
    model = get_model(model_args, feat_dim = 140, num_cats = 2).to(args.device)
    
    with open(args.eval_config_path, 'r') as file:
        eval_config = yaml.safe_load(file)
          
    eval_model_args = eval_config['model_args']
    print(eval_config['model_args'])
    eval_model = get_model(eval_model_args, feat_dim = 140, num_cats = 2).to(args.device)
    
    criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
    generator = torch.Generator(device = args.device)
    generator.manual_seed(0)
    
    for i,batch in enumerate(dataloader_clintox):
        
        if i == args.n_inputs:
            break
        print(f'Running example {i+1}')
        print('-----------------------')
        
        
        recon_X, recon_A, recon_y = run_attack(args, model, criterion, batch, generator=generator)
        
        recon_A = (recon_A > args.A_thrs).int()

        if args.do_ohe:
            recon_X_id, recon_X = normalize_ohe_features(recon_X, soft=False)
            
        # Compute metrics
        gt_mol = Chem.MolFromSmiles(batch['smiles'][0])
        gt_fms = torch.tensor(get_X(gt_mol, feature_onehot_encoding=args.do_ohe)).to(args.device)
        gt_fms.requires_grad_(True)

        gt_ams = torch.tensor(get_A(gt_mol)).to(args.device)    
        
        metrics = evaluate_metrics(args, gt_ams, gt_fms, recon_A, recon_X, eval_model)
        
        
        for metric in metrics:
            print(f'{metric}: {metrics[metric]:.6f}')
            if args.neptune:
                run[f'logs/{metric}'].log(metrics[metric])
            
        if args.do_ohe:
            feature_vals = possible_feature_values()
            recon_X = torch.tensor([[feature_vals[j][idx.item()] for j, idx in enumerate(recon_X_id[i])] for i in range(recon_X_id.shape[0])]).to(args.device)
        
        
        try:
            # Draw examples
            recon = draw_atom(recon_X, recon_A)
            plt.clf()
            plt.imshow(recon)
            
            if args.neptune:
                buf = io.BytesIO()
                plt.savefig(buf, format='png')
                buf.seek(0)
                run[f"pred/{i}.png"].append(neptune.types.File.from_content(buf.getvalue(), extension="png"))
            else:
                plt.savefig(f"pred_{i}.png")
                
            gt = Chem.Draw.MolToImage(Chem.MolFromSmiles(batch['smiles'][0]), size=(1200, 1200))
            plt.clf()
            plt.imshow(gt)
        except:
            pass
        
        if args.neptune:
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            run[f"gt/{i}.png"].append(neptune.types.File.from_content(buf.getvalue(), extension="png"))
        else:
            plt.savefig(f"gt_{i}.png")          

