
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Conv1d, ReLU, Linear, ConvTranspose1d
import pathlib
import argparse
from gru_d import GRUD
from cal_score import cal_score

def sample_normal_jit(mu, sigma):
    rho = mu.mul(0).normal_()
    z = rho.mul_(sigma).add_(mu)
    return z, rho

class Normal:
    def __init__(self, mu, log_sigma):
        self.mu = mu
        self.log_sigma = log_sigma
        self.sigma = torch.exp(log_sigma)

    def sample(self, t=1.):
        return sample_normal_jit(self.mu, self.sigma * t)

    def log_p(self, samples):
        normalized_samples = (samples - self.mu) / self.sigma
        log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.log_sigma
        return log_p

def log_p_standard_normal(samples):
    log_p = - 0.5 * torch.square(samples) - 0.9189385332  # 0.5 * np.log(2 * np.pi)
    return log_p

def vae_terms(log_q_conv, eps):

    log_p_conv = log_p_standard_normal(eps) #p(z)
    kl_per_var = log_q_conv - log_p_conv #p(z|x)-p(z)
    kl_all = torch.sum(kl_per_var, dim=1)
    log_q = torch.sum(log_q_conv, dim=1)
    log_p = torch.sum(log_p_conv, dim=1)
    return log_q, log_p, kl_all

def load_data(dir, device='cuda'):
    tensors = {}
    for filename in os.listdir(dir):
        if filename.endswith('.pt'):
            tensor_name = filename.split('.')[0]
            tensor_value = torch.load(str(dir / filename),map_location=device)
            tensors[tensor_name] = tensor_value
    return tensors

def save_data(dir, **tensors):
    for tensor_name, tensor_value in tensors.items():
        torch.save(tensor_value, str(dir / tensor_name) + '.pt')

def normalize(data):
    numerator = data - np.min(data, 0)
    denominator = np.max(data, 0) - np.min(data, 0)
    # numerator = data-np.min(data)
    # denominator = np.max(data) - np.min(data)
    norm_data = numerator / (denominator + 1e-7)
    return norm_data

class TimeDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, seq_len, missing_rate=0.0):
        data = np.loadtxt(data_path, delimiter=",", skiprows=1)
        total_length = len(data)
        data = data[::-1]
        
        self.min_val = np.min(data, 0)
        self.max_val = np.max(data, 0) - np.min(data, 0)
        self.original_sample = []
        ori_seq_data = []

        norm_data = normalize(data)
        total_length = len(norm_data)

        for i in range(len(norm_data) - seq_len + 1):
            x = norm_data[i : i + seq_len].copy()
            ori_seq_data.append(np.expand_dims(x,axis=0))
        ori_seq_data = torch.from_numpy(np.concatenate(ori_seq_data, axis=0))
        idx = torch.randperm(len(ori_seq_data))
        self.original_sample = ori_seq_data[idx].to(torch.float32)

        self.X_mean = np.mean(np.array(self.original_sample),axis=0).reshape(1,np.array(self.original_sample).shape[1],np.array(self.original_sample).shape[2])
        generator = torch.Generator().manual_seed(56789)
        removed_points = torch.randperm(norm_data.shape[0], generator=generator)[:int(norm_data.shape[0] * missing_rate)].sort().values
        norm_data[removed_points] = float('nan')
        total_length = len(norm_data)
        index = np.array(range(total_length)).reshape(-1,1)
        norm_data = np.concatenate((norm_data,index),axis=1)
        seq_data = []
        for i in range(len(norm_data) - seq_len + 1):
            x = norm_data[i : i + seq_len]
            seq_data.append(np.expand_dims(x,axis=0))
        seq_data = torch.from_numpy(np.concatenate(seq_data, axis=0))
        self.samples = seq_data[idx].to(torch.float32)
        self.size = len(self.samples)

    def __getitem__(self, index):
        return self.samples[index],self.original_sample[index]

    def __len__(self):
        return len(self.samples)

def to_tensor(data):
    return torch.from_numpy(data).float()

