
import gc
import io
import os
import time
import numpy as np
import tensorflow as tf
# Keep the import below for registering all model definitions
from models.encoder import Encoder
from models.recovery import Recovery
from models.conditional_ncsnpp import NCSNpp as conditional_Time_Score
import losses
import sampling
from models.ema import ExponentialMovingAverage
import sde_lib
from absl import flags
import torch
from torch.utils import tensorboard
from utils import save_checkpoint, save_ER_checkpoint, restore_checkpoint, restore_ER_checkpoint
import torch.nn as nn
import torch.optim as optim
from metrics.discriminative_metrics import discriminative_score_metrics
from metrics.predictive_metrics import predictive_score_metrics
from metrics.visualization_metrics import visualization

from data import load_data, batch_generator, NormMinMax, scaler, inverse_scaler

FLAGS = flags.FLAGS

def train_ER(arg):
  # load normalized data
  ori_data = load_data(arg)
  config = arg.config
  nete = Encoder(arg).to(config.device)
  netr = Recovery(arg).to(config.device)
  input_data, min_val, max_val = NormMinMax(ori_data)
  input_data = torch.tensor(input_data, dtype=torch.float32).to(config.device)
  nete.train()
  netr.train()
  optimizer_e = optim.Adam(nete.parameters(), lr=arg.lr, betas=(arg.beta1, 0.999))
  optimizer_r = optim.Adam(netr.parameters(), lr=arg.lr, betas=(arg.beta1, 0.999))

  tb_dir = "./ER_trained/ER_tensorboard"
  tf.io.gfile.makedirs(tb_dir)
  writer = tensorboard.SummaryWriter(tb_dir)

  checkpoint_dir = "./ER_trained/checkpoints"
  tf.io.gfile.makedirs(checkpoint_dir)
  # train encoder and decoder
  ER_step_fn = losses.get_ER_step_fn()

  state = {'encoder':nete, 'decoder':netr, 'opt_e':optimizer_e, 'opt_r':optimizer_r}

  for iter in range(arg.iteration):
    # Train for one iter
    ER_loss, max_value, min_value = ER_step_fn(state, input_data)

    if iter % 50 == 0:
      writer.add_scalar("training_ER_loss", ER_loss, iter)
      writer.add_scalar("training_ER_max", max_value, iter)
      writer.add_scalar("training_ER_min", min_value, iter)

    if (iter+1) == arg.iteration:
      # Save the checkpoint at the last moment
      save_ER_checkpoint(os.path.join(checkpoint_dir, f'checkpoint.pth'), state)

  
def train_conditional_score(arg):
  ori_data = load_data(arg)
  config = arg.config
  ER_dir = "./ER_trained/checkpoints"
  ckpt_dir = os.path.join(ER_dir, f"checkpoint.pth")

  device = config.device
  nete = Encoder(arg).to(device)
  netr = Recovery(arg).to(device)
  optimizer_e = optim.Adam(nete.parameters(), lr=arg.lr, betas=(arg.beta1, 0.999))
  optimizer_r = optim.Adam(netr.parameters(), lr=arg.lr, betas=(arg.beta1, 0.999))

  #restore pre-trained checkpoints
  state_ER = {'encoder': nete, 'decoder': netr, 'opt_e': optimizer_e, 'opt_r': optimizer_r}
  restore_ER_checkpoint(ckpt_dir, state_ER, device)

  work_dir = f"./score_trained"
  tf.io.gfile.makedirs(work_dir)

  tb_dir = os.path.join(work_dir, "tensorboard")
  tf.io.gfile.makedirs(tb_dir)
  writer = tensorboard.SummaryWriter(tb_dir)

  checkpoint_dir = os.path.join(work_dir, "checkpoints")
  tf.io.gfile.makedirs(checkpoint_dir)

  # Initialize model.
  conditional_score_model = conditional_Time_Score(arg).to(device)
  # conditional_score_model = torch.nn.DataParallel(conditional_score_model, device_ids=arg.parallel_device)
  ema = ExponentialMovingAverage(conditional_score_model.parameters(), decay=config.model.ema_rate)
  optimizer = losses.get_optimizer(config, conditional_score_model.parameters())
  state = dict(optimizer=optimizer, conditional_model=conditional_score_model, ema=ema, encoder=state_ER['encoder'], decoder=state_ER['decoder'], opt_e = state_ER['opt_e'], opt_r = state_ER['opt_r'], step=0)

  # Setup SDEs
  if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
  elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)

  # Build one-step training and evaluation functions
  optimize_fn = losses.optimization_manager(config)
  train_step_fn = losses.get_conditional_step_fn(sde, train=True, optimize_fn=optimize_fn)
  eval_step_fn = losses.get_conditional_step_fn(sde, train=False, optimize_fn=optimize_fn)
  ER_step_fn = losses.get_ER_step_fn()

  num_train_steps = config.training.n_iters
  initial_step = state["step"]

  prev_ori_data, min_val, max_val = NormMinMax(ori_data)

  prev_loss = 0
  prev_eval_loss = 0

  for step in range(initial_step, num_train_steps + 1):
    encoded_data = state['encoder'](torch.tensor(prev_ori_data, dtype=torch.float32).to(device))
    ori_data = scaler(encoded_data)
    # batch for score network
    batch = batch_generator(ori_data, arg.batch_size)
    input_data = torch.tensor(prev_ori_data, dtype=torch.float32).to(device)

    # Execute one training step alternately
    conditional_loss = train_step_fn(state, batch)
    ER_loss, max_value, min_value = ER_step_fn(state, input_data)

    if step % config.training.log_freq == 0:
      writer.add_scalar("training_conditional_loss", conditional_loss, step)
      writer.add_scalar("training_ER_loss", ER_loss, step)
      writer.add_scalar("training_ER_max", encoded_data.max(), step)
      writer.add_scalar("training_ER_min", encoded_data.min(), step)

    if step % config.training.eval_freq == 0:
      eval_batch = batch_generator(ori_data, arg.batch_size)

      eval_loss = eval_step_fn(state, eval_batch)
      writer.add_scalar("eval_loss", eval_loss.item(), step)

    # Save a checkpoint periodically
    if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
      # Save the checkpoint if current losses are greater than previous losses
      if eval_loss<prev_eval_loss and conditional_loss<prev_loss:
        save_step = step // config.training.snapshot_freq
        print(f'save at {save_step}')
        save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint.pth'), state)
      prev_loss = conditional_loss
      prev_eval_loss = eval_loss

