import os
import numpy as np
import polars as pl
import csv
import torch
from torch import optim
from time import sleep
import mdtraj as md

from data.dataset import Dataset
from utils.miscs import *
from model.classifier import *

class OUDataset(Dataset):
    def __init__(self, data_dir, mode, lag_num, idx):
        super().__init__(data_dir, mode, lag_num)
        self.sample, self.observation, self.gr_weights = self._process_data(data_dir, idx)
        (self.n_traj, self.data_len) = self.gr_weights.shape
        self.dim = self.sample.shape[-1]
    
    def _process_data(self, data_dir, idx, drop=0):
        gr_weights = np.load(f'{data_dir}/gr_total.npy')
        avg = np.mean(gr_weights)
        gr_weights = gr_weights[idx]
        gr_weights = gr_weights.reshape([1, -1])

        traj = np.load(f'{data_dir}/trajectory.npz')
        sample = traj['sample'][:, :, :2]
        observation = traj['observation'][:, :, :2]
        return sample, observation, gr_weights
    
    def update_weight(self, model_weight=None):
        gr_weight_cum = np.cumsum(self.gr_weights[:, (self.K-1)*self.lag_num:], axis=1)
        gr_weight_cum = gr_weight_cum[:, self.lag_num:] - gr_weight_cum[:, :-self.lag_num]
        
        if self.K == 1:
            weight = np.exp(gr_weight_cum)
        else:
            weight = model_weight * np.exp(gr_weight_cum)
        self.weight = weight
    
    def normalized_gr(self):
        x = self.sample[:, :-self.K*self.lag_num]
        y = self.sample[:, self.K*self.lag_num:]
        obser_x = self.observation[:, :-self.K*self.lag_num]
        obser_y = self.observation[:, self.K*self.lag_num:]
        gr_weight_cum = np.cumsum(self.gr_weights, axis=1)
        gr_weight_cum = gr_weight_cum[:, self.K*self.lag_num:] - gr_weight_cum[:, :-self.K*self.lag_num] # (n_traj, len(self))
        weight = np.exp(gr_weight_cum)
        return x, y, obser_x, obser_y, weight
    

def main(lag_num, data_dir, OUTPUT_DIR, idx):
    torch.manual_seed(0)
    np.random.seed(0)

    cuda_idx = 0
    torch.cuda.set_device(cuda_idx)
    device = 'cuda:' + str(cuda_idx)

    batchsize = 2000
    dataset = OUDataset(
        data_dir=data_dir,
        mode='train', 
        lag_num=lag_num,
        idx=idx
    )
    dim = dataset.dim
    model = ClassifierMSDense2(dim=dim).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    with open(f'{OUTPUT_DIR}/mse.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'lagtime', 'mse', 'model_std', 'gr_std']) 

    epochs = 1000
    step = 0

    if dataset.K == 1:
        dataset.update_weight()
    
    x_data, y_data, x_obs, y_obs, gr_weight_ori = dataset.normalized_gr()
    gr_weight = dataset.weight
    x_data = torch.from_numpy(x_data).float().to(device)
    y_data = torch.from_numpy(y_data).float().to(device)
    x_obs = torch.from_numpy(x_obs).float().to(device)
    y_obs = torch.from_numpy(y_obs).float().to(device)

    for epoch in range(epochs):
        model.train()
        loss_total_cls = []
        iter_num = len(dataset) // batchsize
        idx_list = np.random.permutation(len(dataset))
        for batch_iter in range(iter_num):
            idx_batch = idx_list[batch_iter*batchsize:(batch_iter+1)*batchsize]
            traj_idx = idx_batch // (dataset.data_len - dataset.K * dataset.lag_num)
            num_idx = idx_batch % (dataset.data_len - dataset.K * dataset.lag_num)
            x = x_data[traj_idx, num_idx].reshape([batchsize, dim])
            y = y_data[traj_idx, num_idx].reshape([batchsize, dim]) # (B, dim)
            weight = dataset.weight[traj_idx, num_idx].reshape(-1, 1) # (B, 1)
            weight = torch.from_numpy(weight).float().to(device)
            loss_cls = model(x, y, weight)

            optimizer.zero_grad()
            loss_cls.backward()
            optimizer.step()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            loss_total_cls.append(loss_cls.item())

        if epoch <= 50:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.95

        print('Train Epoch: {}/{} ({:.0f}%)\tLogistic Regression Loss: {:.6f}'.format(
                epoch + 1, epochs, 100. * (epoch + 1) / epochs,
                np.array(loss_total_cls).mean()))
        
        if (epoch+1) % 25 == 0:
            model.eval()
            with torch.no_grad():
                model_weight = model.update_weight(
                    x=x_data.reshape([-1, dim]),
                    y=y_data.reshape([-1, dim]),
                ).reshape([dataset.n_traj, -1])
                mse = np.mean((gr_weight - model_weight)**2)

                obs_weight = model.update_weight(
                    x=x_obs.reshape([-1, dim]),
                    y=y_obs.reshape([-1, dim]),
                ).reshape([dataset.n_traj, -1])
            with open(f'{OUTPUT_DIR}/mse.csv', 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([epoch, dataset.K*dataset.lag_num, mse.item(), np.std(model_weight).item(), np.std(gr_weight).item()])
            
            np.save(f'{OUTPUT_DIR}/model_weight_obs_lag{(dataset.K)*dataset.lag_num}.npy', obs_weight)
            np.save(f'{OUTPUT_DIR}/model_weight_lag{(dataset.K)*dataset.lag_num}.npy', model_weight)
            dataset.K += 1
            model_weight = model_weight[:, :-dataset.lag_num]
            
            dataset.update_weight(model_weight)
            x_data, y_data, x_obs, y_obs, gr_weight_ori = dataset.normalized_gr()
            gr_weight = dataset.weight
            x_data = torch.from_numpy(x_data).float().to(device)
            y_data = torch.from_numpy(y_data).float().to(device)
            x_obs = torch.from_numpy(x_obs).float().to(device)
            y_obs = torch.from_numpy(y_obs).float().to(device)

        if (epoch+1) % 100 == 0:
            torch.save(model.state_dict(), f'{OUTPUT_DIR}/lagtime{(dataset.K-1)*dataset.lag_num}.pt')    

if __name__ == '__main__':
    lag_num = 100
    data_dir = '' # TODO: add data dir
    output_dir = '' # TODO: add output dir
    idx_total = np.arange(0, 225)

    for idx in idx_total:
        OUTPUT_DIR = f'{output_dir}/idx{idx}'
        if not os.path.exists(OUTPUT_DIR):
            os.makedirs(OUTPUT_DIR)
        main(lag_num, data_dir, OUTPUT_DIR, idx)