class VariationalAutoencoderConv(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.hidden_layer_sizes = [100, 200]
        self.feat_dim = feat_dim = args.input_size
        self.seq_len = seq_len = args.seq_len
        self.latent_dim = args.latent_dim
        self.encoder_last_dense_dim = feat_dim*self.hidden_layer_sizes[-1]
        self.reconstruction_wt = args.reconstruction_wt

        get_mu_encoder = []
        in_channels = seq_len
        for num_filters in self.hidden_layer_sizes:
            out_channels = num_filters
            get_mu_encoder.append(Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding='same'))
            get_mu_encoder.append(ReLU())
            in_channels = out_channels
        get_mu_encoder.append(Linear(in_channels*feat_dim, self.latent_dim))

        get_sig_encoder = []
        in_channels = seq_len
        for num_filters in self.hidden_layer_sizes:
            out_channels = num_filters
            get_sig_encoder.append(Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding='same'))
            get_sig_encoder.append(ReLU())
            in_channels = out_channels
        get_sig_encoder.append(Linear(in_channels*feat_dim, self.latent_dim))

        self.mu_encoder = nn.ModuleList(get_mu_encoder)
        self.sig_encoder = nn.ModuleList(get_sig_encoder)

        get_mu_decoder = []
        in_channels = self.hidden_layer_sizes[-1]
        get_mu_decoder.append(Linear(self.latent_dim, self.encoder_last_dense_dim))
        get_mu_decoder.append(ReLU())
        out_channels = self.hidden_layer_sizes[0]
        get_mu_decoder.append(ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1))
        get_mu_decoder.append(ReLU())
        in_channels = out_channels
        get_mu_decoder.append(Linear(feat_dim*self.hidden_layer_sizes[0], seq_len*feat_dim))

        get_sig_decoder = []
        in_channels = self.hidden_layer_sizes[-1]
        get_sig_decoder.append(Linear(self.latent_dim, self.encoder_last_dense_dim))
        get_sig_decoder.append(ReLU())
        out_channels = self.hidden_layer_sizes[0]
        get_sig_decoder.append(ConvTranspose1d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1))
        get_sig_decoder.append(ReLU())
        in_channels = out_channels
        get_sig_decoder.append(Linear(feat_dim*self.hidden_layer_sizes[0], seq_len*feat_dim))

        self.mu_decoder = nn.ModuleList(get_mu_decoder)
        self.sig_decoder = nn.ModuleList(get_sig_decoder)

        # self.initialize(self.mu_encoder)
        # self.initialize(self.sig_encoder)
        # self.initialize(self.mu_decoder)
        # self.initialize(self.sig_decoder)
        self.embedder = GRUD(feat_dim, args.hidden_size, dataset.X_mean, device=device)

    def forward(self, x):

        h = self.embedder(x)
        idx = 0
        mu_z = self.mu_encoder[idx](h)
        logvar_z = self.sig_encoder[idx](h)
        idx += 1
        mu_z = self.mu_encoder[idx](mu_z)
        logvar_z = self.sig_encoder[idx](logvar_z)
        idx += 1

        mu_z = self.mu_encoder[idx](mu_z)
        logvar_z = self.sig_encoder[idx](logvar_z)
        idx += 1
        mu_z = self.mu_encoder[idx](mu_z)
        logvar_z = self.sig_encoder[idx](logvar_z)
        idx += 1
        
        mu_z = mu_z.reshape(x.shape[0],-1)
        logvar_z = logvar_z.reshape(x.shape[0],-1)
        mu_z = self.mu_encoder[idx](mu_z)
        logvar_z = self.sig_encoder[idx](logvar_z)
        ######## encoder ########

        dist = Normal(mu_z, logvar_z)
        z, _ = dist.sample()
        log_q_conv = dist.log_p(z)
        
        idx = 0
        mu_x = self.mu_decoder[idx](z).reshape(x.shape[0],self.hidden_layer_sizes[-1],self.feat_dim)
        logvar_x = self.sig_decoder[idx](z).reshape(x.shape[0],self.hidden_layer_sizes[-1],self.feat_dim)
        idx += 1
        mu_x = self.mu_decoder[idx](mu_x)
        logvar_x = self.sig_decoder[idx](logvar_x)
        idx += 1
        mu_x = self.mu_decoder[idx](mu_x)
        logvar_x = self.sig_decoder[idx](logvar_x)
        idx += 1
        mu_x = self.mu_decoder[idx](mu_x)
        logvar_x = self.sig_decoder[idx](logvar_x)
        idx += 1
        mu_x = mu_x.reshape(x.shape[0],-1)
        logvar_x = logvar_x.reshape(x.shape[0],-1)
        mu_x = self.mu_decoder[idx](mu_x)
        logvar_x = self.sig_decoder[idx](logvar_x)
        idx += 1

        mu_x = mu_x.reshape(x.shape[0],self.seq_len,self.feat_dim)
        logvar_x = logvar_x.reshape(x.shape[0],self.seq_len,self.feat_dim)
        logits = torch.cat([mu_x,logvar_x], dim=1)

        return logits, log_q_conv, z, dist

    def decoder_output(self, logits):
        logits = logits.div(5.).tanh_().mul(5.)
        mu, log_sigma = torch.chunk(logits, 2, dim=1)
        return Normal(mu, log_sigma)

    def sampling(self, z):
        idx = 0
        mu_x = self.mu_decoder[idx](z).reshape(z.shape[0],self.hidden_layer_sizes[-1],self.feat_dim)
        logvar_x = self.sig_decoder[idx](z).reshape(z.shape[0],self.hidden_layer_sizes[-1],self.feat_dim)
        idx += 1
        mu_x = self.mu_decoder[idx](mu_x)
        logvar_x = self.sig_decoder[idx](logvar_x)
        idx += 1
        mu_x = self.mu_decoder[idx](mu_x)
        logvar_x = self.sig_decoder[idx](logvar_x)
        idx += 1
        mu_x = self.mu_decoder[idx](mu_x)
        logvar_x = self.sig_decoder[idx](logvar_x)
        idx += 1
        mu_x = mu_x.reshape(z.shape[0],-1)
        logvar_x = logvar_x.reshape(z.shape[0],-1)
        mu_x = self.mu_decoder[idx](mu_x)
        logvar_x = self.sig_decoder[idx](logvar_x)
        idx += 1

        mu_x = mu_x.reshape(z.shape[0],self.seq_len,self.feat_dim)
        logvar_x = logvar_x.reshape(z.shape[0],self.seq_len,self.feat_dim)
        dist = Normal(mu_x, logvar_x)
        output, _ = dist.sample()
        return output

    def initialize(self, modules):
        for layer in modules:
            try:
                nn.init.normal_(layer.weight,0.,1.)
            except:
                pass