def evaluate(arg):
  #########################edit config#################################
  ode_sampling = False # For ode sampling, but we do not use it.
  generate_iter = 1 # We calculated results on the paper by setting 10
  #########################edit config#################################
  prev_ori_data = load_data(arg)
  samp_num = prev_ori_data.shape[0]
  config = arg.config
  work_dir = f"./score_trained"
  tf.io.gfile.makedirs(work_dir)

  if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  
  metric_results = dict()
  total_metric_result = dict()

  checkpoint_dir = os.path.join(work_dir, "checkpoints")

  for samp_iter in range(generate_iter):

    sampling_shape = (samp_num, arg.seq_len, arg.hidden_dim)
    sampling_fn = sampling.get_sampling_fn(config, sde, sampling_eps, ode_sampling)

    device = config.device
    nete = Encoder(arg).to(device)
    netr = Recovery(arg).to(device)
    conditional_score_model = conditional_Time_Score(arg).to(device)
    # conditional_score_model = torch.nn.DataParallel(conditional_score_model, device_ids=arg.parallel_device)
    ema = ExponentialMovingAverage(conditional_score_model.parameters(), decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, conditional_score_model.parameters())
    state = dict(optimizer=optimizer, conditional_model=conditional_score_model, ema=ema, encoder=nete, decoder=netr, step=0)
    ckpt_path = os.path.join(checkpoint_dir, f'checkpoint.pth')
    state = restore_checkpoint(ckpt_path, state, device=device)    
    
    ema.copy_to(conditional_score_model.parameters())
    result, n = sampling_fn(conditional_score_model, sampling_shape)
    generated_data_curr = inverse_scaler(result)
    generated_data_curr = state['decoder'](generated_data_curr).detach().cpu().numpy()
    
    np.save(f'./save_{samp_iter}', generated_data_curr) # save sample
    generated_data = list()
    for i in range(samp_num):
      temp = generated_data_curr[i,:arg.seq_len,:]
      generated_data.append(temp)
            
    prev_ori_data = np.array(prev_ori_data)

    ## Performance metrics   

    # 1. Discriminative Score
    discriminative_score = list()
    for _ in range(arg.metric_iteration):
      temp_disc = discriminative_score_metrics(prev_ori_data, generated_data)
      discriminative_score.append(temp_disc)
        
    metric_results[f'discriminative'] = np.mean(discriminative_score)
    total_metric_result[f'discriminative_{samp_iter}'] = np.mean(discriminative_score)
        
    # 2. Predictive score
    predictive_score = list()
    for tt in range(arg.metric_iteration):
      temp_pred = predictive_score_metrics(prev_ori_data, generated_data)
      predictive_score.append(temp_pred)   
        
    metric_results[f'predictive'] = np.mean(predictive_score)
    total_metric_result[f'predictive_{samp_iter}'] = np.mean(predictive_score)     
    # 3. Visualization (tSNE)
    visualization(prev_ori_data, generated_data, samp_iter=samp_iter)
    
    # Print discriminative and predictive scores
    print(metric_results)
  print(total_metric_result)
