import os
from collections import defaultdict
import numpy as np
import copy
from scipy.stats import spearmanr
from copy import deepcopy
import safetensors
from tqdm import tqdm
import torch
from flowdock.utils.rotation import expm_SO3
from flowdock.utils.transforms import (get_batch_pred_torsion_updates,
                                        apply_tr_rot_changes_to_batch_inplace, apply_tor_changes_to_batch_inplace,
                                        get_torsion_angles, find_rigid_alignment, compute_angle_MAE)
from flowdock.utils.spyrmsd import get_symmetry_rmsd
from flowdock.utils.scoring import compute_ndcg
# import torchdiffeq


def load_from_checkpoint(model, checkpoint_path, strict=True):
    state_dict = safetensors.torch.load_file(os.path.join(checkpoint_path, 'model.safetensors'), device="cpu")
    model.load_state_dict(state_dict, strict=strict)
    return model


def euler(model, batch, device, num_steps=20):
    cur_batch = deepcopy(batch).to(device)
    h = 1. / num_steps
    batch_size = len(cur_batch)
    R_eye = torch.eye(3, device=device).repeat(batch_size, 1, 1)
    R_agg = cur_batch.ligand.init_rot
    tr = cur_batch.ligand.init_tr
    tor = cur_batch.ligand.init_tor
    cur_batch.ligand.t = torch.zeros_like(cur_batch.ligand.t)

    pred_tor_mask, pred_torsion_updates = get_batch_pred_torsion_updates(cur_batch)

    dtr_hist = []
    drot_hist = []
    tr_hist = []
    rot_hist = []
    tor_agg = torch.zeros_like(tor)
    for step in range(num_steps):
        with torch.no_grad():
            dtr, drot, dtor, _, _ = model.forward_step(cur_batch)

            tr = tr + h * dtr
            if dtor is not None:
                tor = h * dtor

                # fill predicted torsion angles (if any)
                tor[pred_tor_mask] = h * pred_torsion_updates[pred_tor_mask]
                tor_agg += tor
            if drot is not None:
                R = expm_SO3(drot, h)
                drot_hist.append(drot)
            else:
                R = R_eye

            apply_tor_changes_to_batch_inplace(cur_batch, tor, is_reverse_order=False)
            apply_tr_rot_changes_to_batch_inplace(cur_batch, tr, R)

            cur_batch.ligand.t += h
            R_agg = torch.bmm(R, R_agg)

            tr_hist.append(tr)
            rot_hist.append(R_agg)
            dtr_hist.append(dtr)

    tr_agg = tr

    dtr_hist = torch.stack(dtr_hist, dim=1)
    tr_hist = torch.stack(tr_hist, dim=1)
    if drot is not None:
        drot_hist = torch.stack(drot_hist, dim=1)
    else:
        drot_hist = None

    rot_hist = torch.stack(rot_hist, dim=1)

    return cur_batch, tr_agg, R_agg, tor_agg, dtr_hist, drot_hist, tr_hist, rot_hist


def euler_raw(model, batch, device, num_steps=20):
    cur_batch = deepcopy(batch).to(device)
    h = 1. / num_steps
    cur_batch.ligand.t = torch.zeros_like(cur_batch.ligand.t)

    tr = batch.ligand.random_pos.to(device)
    for step in range(num_steps):
        with torch.no_grad():
            dtr, _ = model.forward_step(cur_batch)[:2]
            tr = tr + h * dtr
            cur_batch.ligand.pos = tr
            cur_batch.ligand.t += h
    return cur_batch


