# python ../scripts/train_end2end.py  | tee ../log/log_3.txt
import argparse
import random
from itertools import islice

import numpy as np
import pandas as pd
import torch
import os

import yaml
from torch.utils.data import DistributedSampler
from torch_geometric.graphgym import optim
from torch.nn.utils import clip_grad_norm_
from torchdrug.utils import comm
from torch import distributed as dist, nn
from torchmetrics import PearsonCorrCoef
from torchmetrics.functional import pearson_corrcoef

from task import End2EndDocking
from data import PDBbind, TankBindDataSet_new
from model import IterativeRefinement
from utils import output_mol, set_seed, record_data, DebugBatchSampler, get_eval_task, post_optimize_output_mol, \
    tankbind_output_pocket, end2end_evaluate

import logging
from torch_geometric.loader import DataLoader
from tqdm import tqdm

import wandb
from easydict import EasyDict

logger = logging.getLogger("")
logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', type=str, default="Debug")
parser.add_argument('--default_dict', type=str, default="../config/end2end_dock_defaults.yaml")
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
print(args)

with open(args.default_dict, "r") as f:
    default_dict = yaml.load(f, Loader=yaml.FullLoader)
default_dict.update(vars(args))
cfg = EasyDict(default_dict)
# Multi-GPU Support
rank = comm.get_rank()
world_size = comm.get_world_size()
is_master = (rank == 0)
gpus = cfg.engine.gpus
if gpus is None:
    device = torch.device("cpu")
else:
    if len(gpus) != world_size:
        error_msg = "World size is %d but found %d GPUs in the argument"
        if world_size == 1:
            error_msg += ". Did you launch with `python -m torch.distributed.launch`?"
        raise ValueError(error_msg % (world_size, len(gpus)))
    device = torch.device(gpus[rank % len(gpus)])
if world_size > 1 and not dist.is_initialized():
    if rank == 0:
        logger.info("Initializing distributed process group")
    backend = "gloo" if gpus is None else "nccl"
    comm.init_process_group(backend, init_method="env://")

if is_master:
    # wandb.init(project="debug_end2enddock", config=dict(cfg), name=args.exp_name, dir=default_dict['wandb_log_dir'])
    wandb.init(project="reformat code", config=dict(cfg), name=args.exp_name, dir=default_dict['wandb_log_dir'])

# >>>>>>>>>>>>>>>>>>>>>

set_seed(cfg.engine['seed'])

output_folder = os.path.join(cfg.engine['output_root'], args.exp_name)  # TODO
model_folder = os.path.join(cfg.engine['model_root'], args.exp_name)  # TODO
os.makedirs(output_folder, exist_ok=True)
os.makedirs(model_folder, exist_ok=True)


model = IterativeRefinement(**cfg.model).to(device)
task = End2EndDocking(model=model, criterion=cfg.task['criterion'], max_iter_num=cfg.task['max_iter_num'],
                      inter_distance_threshold=cfg.task['inter_distance_threshold'],
                      intra_distance_threshold=cfg.task['intra_distance_threshold'])
optimizer = optim.Adam(task.parameters(), lr=cfg.optimizer['lr'])
if cfg.engine['load_model'] is not None:
    logger.warning(f"Load Model from {cfg.engine['load_model']}")
    checkpoint = torch.load(cfg.engine['load_model'])
    task.load_state_dict(checkpoint['model_state_dict'], strict=False)
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # TODO: cannot load when model is not matched

if world_size > 1:
    task = nn.parallel.DistributedDataParallel(task, device_ids=[device],
                                               find_unused_parameters=True)


