import os
import numpy as np
import polars as pl
import csv
import torch
from torch import optim
from time import sleep
import mdtraj as md

from data.dataset import Dataset
from utils.miscs import *
from model.classifier import *


class LVDataset(Dataset):
    def __init__(self, data_dir, mode, lag_num, idx):
        super().__init__(data_dir, mode, lag_num)
        self.sample, self.observation, self.gr_weights = self._process_data(data_dir, idx)
        (self.para_num, self.data_len) = self.gr_weights.shape
        self.n_traj = 1
        self.dim = self.sample.shape[-1]
    
    def _process_data(self, data_dir, idx, drop=0):
        gr_weights = np.load(f'{data_dir}/gr_total.npy')
        avg = np.mean(gr_weights)
        gr_weights = gr_weights[idx]
        gr_weights = gr_weights.reshape([1, -1])

        traj = np.load(f'{data_dir}/trajectory.npz')
        sample = traj['sample'][:, :, :2]
        observation = traj['observation'][:, :, :2]
        
        sample_mean = np.mean(sample, axis=1, keepdims=True)  # shape: (1, 1, 2)
        sample_std = np.std(sample, axis=1, keepdims=True)    # shape: (1, 1, 2)
        sample_std = np.where(sample_std == 0, 1.0, sample_std)
        sample = (sample - sample_mean) / sample_std
        
        observation = (observation - sample_mean) / sample_std
        
        return sample, observation, gr_weights
    
    def update_weight(self, model_weight=None):
        gr_weight_cum = np.cumsum(self.gr_weights[:, (self.K-1)*self.lag_num:], axis=1)
        gr_weight_cum = gr_weight_cum[:, self.lag_num:] - gr_weight_cum[:, :-self.lag_num]
        
        if self.K == 1:
            weight = np.exp(gr_weight_cum)
        else:
            weight = model_weight * np.exp(gr_weight_cum)
        self.weight = weight
    
    def normalized_gr(self):
        x = self.sample[:, :-self.K*self.lag_num]
        y = self.sample[:, self.K*self.lag_num:]
        obser_x = self.observation[:, :-self.K*self.lag_num]
        obser_y = self.observation[:, self.K*self.lag_num:]
        gr_weight_cum = np.cumsum(self.gr_weights, axis=1)
        gr_weight_cum = gr_weight_cum[:, self.K*self.lag_num:] - gr_weight_cum[:, :-self.K*self.lag_num] # (n_traj, len(self))
        weight = np.exp(gr_weight_cum)
        return x, y, obser_x, obser_y, weight
    

def main(lag_num, data_dir, OUTPUT_DIR, idx):
    torch.manual_seed(0)
    np.random.seed(0)

    cuda_idx = 0
    torch.cuda.set_device(cuda_idx)
    device = 'cuda:' + str(cuda_idx)

    batchsize = 2000
    dataset = LVDataset(
        data_dir=data_dir,
        mode='train', 
        lag_num=lag_num,
        idx=idx
    )
    dim = dataset.dim
    model = ClassifierMSDense2(dim=dim).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    with open(f'{OUTPUT_DIR}/mse.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['epoch', 'lagtime', 'mse', 'model_std', 'gr_std'])

    epochs = 2000
    step = 0

    if dataset.K == 1:
        dataset.update_weight()
    
    x_data, y_data, x_obs, y_obs, gr_weight_ori = dataset.normalized_gr()
    gr_weight = dataset.weight
    x_data = torch.from_numpy(x_data).float().to(device)
    y_data = torch.from_numpy(y_data).float().to(device)
    x_obs = torch.from_numpy(x_obs).float().to(device)
    y_obs = torch.from_numpy(y_obs).float().to(device)

    for epoch in range(epochs):
        model.train()
        loss_total_cls = []
        iter_num = len(dataset) // batchsize
        idx_list = np.random.permutation(len(dataset))
        for batch_iter in range(iter_num):
            idx_batch = idx_list[batch_iter*batchsize:(batch_iter+1)*batchsize]
            traj_idx = idx_batch // (dataset.data_len - dataset.K * dataset.lag_num)
            num_idx = idx_batch % (dataset.data_len - dataset.K * dataset.lag_num)
            x = x_data[traj_idx, num_idx].reshape([batchsize, dim])
            y = y_data[traj_idx, num_idx].reshape([batchsize, dim]) # (B, dim)
            weight = dataset.weight[traj_idx, num_idx].reshape(-1, 1) # (B, 1)
            weight = torch.from_numpy(weight).float().to(device)
            loss_cls = model(x, y, weight)

            optimizer.zero_grad()
            loss_cls.backward()
            optimizer.step()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            loss_total_cls.append(loss_cls.item())

        if epoch < 50: #(epoch+1) % 50 == 0 and epoch < 1000:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.95

        print('Train Epoch: {}/{} ({:.0f}%)\tLogistic Regression Loss: {:.6f}'.format(
                epoch + 1, epochs, 100. * (epoch + 1) / epochs,
                np.array(loss_total_cls).mean()))
        
        if (epoch+1) % 25 == 0:
            model.eval()
            with torch.no_grad():
                model_weight = model.update_weight(
                    x=x_data.reshape([-1, dim]),
                    y=y_data.reshape([-1, dim]),
                ).reshape([dataset.n_traj, -1])
                mse = np.mean((gr_weight - model_weight)**2)

                obs_weight = model.update_weight(
                    x=x_obs.reshape([-1, dim]),
                    y=y_obs.reshape([-1, dim]),
                ).reshape([dataset.n_traj, -1])
            with open(f'{OUTPUT_DIR}/mse.csv', 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([epoch, dataset.K*dataset.lag_num, mse.item(), np.std(model_weight).item(), np.std(gr_weight).item()])
            if (dataset.K * dataset.lag_num) % 500 == 0:
                np.save(f'{OUTPUT_DIR}/model_weight_obs_lag{(dataset.K)*dataset.lag_num}.npy', obs_weight)
                np.save(f'{OUTPUT_DIR}/model_weight_lag{(dataset.K)*dataset.lag_num}.npy', model_weight)
            dataset.K += 1
            model_weight = model_weight[:, :-dataset.lag_num]

            dataset.update_weight(model_weight)
            x_data, y_data, x_obs, y_obs, gr_weight_ori = dataset.normalized_gr()
            gr_weight = dataset.weight
            x_data = torch.from_numpy(x_data).float().to(device)
            y_data = torch.from_numpy(y_data).float().to(device)
            x_obs = torch.from_numpy(x_obs).float().to(device)
            y_obs = torch.from_numpy(y_obs).float().to(device)
            

if __name__ == '__main__':
    lag_num = 50
    data_dir = '' # TODO: add data dir
    output_dir = '' # TODO: add output dir
    idx_total = np.arange(0, 225)

    print(len(idx_total))
    for idx in idx_total:
        OUTPUT_DIR = f'{output_dir}/idx{idx}'
        if not os.path.exists(OUTPUT_DIR):
            os.makedirs(OUTPUT_DIR)
        main(lag_num, data_dir, OUTPUT_DIR, idx)