def print_metrics_stats(metrics_dict):
    rmsd_arr = np.array([metrics['rmsd'] for item_list in metrics_dict.values() for metrics in item_list])
    print('Avg rmsd:', np.round(np.mean(rmsd_arr), 3), np.round(np.median(rmsd_arr), 3))
    print('Rmsd < 2A:', np.round((rmsd_arr < 2).mean(), 3))
    print('Rmsd < 5A:', np.round((rmsd_arr < 5).mean(), 3))

    rmsd_arr = np.array([metrics['symm_rmsd'] for item_list in metrics_dict.values() for metrics in item_list])
    print('Avg symm rmsd:', np.round(np.mean(rmsd_arr), 3), np.round(np.median(rmsd_arr), 3))
    print('Symm rmsd < 2A:', np.round((rmsd_arr < 2).mean(), 3))
    print('Symm rmsd < 5A:', np.round((rmsd_arr < 5).mean(), 3))

    tr_err_arr = np.array([metrics['tr_err'] for item_list in metrics_dict.values() for metrics in item_list])
    rot_sim_arr = np.array([metrics['rot_similarity'] for item_list in metrics_dict.values() for metrics in item_list])
    # torsion_rmse_arr = np.array([metrics['torsion_rmse'] for item_list in metrics_dict.values() for metrics in item_list])
    torsion_angles_err_arr = np.array([metrics['torsion_angles_err'] for item_list in metrics_dict.values() for metrics in item_list])
    print('Translation err:', np.round(np.mean(tr_err_arr), 3), np.round(np.median(tr_err_arr), 3), np.round((tr_err_arr < 1).mean(), 3), np.round((tr_err_arr < 2).mean(), 3))
    print('Rotation similarity:', np.round(np.mean(rot_sim_arr), 3), np.round(np.median(rot_sim_arr), 3))
    # print('Torsion RMSE:', np.round(np.mean(torsion_rmse_arr), 3), np.round(np.median(torsion_rmse_arr), 3))
    print('Torsion angles err:', np.round(np.mean(torsion_angles_err_arr), 3), np.round(np.median(torsion_angles_err_arr), 3))

# from flowdock.utils.posebusters_utils import calc_posebusters