def main():
    base_dataset = TankBindDataSet_new(root=cfg.dataset['root'],
                                       pocket_info_path=cfg.dataset['pocket_info_path'],
                                       protein_embed_folder=cfg.dataset['protein_embed_folder'],
                                       compound_embed_path=cfg.dataset['compound_embed_path'],
                                       compound_folder=cfg.dataset['compound_folder'],
                                       setting=cfg.dataset["setting"],
                                       )
    base_dataset.info = base_dataset.info.query("c_length < 100 and native_num_contact > 5").reset_index(
        drop=True)  # TODO
    train_warm_index = base_dataset.info.query("use_compound_com and group =='train'").index.values
    train_index = base_dataset.info.query("group =='train'").index.values
    valid_index = base_dataset.info.query("use_compound_com and group =='valid'").index.values
    test_index = base_dataset.info.query(
        "(num_contact/native_num_contact > 0.9) and group =='test' and (not use_compound_com)").index.values  # TODO miss one example: 347 -> 346
    # test_index = base_dataset.info.query("group =='test' and use_compound_com").index.values #TODO miss one example: 347 -> 346
    train_warm_dataset = base_dataset[train_warm_index]
    train_dataset = base_dataset[train_index]
    valid_dataset = base_dataset[valid_index]
    test_dataset = base_dataset[test_index]
    train_dataset.add_noise_to_com = cfg.dataset['add_noise_to_com']
    train_warm_dataset.add_noise_to_com = cfg.dataset['add_noise_to_com']
    # train_dataset.native_pocket_threshold = 0.7 # TODO

    train_warm_sampler = DistributedSampler(train_warm_dataset, num_replicas=world_size, rank=rank)
    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    train_warm_dataloader = DataLoader(train_warm_dataset, batch_size=cfg.engine['batch_size'],
                                  follow_batch=['x', 'y', 'compound_pair'],
                                  sampler=train_warm_sampler, num_workers=cfg.engine['num_workers'], pin_memory=True)
    train_dataloader = DataLoader(train_dataset, batch_size=cfg.engine['batch_size'],
                                  follow_batch=['x', 'y', 'compound_pair'],
                                  sampler=train_sampler, num_workers=cfg.engine['num_workers'], pin_memory=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.engine['batch_size'],
                                  follow_batch=['x', 'y', 'compound_pair'],
                                  shuffle=False, num_workers=cfg.engine['num_workers'], pin_memory=True)
    test_dataloader = DataLoader(test_dataset, batch_size=cfg.engine['batch_size'],
                                 follow_batch=['x', 'y', 'compound_pair'],
                                 shuffle=False, num_workers=cfg.engine['num_workers'], pin_memory=True)

    best_valid_eval_metric = {key: None for key in cfg.task['eval_metrics'].keys()}
    is_best = lambda key, value: \
        value > best_valid_eval_metric[key] if cfg.task['eval_metrics'][key] == "Ascending" \
            else value < best_valid_eval_metric[key]
    for epoch in tqdm(range(cfg.engine['num_epochs'])):
        if epoch < cfg.dataset["warm_up_epoch"]:
            _train_dataloader = train_warm_dataloader
            train_warm_sampler.set_epoch(epoch)
        else:
            _train_dataloader = train_dataloader
            train_sampler.set_epoch(epoch)
        if world_size > 1:
            _task = task.module
        else:
            _task = task

        wandb_log_dict = {}
        # Train
        loss_list = torch.empty(0, device=device)
        metric_list = {key: torch.empty(0, device=device) for key in cfg.task['criterion'].keys()}
        task.train()
        start_id = 0
        for batch_id, batch in enumerate(tqdm(islice(_train_dataloader, cfg.engine['max_batch_per_epoch']))):
            batch = batch.to(device)
            loss, metric = task(batch)
            loss.backward()
            loss_list = torch.cat([loss_list, loss.detach().unsqueeze(0)])
            for key in metric_list.keys():
                metric_list[key] = torch.cat([metric_list[key], metric[key].detach().unsqueeze(0)])
            if batch_id - start_id + 1 == cfg.engine['gradient_interval']:
                clip_grad_norm_(task.parameters(), max_norm=10, error_if_nonfinite=True)  # TODO
                optimizer.step()
                optimizer.zero_grad()
                start_id = batch_id + 1
            if batch_id % cfg.engine['log_interval'] == 0:
                logger.warning(f"Batch id: {batch_id}")
                logger.warning(f"Loss: {loss_list[-cfg.engine['log_interval']:].mean()}")
                for key in metric_list.keys():
                    logger.warning(f"{key}: {metric_list[key][-cfg.engine['log_interval']:].mean()}")
        wandb_log_dict["Train/Loss"] = loss_list.mean()
        for key in metric_list.keys():
            wandb_log_dict[f"Train/{key}"] = metric_list[key].mean()
        if world_size > 1:
            wandb_log_dict = comm.reduce(wandb_log_dict, op="mean")

        if is_master:
            # Evaluate
            eval_task = get_eval_task(task)
            if epoch % cfg.engine['evaluate_interval'] == 0:
                logger.warning(f"Begin Evaluation>>>>>")
                eval_result_list = {key: torch.empty(0, device=device) for key in
                                    cfg.task['eval_metrics'].keys()}
                eval_task.eval()
                with torch.no_grad():
                    for batch_id, batch in enumerate(tqdm(valid_dataloader)):
                        batch = batch.to(device)
                        result = eval_task.evaluate_metric(batch, eval_metrics=cfg.task['eval_metrics'].keys())
                        for key in cfg.task['eval_metrics'].keys():
                            eval_result_list[key] = torch.cat([eval_result_list[key], result[key].detach()])
                    for key in cfg.task['eval_metrics'].keys():
                        if key == "confidence":
                            pos_mask = eval_result_list["rmsd"] < 10
                            pearson = pearson_corrcoef(eval_result_list["rmsd"][pos_mask], eval_result_list["confidence"][pos_mask])
                            wandb_log_dict[f"Valid/confidence_pearson"] = pearson
                        else:
                            metric_value = eval_result_list[key].mean()
                            wandb_log_dict[f"Valid/{key}"] = metric_value

            # Test
            if epoch % cfg.engine['test_interval'] == 0:
                logger.warning(f"Begin Test>>>>>")
                eval_result_list = {key: torch.empty(0, device=device) for key in
                                    cfg.task['eval_metrics'].keys()}
                eval_task.eval()
                with torch.no_grad():
                    for batch_id, batch in enumerate(tqdm(test_dataloader)):
                        batch = batch.to(device)
                        result = eval_task.evaluate_metric(batch, eval_metrics=cfg.task['eval_metrics'].keys())
                        for key in cfg.task['eval_metrics'].keys():
                            eval_result_list[key] = torch.cat([eval_result_list[key], result[key].detach()])
                    for key in cfg.task['eval_metrics'].keys():
                        if key == "rmsd":
                            wandb_log_dict["Test/rmsd"] = eval_result_list["rmsd"].mean()
                            wandb_log_dict[f"Test/rmsd_25"] = torch.quantile(eval_result_list["rmsd"], 0.25)
                            wandb_log_dict[f"Test/rmsd_50"] = torch.quantile(eval_result_list["rmsd"], 0.50)
                            wandb_log_dict[f"Test/rmsd_75"] = torch.quantile(eval_result_list["rmsd"], 0.75)
                        elif key == "confidence":
                            pos_mask = eval_result_list["rmsd"] < 10
                            pearson = pearson_corrcoef(eval_result_list["rmsd"][pos_mask], eval_result_list["confidence"][pos_mask])
                            wandb_log_dict[f"Test/confidence_pearson"] = pearson
                        else:
                            wandb_log_dict[f"Test/{key}"] = eval_result_list[key].mean()


                # Write Mol
                logger.warning(f"Begin Writing>>>>>")
                output_mol(test_dataset, eval_task, output_folder=output_folder, wandb_log_dict=wandb_log_dict)

            wandb.log(wandb_log_dict)
            logger.warning(f"Epoch {epoch} Finished >>>>>>>>>>>>>")
            torch.save({
                'model_state_dict': eval_task.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, os.path.join(model_folder, f"Epoch_{epoch}.pt"))


def test(eval_unseen=False):
    base_dataset = TankBindDataSet_new(root=cfg.dataset['root'],
                                       pocket_info_path=cfg.dataset['pocket_info_path'],
                                       protein_embed_folder=cfg.dataset['protein_embed_folder'],
                                       compound_embed_path=cfg.dataset['compound_embed_path'],
                                       compound_folder=cfg.dataset['compound_folder'],
                                       setting=cfg.dataset["setting"],
                                       )
    score_info = pd.read_pickle(os.path.join(output_folder, "info_score.pkl"))
    compound_name = [index[:4] for index in score_info['index'].to_list()]
    score_info['compound_name'] = compound_name
    max_indices = score_info.groupby(['compound_name'])['score'].transform(max) == score_info['score']
    score_info_max = score_info[max_indices]  # Duplicate: 6jan_0

    train_index = base_dataset.info.query("group =='train'").index.values
    train_uid = base_dataset.info.uid[train_index].drop_duplicates().to_list()

    if eval_unseen:
        test_index = base_dataset.info.query(
            "group =='test' and (pdb in @score_info_max['index']) and (uid not in @train_uid)").index.values
    else:
        test_index = base_dataset.info.query(
            "group =='test' and (pdb in @score_info_max['index'])").index.values
    test_dataset = base_dataset[test_index]
    test_dataloader = DataLoader(test_dataset, batch_size=cfg.engine['batch_size'],
                                 follow_batch=['x', 'y', 'compound_pair'],
                                 shuffle=False, num_workers=cfg.engine['num_workers'], pin_memory=True)

    wandb_log_dict = {}
    eval_task = get_eval_task(task)
    logger.warning(f"Begin Test>>>>>")
    eval_result_list = {key: torch.empty(0, device=device) for key in
                        cfg.task['eval_metrics'].keys()}
    eval_task.eval()
    with torch.no_grad():
        for batch_id, batch in enumerate(tqdm(test_dataloader)):
            batch = batch.to(device)
            result = eval_task.evaluate_metric(batch, eval_metrics=cfg.task['eval_metrics'].keys())
            for key in cfg.task['eval_metrics'].keys():
                eval_result_list[key] = torch.cat([eval_result_list[key], result[key].detach()])
        for key in cfg.task['eval_metrics'].keys():
            wandb_log_dict[f"Test/{key}"] = eval_result_list[key].mean()
            if key == "rmsd":
                wandb_log_dict[f"Test/rmsd_25"] = torch.quantile(eval_result_list["rmsd"], 0.25)
                wandb_log_dict[f"Test/rmsd_50"] = torch.quantile(eval_result_list["rmsd"], 0.50)
                wandb_log_dict[f"Test/rmsd_75"] = torch.quantile(eval_result_list["rmsd"], 0.75)
            if key == "centroid dis":
                wandb_log_dict[f"Test/centroid rmsd_25"] = torch.quantile(eval_result_list["centroid dis"], 0.25)
                wandb_log_dict[f"Test/centroid rmsd_50"] = torch.quantile(eval_result_list["centroid dis"], 0.50)
                wandb_log_dict[f"Test/centroid rmsd_75"] = torch.quantile(eval_result_list["centroid dis"], 0.75)

    # Write Mol
    # logger.warning(f"Begin Writing>>>>>")
    # output_mol(test_dataset, eval_task, output_folder=output_folder, wandb_log_dict=wandb_log_dict)

    wandb.log(wandb_log_dict)
    logger.warning(f"Test Finished >>>>>>>>>>>>>")


if __name__ == "__main__":
    main()

