import torch

def get_diffusivity(start_pos, final_pos, rollout_steps):
    """
    Input: B x N x 3
    """
    start_pos = start_pos - start_pos.mean(1, keepdims=True)
    final_pos = final_pos - final_pos.mean(1, keepdims=True)
    msd = (final_pos - start_pos).pow(2).sum(dim=-1).mean(dim=1)
    diff = msd / rollout_steps 
    return diff

def get_diffusivity_traj(pos_seq, dilation=1):
    """
    Input: B x N x T x 3
    Output: B x T
    """
    # substract CoM
    bsize, time_steps = pos_seq.shape[0], pos_seq.shape[2]
    pos_seq = pos_seq - pos_seq.mean(1, keepdims=True)
    msd = (pos_seq[:, :, 1:] - pos_seq[:, :, 0].unsqueeze(2)).pow(2).sum(dim=-1).mean(dim=1)
    diff = msd / (torch.arange(1, time_steps)*dilation) / 6
    return diff.view(bsize, time_steps-1)

def make_rdf_type(pred_seq, data_seq, ptypes, lattices):
    
    type2indices = {
        'H': ptypes == 1,
        'O': ptypes == 8
    }
    pairs = [('H', 'H'), ('O', 'O'), ('H', 'O')]
    
    plt.subplots_adjust()
    fig, axs = plt.subplots(1, 3)
    fig.set_size_inches(15, 3)
    
    for idx in range(3):
        type1, type2 = pairs[idx]    
        indices0 = type2indices[type1]
        indices1 = type2indices[type2]
        pred_pdist = distance_pbc(pred_seq, lattices, indices0, indices1)
        data_pdist = distance_pbc(data_seq, lattices, indices0, indices1)

        ax = axs[idx % 3]
        bins = np.linspace(0.001, 6, 301)
        pred_pdist = pred_pdist.flatten().numpy()
        data_pdist = data_pdist.flatten().numpy()
        rho_pred = pred_pdist.shape[0] / torch.prod(lattices) # adjust according to time
        rho_data = data_pdist.shape[0] / torch.prod(lattices)
        pred_pdist = pred_pdist[pred_pdist != 0]
        data_pdist = data_pdist[data_pdist != 0]
        pred_hist, _ = np.histogram(pred_pdist, bins)
        data_hist, _ = np.histogram(data_pdist, bins) # can use time 
        Z_pred = rho_pred * 4 / 3 * np.pi * (bins[1:] ** 3 - bins[:-1] ** 3)
        Z_data = rho_data * 4 / 3 * np.pi * (bins[1:] ** 3 - bins[:-1] ** 3)
        xaxis = (bins[1:] + bins[:-1]) / 2
        ax.plot(xaxis, pred_hist / Z_pred, label='pred')
        ax.plot(xaxis, data_hist / Z_data, label='gt')
        if idx == 0:
            ax.legend()
        ax.set_title(f'{type1} | {type2}')
        

def get_thermo(filename):
    with open(filename, 'r') as f:
        thermo = f.read().splitlines()
        sim_time, Et, Ep, Ek, T = [], [], [], [], []
        for i in range(1, len(thermo)):
            try:
                t, Etot, Epot, Ekin, Temp = [float(x) for x in thermo[i].split(' ') if x]
                sim_time.append(t)
                Et.append(Etot)
                Ep.append(Epot)
                Ek.append(Ekin)
                T.append(Temp)
            except Exception():
                continue
    thermo = {
        'time': sim_time,
        'Et': Et,
        'Ep': Ep,
        'Ek': Ek,
        'T': T
    }
    return thermo