def run_evaluation(dataloader, num_steps, solver, model, compute_metrics=True, do_print=True):
    def revert_augm(batch):
        batch.ligand.pos[:] = torch.einsum('bij,bjk->bik', batch.ligand.pos, batch.original_augm_rot)
        for batch_idx, num_atoms in enumerate(batch.ligand.num_atoms):
            batch.ligand.pos[batch_idx, num_atoms:] = 0.
        return batch.ligand.pos

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    metrics_dict = {}
    for batch in tqdm(dataloader, desc="Docking inference"):
        batch = batch['batch']
        batch_size = len(batch)
        optimized, tr_agg, R_agg, tor_agg, dtr_hist, drot_hist, tr_hist, rot_hist = solver(model, batch,
                                                                                  device=device,
                                                                                  num_steps=num_steps)

        if model.use_scoring_rollout:
            with torch.no_grad():
                _, _, _, _, scoring_pred = model.forward_step(optimized)
                _, complex_scoring_pred = scoring_pred
                complex_scoring_pred = complex_scoring_pred[:, 0].cpu().numpy()

        # compute RMSD in case of zero-padded batches
        num_batch_atoms = (~batch.ligand.is_padded_mask).sum(dim=1).to(device)

        for batch_idx, num_atoms in enumerate(optimized.ligand.num_atoms):
            optimized.ligand.pos[batch_idx, num_atoms:] = 0.

        # Normal alignment process
        aligned_batch = copy.deepcopy(batch).to(device)
        tr_aligned = torch.zeros_like(tr_agg, device=device)
        rot_aligned = torch.eye(3, device=device).repeat(tr_agg.shape[0], 1, 1)
        apply_tor_changes_to_batch_inplace(aligned_batch, tor_agg, is_reverse_order=False)
        for i in range(len(optimized.ligand.pos)):
            pos_pred = aligned_batch.ligand.pos[i, :optimized.ligand.num_atoms[i]]
            pos_true = optimized.ligand.pos[i, :optimized.ligand.num_atoms[i]]

            rot, tr = find_rigid_alignment(pos_pred, pos_true)
            tr_aligned[i] = tr
            rot_aligned[i] = rot

        apply_tr_rot_changes_to_batch_inplace(aligned_batch, tr_aligned, rot_aligned)
        tr_agg = tr_aligned
        R_agg = rot_aligned

        # compute minimum distance between protein and ligand atoms
        distances = torch.linalg.norm(optimized.ligand.pos[:, None, :] - optimized.protein.pos[:, :, None], dim=-1)
        min_distances = []
        for i in range(len(distances)):
            cur_dist = distances[i, ~optimized.protein.is_padded_mask[i]][:, ~optimized.ligand.is_padded_mask[i]]
            if len(cur_dist) > 0:
                min_distances.append(cur_dist.min())
            else:
                min_distances.append(torch.tensor(1000., device=device))
        min_distances = torch.stack(min_distances)

        # Handle tr_agg_init_coord computation
        if tr_agg is not None:
            tr_agg_init_coord = torch.bmm((optimized.original_pocket_center + tr_agg)[:, None, :],
                                          optimized.original_augm_rot)
        else:
            tr_agg_init_coord = torch.bmm(optimized.original_pocket_center[:, None, :],
                                          optimized.original_augm_rot)

        init_batch = copy.deepcopy(batch).to(device)
        init_batch.ligand.pos = optimized.ligand.pos.clone().to(device)
        transformed_orig = revert_augm(init_batch)

        compute_metrics = True
        tor_pred = tor_agg.cpu().numpy()

        if compute_metrics:
            tr_true_init = torch.bmm((optimized.original_pocket_center + optimized.ligand.final_tr)[:, None, :],
                                        optimized.original_augm_rot)

            rmsds = torch.sqrt(((optimized.ligand.true_pos - \
                                transformed_orig) ** 2).sum(axis=2).sum(axis=1) / num_batch_atoms).cpu().numpy()

            # compute translation error (l2 norm) - only if tr_agg is available
            if tr_agg is not None:
                tr_errors = torch.linalg.norm(tr_agg - optimized.ligand.final_tr, dim=1).cpu().numpy()
            else:
                tr_errors = np.zeros(batch_size)

            # compute rotation error - only if R_agg is available
            if R_agg is not None:
                rot_similarity = torch.matmul(R_agg.view(-1, 1, 9),
                                            optimized.ligand.final_rot.view(-1, 9, 1)).view(-1) / 3
            else:
                rot_similarity = torch.ones(batch_size, device=device)

            tor_true = optimized.ligand.final_tor.cpu().numpy()

        for i, name in enumerate(batch.names):
            complex_metrics = {
                'min_distance': float(min_distances[i]),
            }
            complex_metrics['orig_pos_before_augm'] = optimized.ligand.orig_pos_before_augm[i, :optimized.ligand.num_atoms[i]].cpu().numpy()
            complex_metrics['transformed_orig'] = transformed_orig[i, :optimized.ligand.num_atoms[i]].cpu().numpy()

            # Handle cases where aggregated values might be None
            if tr_agg is not None:
                complex_metrics['tr_pred_init'] = tr_agg_init_coord[i].cpu().numpy()
            else:
                complex_metrics['tr_pred_init'] = np.zeros(3)

            if R_agg is not None:
                complex_metrics['rot_pred'] = R_agg[i].cpu().numpy()
            else:
                complex_metrics['rot_pred'] = np.eye(3)
            complex_metrics['rot_augm'] = optimized.original_augm_rot[i].cpu().numpy()
            complex_metrics['full_protein_center'] = optimized.protein.full_protein_center[i].cpu().numpy()

            # compute torsion angles
            bond_properties_for_angles = {}
            bond_properties_for_angles['start'] = optimized.ligand.rotatable_bonds_ext.start[i, :optimized.ligand.num_rotatable_bonds[i]]
            bond_properties_for_angles['end'] = optimized.ligand.rotatable_bonds_ext.end[i, :optimized.ligand.num_rotatable_bonds[i]]
            bond_properties_for_angles['neighbor_of_start'] = optimized.ligand.rotatable_bonds_ext.neighbor_of_start[i, :optimized.ligand.num_rotatable_bonds[i]]
            bond_properties_for_angles['neighbor_of_end'] = optimized.ligand.rotatable_bonds_ext.neighbor_of_end[i, :optimized.ligand.num_rotatable_bonds[i]]
            bond_properties_for_angles['bond_periods'] = optimized.ligand.rotatable_bonds_ext.bond_periods[i, :optimized.ligand.num_rotatable_bonds[i]]

            torsion_angles_pred = get_torsion_angles(torch.from_numpy(np.copy(complex_metrics['transformed_orig'])).to(device),
                                                     bond_atoms_for_angles=bond_properties_for_angles)
            complex_metrics['torsion_angles_pred'] = torsion_angles_pred.cpu().numpy()

            if model.use_scoring_rollout:
                complex_metrics['model_error_estimate'] = complex_scoring_pred[i]


            if compute_metrics:
                complex_metrics['tr_true_init'] = tr_true_init[i].cpu().numpy()
                complex_metrics['rmsd'] = float(rmsds[i])
                complex_metrics['tr_err'] = float(tr_errors[i])
                complex_metrics['rot_similarity'] = float(rot_similarity[i])
                complex_metrics['rot_true'] = optimized.ligand.final_rot[i].cpu().numpy()

                torsion_angles_true = get_torsion_angles(torch.from_numpy(np.copy(complex_metrics['orig_pos_before_augm'])).to(device),
                                                     bond_atoms_for_angles=bond_properties_for_angles)
                complex_metrics['torsion_angles_true'] = torsion_angles_true.cpu().numpy()
                complex_metrics['torsion_angles_err'] = compute_angle_MAE(torsion_angles_pred, torsion_angles_true,
                                                                      bond_properties_for_angles['bond_periods'])
                complex_metrics['bond_properties_for_angles'] = {key: value.cpu().numpy() for key, value in bond_properties_for_angles.items()}

                # save ligand torsions
                start, end = optimized.ligand.tor_ptr[:-1][i], optimized.ligand.tor_ptr[1:][i]
                if tor_pred is not None:
                    complex_metrics['tor_pred'] = tor_pred[start: end]
                else:
                    complex_metrics['tor_pred'] = np.zeros(end - start)
                complex_metrics['tor_true'] = tor_true[start: end]

                complex_metrics['orig_mol'] = optimized.ligand.orig_mols[i]

                try:
                    symm_rmsd = get_symmetry_rmsd(optimized.ligand.orig_mols[i],
                                                complex_metrics['orig_pos_before_augm'],
                                                complex_metrics['transformed_orig'])
                except Exception as e:
                    symm_rmsd = complex_metrics['rmsd']
                complex_metrics['symm_rmsd'] = symm_rmsd

            metrics_dict[name] = metrics_dict.get(name, []) + [complex_metrics]

    if compute_metrics and do_print:
        print_metrics_stats(metrics_dict)
    return metrics_dict


