from abc import ABC, abstractmethod
import numpy as np
import pandas as pd
import polars as pl
import random
import os
import torch
from typing import Optional, Literal, Dict, Any
import mdtraj as md

from utils.miscs import compute_dihedral

class Dataset(ABC, torch.utils.data.Dataset):
    def __init__(
        self,
        data_dir: str, 
        mode: Literal['train', 'val'],
        lag_num: int,
        **kwargs,
    ):
        self.lag_num = lag_num
        self.mode = mode
        self.K = 1
        self.weight = None
        self.sample, self.gr_weights = None, None
        (self.n_traj, self.data_len, self.dim) = (None, None, None)
    
    def __len__(self):
        return (self.data_len - self.K * self.lag_num) * self.n_traj
    
    def update_weight(self, model_weight=None):
        gr_weight_cum = np.cumsum(self.gr_weights[:, (self.K-1)*self.lag_num:], axis=1)
        gr_weight_cum = gr_weight_cum[:, self.lag_num:] - gr_weight_cum[:, :-self.lag_num]
        
        if self.K == 1:
            ## normalizing to 1
            weight = np.exp(gr_weight_cum) / np.mean(np.exp(gr_weight_cum), axis=1, keepdims=True)
        else:
            weight = model_weight * np.exp(gr_weight_cum)
            weight /= np.mean(weight, axis=1, keepdims=True)
        self.weight = weight
    
    def normalized_gr(self):
        x = self.sample[:, :-self.K*self.lag_num]
        y = self.sample[:, self.K*self.lag_num:]
        gr_weight_cum = np.cumsum(self.gr_weights, axis=1)
        gr_weight_cum = gr_weight_cum[:, self.K*self.lag_num:] - gr_weight_cum[:, :-self.K*self.lag_num] # (n_traj, len(self))
        ## normalizing to 1
        weight = np.exp(gr_weight_cum) / np.mean(np.exp(gr_weight_cum), axis=1, keepdims=True)
        return x, y, weight
    
    @classmethod
    def __getitem__(self, idx):
        pass
    
    @classmethod
    def collate(batch_list):
        pass

class DWDataset(Dataset):
    def __init__(
        self,
        data_dir: str, 
        mode: Literal['train', 'val'],
        lag_num: int,
        **kwargs,
    ):
        super().__init__(data_dir, mode, lag_num)
        self.sample, self.gr_weights = self._process_npz(data_dir, drop=0)
        (self.n_traj, self.data_len, self.dim) = self.sample.shape

    def _process_npz(self, npz_path, drop=0):
        data = np.load(npz_path)
        sample = data['sample'][:, drop:]
        gr_weights = data['gr_weights'][:, drop:]
        return sample, gr_weights

    @property
    def name(self):
        return "Four well"
    
    @classmethod
    def __getitem__(self, idx):
        traj_idx = idx // (self.data_len - self.K * self.lag_num)
        num_idx = idx % (self.data_len - self.K * self.lag_num)
        x = self.sample[traj_idx, num_idx]
        y = self.sample[traj_idx, num_idx + self.K*self.lag_num]
        weight = self.weight[traj_idx, num_idx]
        data_dict = {
            'x': torch.tensor(x).float(), # (dim)
            'y': torch.tensor(y).float(), # (dim)
            'normalized_weight': torch.tensor(weight).reshape(-1).float(), 
        }
        return data_dict
    
    @classmethod
    def collate(batch_list):
        batch = {}
        for key, val in batch_list[0].items():
            batched_val = torch.stack([feat_dict[key] for feat_dict in batch_list])
            batch[key] = batched_val
        return batch


class AlanDataset(Dataset):
    def __init__(
        self,
        data_dir: str, 
        mode: Literal['train', 'val'],
        lag_num: int,
        **kwargs,
    ):
        super().__init__(data_dir, mode, lag_num)
        self.sample, self.gr_weights, self.reaction_feats = self._process_data(data_dir)
        # print(self.sample.shape, self.gr_weights.shape, self.reaction_feats.shape)
        (self.n_traj, self.data_len) = self.gr_weights.shape
        self.L, self.dim = self.sample.shape[-2:]
    
    def _process_data(self, data_dir, drop=0):
        data = md.load_dcd(f'{data_dir}/trajectory_total.dcd', top=f'{data_dir}/alanine-dipeptide.pdb')
        sample = data.xyz[None, ...] # (1, n_frames, n_atoms, 3)
        sample = sample[:, drop:]
        gr_weights = pl.read_csv(f'{data_dir}/gr_total.csv')['logM'].to_numpy()
        gr_weights = gr_weights.reshape([1, -1])
        gr_weights = -gr_weights[:, drop:]

        phi_indices = [4, 6, 8, 14]
        psi_indices = [6, 8, 14, 16]
        phi_atoms = sample[:, :, phi_indices, :].reshape(-1, 4, 3)
        psi_atoms = sample[:, :, psi_indices, :].reshape(-1, 4, 3)
        phi = compute_dihedral(phi_atoms[:, 0], phi_atoms[:, 1], phi_atoms[:, 2], phi_atoms[:, 3])
        psi = compute_dihedral(psi_atoms[:, 0], psi_atoms[:, 1], psi_atoms[:, 2], psi_atoms[:, 3])
        reaction_feats = np.concatenate([phi[..., None], psi[..., None]], axis=-1).reshape([1, -1, 2])
        sample -= np.nanmean(sample, axis=(-2, -1), keepdims=True)
        return sample, gr_weights, reaction_feats
    
    @property
    def name(self):
        return "Alanine Dipeptide"
    

def train_loader(batch_size, data_dir, lag_num):
    '''
        lag_num: delta t = 0.001, tau is composed of how many delta t
    '''
    dataset = DWDataset(data_dir=data_dir, mode='train', lag_num=lag_num)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=dataset.collate 
    )
    return dataset, dataloader

   
    
if __name__ == '__main__':
    '''
        just for test
    '''
    dataset = DWDataset(data_dir='', mode='train', lag_num=50)
    dataset.update_weight()
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=128,
        shuffle=True,
        num_workers=4,
        collate_fn=dataset.collate
    )

    for batch_iter in dataloader:
        batch = batch_iter
        print('x_t shape:', batch['x'].shape)  # (batch_size, dim)
        print('x_{t+Ktau} shape:', batch['y'].shape) # (batch_size, dim)
        print('normalized_weight shape:', batch['normalized_weight'].shape) # (batch_size, 1)

