import os
import numpy as np
import polars as pl
import csv
import torch
from torch import optim
from time import sleep
import mdtraj as md

from data.dataset import Dataset
from utils.internal_coordinates import InternalCoordinates
from utils.miscs import *
from model.classifier import *

angles_num = 2
lr = 1e-4

class AlanInterAnglesDataset(Dataset):
    def __init__(self, data_dir, mode, lag_num):
        super().__init__(data_dir, mode, lag_num)
        self.sample, self.gr_weights = self._process_data(data_dir)
        (self.n_traj, self.data_len) = self.gr_weights.shape
        self.L, self.dim = self.sample.shape[-2:]
    
    def _process_data(self, data_dir, drop=0):
        gr_weights = pl.read_csv(f'{data_dir}/gr_total.csv')['logM'].to_numpy()
        gr_weights = gr_weights.reshape([1, -1])
        gr_weights = - gr_weights[:, drop:]

        data = md.load_dcd(f'{data_dir}/trajectory_total.dcd', top=f'{data_dir}/alanine-dipeptide.pdb')
        sample = data.xyz[None, ...] # (1, n_frames, n_atoms, 3)
        sample = sample[:, drop:]

        ic = InternalCoordinates(f'{data_dir}/alanine-dipeptide.pdb')
        angles = ic.get_all_angles(sample)[None, :, :angles_num]
        
        return angles, gr_weights
    

def main(lag_num, data_dir, OUTPUT_DIR, device):    
    dataset = AlanInterAnglesDataset(
        data_dir=data_dir,
        mode='train', 
        lag_num=lag_num
    )
    dim = dataset.dim

    batchsize = 1024
    model = ClassifierMSDense2(dim=dim).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)

    with open(f'{OUTPUT_DIR}/mse.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'lagtime', 'mse', 'model_std', 'gr_std'])

    epochs = 1000
    step = 0

    if dataset.K == 1:
        dataset.update_weight()
    
    x_data, y_data, gr_weight_ori = dataset.normalized_gr()
    gr_weight = dataset.weight
    x_data = torch.from_numpy(x_data).float().to(device)
    y_data = torch.from_numpy(y_data).float().to(device)

    for epoch in range(epochs):
        model.train()
        loss_total_cls = []
        iter_num = len(dataset) // batchsize
        idx_list = np.random.permutation(len(dataset))
        for batch_iter in tqdm(range(iter_num)):
            idx = idx_list[batch_iter*batchsize:(batch_iter+1)*batchsize]
            traj_idx = idx // (dataset.data_len - dataset.K * dataset.lag_num)
            num_idx = idx % (dataset.data_len - dataset.K * dataset.lag_num)
            x = x_data[traj_idx, num_idx].reshape([batchsize, dim])
            y = y_data[traj_idx, num_idx].reshape([batchsize, dim]) # (B, dim)
            weight = dataset.weight[traj_idx, num_idx].reshape(-1, 1) # (B, 1)
            weight = torch.from_numpy(weight).float().to(device)
            loss_cls = model(x, y, weight)

            optimizer.zero_grad()
            loss_cls.backward()
            optimizer.step()
            loss_total_cls.append(loss_cls.item())

        if epoch <= 50:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.95

        print('Train Epoch: {}/{} ({:.0f}%)\tLogistic Regression Loss: {:.6f}'.format(
                epoch + 1, epochs, 100. * (epoch + 1) / epochs,
                np.array(loss_total_cls).mean()))
        
        if (epoch+1) % 20 == 0:
            model.eval()
            with torch.no_grad():
                model_weight = model.update_weight(
                    x=x_data.reshape([-1, dim]),
                    y=y_data.reshape([-1, dim]),
                ).reshape([dataset.n_traj, -1])
                model_weight /= np.mean(model_weight, axis=1, keepdims=True)
                gr_weight_eval = gr_weight / np.mean(gr_weight, axis=1, keepdims=True)
                mse = np.mean((gr_weight_eval - model_weight)**2)
            with open(f'{OUTPUT_DIR}/mse.csv', 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([epoch, dataset.K*dataset.lag_num, mse.item(), np.std(model_weight).item(), np.std(gr_weight).item()])
            
            np.save(f'{OUTPUT_DIR}/model_weight_lag{(dataset.K)*dataset.lag_num}.npy', model_weight)
            dataset.K += 1
            model_weight = model_weight[:, :-dataset.lag_num]
            dataset.update_weight(model_weight)
            x_data, y_data, gr_weight_ori = dataset.normalized_gr()
            gr_weight = dataset.weight
            x_data = torch.from_numpy(x_data).float().to(device)
            y_data = torch.from_numpy(y_data).float().to(device)

    torch.save(model.state_dict(), f'{OUTPUT_DIR}/last.pt')    

if __name__ == '__main__':
    cuda_idx = 0
    torch.cuda.set_device(cuda_idx)
    device = 'cuda:' + str(cuda_idx)

    lag_num = 1
    data_dir = '' # TODO: add data dir

    OUTPUT_DIR = '' # TODO: add output dir
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
    main(lag_num, data_dir, OUTPUT_DIR, device)