def run_multiple_inferences(loader, num_steps, n_reps, split, checkpoint_path, solver, model, seed, fname_prefix):
    all_metrics = []

    fname = os.path.join(checkpoint_path, f'{fname_prefix}_{split}_all_rmsds_{num_steps}steps_{n_reps}runs_{seed}seed.npy')
    print('Saving results to ', fname)

    print(f'Start inference (n_runs = {n_reps}, n_steps = {num_steps}) for', checkpoint_path)
    for i in range(n_reps):
        metrics = run_evaluation(loader, num_steps=num_steps, solver=solver, model=model)
        all_metrics.append(metrics)

    np.save(fname, all_metrics)
    print()
    return all_metrics


def scoring_inference(loader, model):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    metrics_dict = {}
    tr_scores_dict = {}
    for batch in tqdm(loader, desc="Scoring"):
        batch = batch['batch'].to(device)
        with torch.no_grad():
            rmsd_pred = model.forward_step(batch)
            rmsd_pred = rmsd_pred[0]

        if model.objective == 'multiclass':
            max_prob, class_scores = torch.softmax(rmsd_pred, dim=1).max(dim=1)
            max_logit, _ = rmsd_pred.max(dim=1)
            scores = torch.concatenate([max_logit.unsqueeze(1), max_prob.unsqueeze(1), class_scores.unsqueeze(1)], dim=1)
        elif model.objective == 'regression' or model.objective == 'regression_thresholded' or model.objective == 'binary' or model.objective == 'ranking':
            scores = rmsd_pred
        else:
            print('Warning: incorrect objective', model.objective)

        for i, (name, score) in enumerate(zip(batch.names, scores.cpu().numpy())):
            metrics_dict[name] = metrics_dict.get(name, []) + [score]
    return metrics_dict, tr_scores_dict