def loss(model, x, result):
    criterion = torch.nn.MSELoss()
    logits, all_log_q, all_eps, _ = result
    log_q, log_p, kl_all = vae_terms(all_log_q, all_eps)
    output, _ = model.decoder_output(logits).sample()

    mask = x[:,2,:,:]
    x_no_nan = x[:,0,:,:]
    x_tilde_no_nan = output*mask
    recon_err2 = criterion(x_tilde_no_nan, x_no_nan)
    loss = torch.mean(kl_all) + model.reconstruction_wt*recon_err2

    return loss, recon_err2

def train(args, dataset):

    model = VariationalAutoencoderConv(args).to(args.device)
    oprimizer = torch.optim.Adam(model.parameters())
    max_steps = args.max_steps
    batch_size = args.batch_size
    device = args.device
    
    for step in range(1, max_steps + 1):
        batch,original_batch = batch_generator(dataset, batch_size)
        batch = batch.to(device)
        original_batch = original_batch.to(device)
        x = prepare_irregular(batch, original_batch, device=device)

        result = model(x)

        loss_vae, recon = loss(model, x, result)
        oprimizer.zero_grad()
        loss_vae.backward()
        oprimizer.step()
        if step % 100 == 0:
            print(
                "step: "
                + str(step)
                + "/"
                + str(max_steps)
                + ", loss_e: "
                + str(np.round(loss_vae.item(), 4))
                + ", recon: "
                + str(np.round(recon.item(), 4))
            )
        if step % 5000 == 0:
            dataset_size = dataset.size
            with torch.no_grad():
                _, original_batch = batch_generator(dataset, dataset_size)
                z = torch.randn(size=(original_batch.shape[0],model.latent_dim), device=args.device)
                result = model.sampling(z).detach().cpu()
            metric_results = cal_score(args, original_batch , result, dataset_size)
            print(f'{step}/{max_steps} : {metric_results}')

