import torch
torch.set_default_dtype(torch.float64)

import numpy as np
import pandas as pd
from torch_geometric.data import Data

import ase
from ase import units
from ase.io import read
from ase.md.verlet import VelocityVerlet
from ase.md.langevin import Langevin
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution, Stationary, ZeroRotation
from ase.calculators.calculator import Calculator, all_changes

from trainer import set_seed, load_config
from models.visnet import ViSNet


EV_PER_KCAL = 1.0 / 23.060549


def kabsch_rmsd_numpy(P, Q):
    P = np.asarray(P)
    Q = np.asarray(Q)

    # Translation (center both point sets)
    Pc = P - P.mean(axis=0)
    Qc = Q - Q.mean(axis=0)

    # Cross-covariance matrix
    H = Pc.T @ Qc

    # SVD
    U, _, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    # Correct improper rotation (reflection)
    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt.T @ U.T

    # Rotate and compute RMSD
    diff = Pc @ R - Qc
    rmsd = np.sqrt((diff ** 2).sum() / P.shape[0])
    
    return rmsd


class NNIPCalculator(Calculator):
    ''' ASE Calculator for deep learning potential '''

    implemented_properties = ["energy", "forces"]

    def __init__(self, model, energy_unit='kcal/mol', device='cpu', **kwargs):
        super().__init__(**kwargs)

        self.model = model.to(device).eval()
        self.device = device
        self.energy_unit = energy_unit.lower()

    def calculate(self, atoms=None, properties=('energy', 'forces'), system_changes=all_changes):
        super().calculate(atoms, properties, system_changes)

        z = torch.tensor(atoms.get_atomic_numbers(), dtype=torch.long, device=self.device)
        pos = torch.tensor(atoms.get_positions(), dtype=torch.get_default_dtype(), device=self.device, requires_grad='forces' in properties)

        batch = torch.zeros_like(z, dtype=torch.long, device=self.device)
        sample = Data(z=z, pos=pos, batch=batch)

        with torch.set_grad_enabled(True):
            outputs = self.model(sample)
        E, F = outputs.out, outputs.forces

        if self.energy_unit == 'kcal/mol':
            E_eV = E * EV_PER_KCAL
            F_eV = F * EV_PER_KCAL 
        else:
            raise ValueError
        
        self.results['energy'] = E_eV.detach().cpu().item()
        if 'forces' in properties:
            self.results['forces'] = F_eV.detach().cpu().numpy()

class MDLogger:
    def __init__(self, path_to_logfile, initial_pos, verbose=False):
        self.path_to_logfile = path_to_logfile
        self.log_df = pd.DataFrame(columns=['step', 'time_fs', 'E_pot_eV', 'E_kin_eV', 'E_total_eV', 'F_max', 'RMSD', 'H_ext'])

        self.verbose = verbose
        self.initial_pos = initial_pos

        self.Q_cum = 0.0

    def log_step(self, atoms, dynamic):
        step = dynamic.nsteps
        t_fs = step * dynamic.dt / units.fs
        e_pot = atoms.get_potential_energy()
        e_kin = atoms.get_kinetic_energy()
        e_total = e_pot + e_kin
        f_max = np.linalg.norm(atoms.get_forces(), axis=1).max()
        rmsd = kabsch_rmsd_numpy(self.initial_pos, atoms.get_positions())

        if isinstance(dynamic, Langevin):
            v = atoms.get_velocities()
            work = np.sum((atoms.arrays['langevin_friction'] + atoms.arrays['langevin_random']) * v) * dynamic.dt * units.fs
            self.Q_cum += work
            H_ext = e_pot + e_kin + self.Q_cum
        else:
            H_ext = 0.0

        self.log_df.loc[len(self.log_df)] = [step, t_fs, e_pot, e_kin, e_total, f_max, rmsd, H_ext]

        if self.verbose:
            print(f"{step:6d}  {t_fs:8.3f}  {e_pot:10.6f}  {e_kin:10.6f}  "
              f"{e_total:10.6f}  {f_max:10.6f} {rmsd:10.6f} {H_ext:10.6f}\n")

    def save_logs(self):
        self.log_df.to_csv(self.path_to_logfile, index=False)


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    init_config_path = 'configs/visnet/visnet_md22.yml'
    path_to_logfile = 'ase-Ac-Ala3-NHMe-traj-base-nve-100ps-0.5fs.csv'
    checkpoint_path = 'visnet-Ac-Ala3-NHMe/last.ckpt'
    path_traj = 'md22_Ac-Ala3-NHMe.xyz'


    init_config = load_config(init_config_path)
    set_seed(init_config.seed)

    # Define model
    print('Use Phi-Module:', init_config.model.use_phi_module)
    model = ViSNet(config=init_config)
    model.epoch = 'inference'
    if checkpoint_path is not None:
        state_dict = {
            k.replace('module.', '') if k.startswith('module.') else k: v
            for k, v in torch.load(checkpoint_path, map_location=device)['model_state_dict'].items() if k not in {"electrostatic_offset", "electrostatic_bias"}
        }
        print()
        model.load_state_dict(state_dict)

    model = model.double()
    print('Model is loaded')

    calc = NNIPCalculator(model=model, device=device)

    atoms = read(path_traj, index=0)
    atoms.calc = calc
    initial_pos = atoms.get_positions()
    MaxwellBoltzmannDistribution(atoms, temperature_K=300)
    Stationary(atoms)  
    ZeroRotation(atoms)

    timestep = 0.5 
    friction = 0.1 / units.fs 
    dynamic = VelocityVerlet(atoms, timestep=timestep * units.fs)
    logger = MDLogger(path_to_logfile=path_to_logfile, initial_pos=initial_pos, verbose=True)

    ps_duration = 100
    steps, log_step_size = int(ps_duration * 1000 / timestep), 100
    rmsd_values, e_total_values = [], []
    for step in range(steps): 
        dynamic.run(1)
        if step % log_step_size == 0:
            logger.log_step(atoms, dynamic)

            logger.save_logs()

    logger.save_logs()