def scoring_inference_with_metrics(loader, data_exp, checkpoint_path, model):
    def fill_nans(scores_all):
        nan_mask = np.isnan(scores_all).sum(axis=1).astype(bool)
        if nan_mask.sum() > 0:
            if model.objective == 'multiclass':
                scores_all[nan_mask, 2] = 6.
                scores_all[nan_mask, 0] = 0.
                scores_all[nan_mask, 1] = 0.
            elif model.objective == 'binary':
                scores_all[nan_mask] = 0.
            elif model.objective == 'regression':
                scores_all[nan_mask] = 50.
            elif model.objective == 'regression_thresholded':
                scores_all[nan_mask] = 5.
            elif model.objective == 'ranking':
                print('Not implemented for ranking')
                scores_all[nan_mask] = 0.
        return scores_all

    print('Start scoring for', checkpoint_path)
    save_path = os.path.join(checkpoint_path, f'{"__".join(data_exp.split(".npy")[0].split("/"))}_scoring.npy')
    print('Saving to', save_path)
    save_path_tr = os.path.join(checkpoint_path, f'{"__".join(data_exp.split(".npy")[0].split("/"))}_tr_scores.npy')
    print('Saving to', save_path_tr)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    metrics_dict = {}
    tr_scores_dict = {}
    metrics_dict_by_uid = defaultdict(dict)
    rmsds_dict_by_uid = defaultdict(dict)
    rmsds_all = []
    scores_all = []
    tr_errs_all_pred = []
    tr_errs_all_true = []
    for batch in tqdm(loader):
        batch = batch['batch'].to(device)
        with torch.no_grad():
            rmsd_pred = model.forward_step(batch)
            rmsd_pred = rmsd_pred[0]

        if model.objective == 'multiclass':
            max_prob, class_scores = torch.softmax(rmsd_pred, dim=1).max(dim=1)
            max_logit, _ = rmsd_pred.max(dim=1)
            scores = torch.concatenate([max_logit.unsqueeze(1), max_prob.unsqueeze(1), class_scores.unsqueeze(1)], dim=1)
        elif model.objective == 'regression' or model.objective == 'regression_thresholded' or model.objective == 'binary' or model.objective == 'ranking' or model.objective == 'cross_entropy' or model.objective == 'cross_entropy_with_gains':
            scores = rmsd_pred
        else:
            print('Warning: incorrect objective', model.objective)

        rmsds_all.append(batch.ligand.rmsd.cpu().numpy())
        scores_all.append(scores.cpu().numpy())

        for i, (name, rmsd, score) in enumerate(zip(batch.names, batch.ligand.rmsd.cpu().numpy(),
                                                    scores.cpu().numpy())):
            metrics_dict[name] = metrics_dict.get(name, []) + [score]
            complex_name = '_'.join(name.split('_')[:-1])
            metrics_dict_by_uid[complex_name][name] = metrics_dict_by_uid[complex_name].get(name, []) + [score]
            rmsds_dict_by_uid[complex_name][name] = rmsds_dict_by_uid[complex_name].get(name, []) + [rmsd]

    rmsds_all = np.concatenate(rmsds_all)
    scores_all = np.concatenate(scores_all, axis=0)
    scores_all = fill_nans(scores_all)
    scores_all = scores_all[:, -1]
    if len(tr_errs_all_pred) > 0:
        tr_errs_all_pred = np.concatenate(tr_errs_all_pred)[:, -1]
        tr_errs_all_true = np.concatenate(tr_errs_all_true)

    ndcg, ndcg_top5, ndcg_top1 = compute_ndcg(scores_all, rmsds_all, k=5)

    print('Corr with rmsd:', spearmanr(rmsds_all, scores_all).correlation)
    print('NDCG:', ndcg)
    print('NDCG top5:', ndcg_top5)
    print('NDCG top1:', ndcg_top1)

    if len(tr_errs_all_pred) > 0:
        print('Corr with tr_err:', spearmanr(tr_errs_all_pred, tr_errs_all_true).correlation)

    complex_corrs = []
    avg_complex_corrs = []
    ndcg_full_by_complex = []
    ndcg_top5_by_complex = []
    ndcg_top1_by_complex = []
    top_scored_rmsds = []
    for uid in metrics_dict_by_uid:
        sample_rmsds = np.array(list(rmsds_dict_by_uid[uid].values()))
        sample_scores = np.array(list(metrics_dict_by_uid[uid].values()))
        n_samples = sample_rmsds.shape[0]
        n_reps = sample_rmsds.shape[1]
        sample_scores = sample_scores.reshape(-1, sample_scores.shape[-1])
        sample_scores = fill_nans(sample_scores)
        sample_scores = sample_scores[:, -1]
        avg_rmsds = sample_rmsds[:, 0]
        avg_sample_scores = sample_scores.reshape(n_samples, n_reps).mean(axis=1)
        sample_rmsds = sample_rmsds.flatten()
        if len(sample_rmsds) > 1:
            ndcg, ndcg_top5, ndcg_top1 = compute_ndcg(sample_scores, sample_rmsds, k=5)
            top_scored_rmsds.append(sample_rmsds[np.argmax(sample_scores)])
            if len(set(sample_scores)) > 1:
                complex_corr = spearmanr(sample_rmsds, sample_scores).correlation
            else:
                complex_corr = 0.
            if len(set(avg_sample_scores)) > 1:
                avg_complex_corr = spearmanr(avg_rmsds, avg_sample_scores).correlation
            else:
                avg_complex_corr = 0.
        else:
            complex_corr = 0.
            avg_complex_corr = 0.
            ndcg = 0.
            ndcg_top5 = 0.
            ndcg_top1 = 0.
        complex_corrs.append(complex_corr)
        avg_complex_corrs.append(avg_complex_corr)
        ndcg_full_by_complex.append(ndcg)
        ndcg_top5_by_complex.append(ndcg_top5)
        ndcg_top1_by_complex.append(ndcg_top1)

    top_scored_rmsds = np.array(top_scored_rmsds)

    print('Mean and median corr by complex: ', np.mean(complex_corrs), np.median(complex_corrs))
    print(f'Mean and median avg corr (n_reps={n_reps}) by complex: ', np.mean(avg_complex_corrs), np.median(avg_complex_corrs))
    print('Mean and median ndcg by complex: ', np.mean(ndcg_full_by_complex), np.median(ndcg_full_by_complex))
    print('Mean and median ndcg top5 by complex: ', np.mean(ndcg_top5_by_complex), np.median(ndcg_top5_by_complex))
    print('Mean and median ndcg top1 by complex: ', np.mean(ndcg_top1_by_complex), np.median(ndcg_top1_by_complex))
    print('Mean and median top scored rmsds: ', np.mean(top_scored_rmsds), np.median(top_scored_rmsds), np.mean(top_scored_rmsds <= 2), np.mean(top_scored_rmsds <= 5))
    print()
    np.save(save_path, metrics_dict)
    if len(tr_scores_dict) > 0:
        np.save(save_path_tr, tr_scores_dict)
    return metrics_dict