def test(args, dataset):
    model = VariationalAutoencoderConv(args).to(args.device)
    model.load_state_dict(torch.load(f"{args.save_dir}/model.pt"))
    dataset_size = dataset.size
    with torch.no_grad():
        _, original_batch = batch_generator(dataset, dataset_size)
        z = torch.randn(size=(original_batch.shape[0],model.latent_dim), device=args.device)
        result = model.sampling(z)
    cal_score(args, original_batch , result, dataset_size)

def batch_generator(dataset, batch_size):
    dataset_size = len(dataset)
    idx = torch.randperm(dataset_size)
    batch_idx = idx[:batch_size]
    batch = torch.stack([dataset[i][0] for i in batch_idx])
    ori_batch = torch.stack([dataset[i][1] for i in batch_idx])
    return batch, ori_batch

def prepare_irregular(batch, original_batch, device='cuda'):
    x = original_batch.cpu()
    time = batch[:,:,-1:].cpu()
    miss_x = batch[:,:,:-1].cpu()
    Mask = torch.isnan(miss_x)
    miss_x[Mask] = 0.0
    S = np.zeros_like(x)
    for i in range(S.shape[1]):
        S[:,i,:] = time[0][i]
    Delta = np.zeros_like(x)
    for i in range(1,S.shape[1]):
        Delta[:,i,:] = S[:,i,:] - S[:,i-1,:]
    missing_index = np.where(Mask == True)

    X_last_obsv = miss_x.clone()
    for idx in range(missing_index[0].shape[0]):
        i = missing_index[0][idx] 
        j = missing_index[1][idx]
        k = missing_index[2][idx]
        if j != 0 and j != missing_index[1].max():
            Delta[i,j+1,k] = Delta[i,j+1,k] + Delta[i,j,k]
        if j != 0:
            X_last_obsv[i,j,k] = X_last_obsv[i,j-1,k]
    Delta = torch.tensor(Delta / Delta.max())
    return torch.stack([miss_x.to(device),X_last_obsv.to(device),(~Mask).to(device),Delta.to(device)],dim=1)

#####################################################################################################
#####################################################################################################


if __name__ == '__main__':
    parser = argparse.ArgumentParser("TimeVAE")
    parser.add_argument("--model1", type=str, default="timevae")
    parser.add_argument("--data", type=str, default="energy")
    parser.add_argument("--seq_len", type=int, default=24)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--max_steps", type=int, default=40000)
    parser.add_argument("--max_steps_metric", type=int, default=10)
    parser.add_argument("--train", default=True)
    parser.add_argument("--save_dir", type=str, default='test')
    parser.add_argument("--missing_value",type=float,default=0.7)
    parser.add_argument("--device_idx", type=int, default=3)
    parser.add_argument("--latent_dim", type=int, default=2)
    parser.add_argument("--reconstruction_wt", type=int, default=3)

    here = pathlib.Path(__file__).resolve().parent
    args = parser.parse_args()
    # args.aug_mapping = True
    device_idx = args.device_idx
    args.device = device = torch.device(f'cuda:{device_idx}' if torch.cuda.is_available() else "cpu")

    print(device)
    if args.data == 'air':
        data_path = here / 'data/airquality_data.csv'
        args.input_size = 13
        args.hidden_size = 39
    elif args.data == 'stock':
        data_path = here / 'data/stock_data.csv'
        args.input_size = 6
        args.hidden_size = 24
    elif args.data == 'energy':
        data_path = here / 'data/energy_data.csv'
        args.input_size = 28
        args.hidden_size = 56
    elif args.data == 'ai4i':
        data_path = here / 'data/ai4i_data.csv'
        args.input_size = 5
        args.hidden_size = 20

    dataset = TimeDataset(data_path, args.seq_len, args.missing_value)  
    args.num_layers = 24
    path = here
    args.save_dir = path / args.save_dir
    os.makedirs(args.save_dir, exist_ok=True)

    train(args, dataset)
    # test(args, dataset)
