
import gc
import io
import os
import time

import numpy as np
import tensorflow as tf
import logging
# Keep the import below for registering all model definitions
from models.encoder import Encoder
from models.recovery import Recovery
from models.conditional_ncsnpp import NCSNpp as conditional_Time_Score
import losses
import sampling
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import likelihood
import sde_lib
from absl import flags
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, save_ER_checkpoint, restore_checkpoint, extract_time, NormMinMax, restore_ER_checkpoint, scaler, inverse_scaler, restore_data
import torch.nn as nn
import torch.optim as optim
from metrics.discriminative_metrics import discriminative_score_metrics
from metrics.predictive_metrics import predictive_score_metrics
from metrics.visualization_metrics import visualization

from options import Options
from lib.data import load_data, batch_generator

FLAGS = flags.FLAGS


def train_ER(opt, ori_data, work_dir, device):
  nete = Encoder(opt).to(device)
  nete = torch.nn.DataParallel(nete, device_ids=[0, 1])
  netr = Recovery(opt).to(device)
  netr = torch.nn.DataParallel(netr, device_ids=[0, 1])
  ori_time, max_seq_len = extract_time(ori_data)
  input_data = torch.tensor(ori_data, dtype=torch.float32).to(device)
  nete.train()
  netr.train()
  optimizer_e = optim.Adam(nete.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  optimizer_r = optim.Adam(netr.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

  tb_dir = "./ER_trained/ER_tensorboard"
  tf.io.gfile.makedirs(tb_dir)
  writer = tensorboard.SummaryWriter(tb_dir)

  checkpoint_dir = f"./ER_trained/checkpoints"
  tf.io.gfile.makedirs(checkpoint_dir)

  ER_step_fn = losses.get_ER_step_fn()

  ckpt = sorted([ int(i[11:-4]) for i in os.listdir(checkpoint_dir)])[-1]
  ckpt_dir = os.path.join(checkpoint_dir, f"checkpoint_{ckpt}.pth")

  state = {'encoder':nete, 'decoder':netr, 'opt_e':optimizer_e, 'opt_r':optimizer_r}
  state = restore_ER_checkpoint(ckpt_dir, state, device=device) 

  for iter in range(opt.iteration):
    # Train for one iter
    ER_loss, max_value, min_value = ER_step_fn(state, input_data)

    if iter % 50 == 0:
      writer.add_scalar("training_ER_loss", ER_loss, iter)
      writer.add_scalar("training_ER_max", max_value, iter)
      writer.add_scalar("training_ER_min", min_value, iter)

    if (iter+1) % 10000 == 0:
      # Save the checkpoint.
      save_step = (iter+1) // 10000
      save_step += ckpt
      save_ER_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'), state)

  
def train_conditional_score(opt, ori_data, config, workdir):
    
  ER_dir = f"./ER_trained/checkpoints"
  ckpt = sorted([ int(i[11:-4]) for i in os.listdir(ER_dir)])[-1]
  ckpt_dir = os.path.join(ER_dir, f"checkpoint_{ckpt}.pth")

  device = config.device
  nete = Encoder(opt).to(device)
  nete = torch.nn.DataParallel(nete, device_ids=[0, 1])
  netr = Recovery(opt).to(device)
  netr = torch.nn.DataParallel(netr, device_ids=[0, 1])
  optimizer_e = optim.Adam(nete.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  optimizer_r = optim.Adam(netr.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

  state_ER = {'encoder': nete, 'decoder': netr, 'opt_e': optimizer_e, 'opt_r': optimizer_r}
  state_ER = restore_ER_checkpoint(ckpt_dir, state_ER, device)

  work_dir = f"./score_trained"
  tf.io.gfile.makedirs(work_dir)

  tb_dir = os.path.join(work_dir, "tensorboard")
  tf.io.gfile.makedirs(tb_dir)
  writer = tensorboard.SummaryWriter(tb_dir)

  checkpoint_dir = os.path.join(work_dir, "checkpoints")
  tf.io.gfile.makedirs(checkpoint_dir)

  # Initialize model.
  conditional_score_model = conditional_Time_Score(config).to(device)
  conditional_score_model = torch.nn.DataParallel(conditional_score_model, device_ids=[0, 1])
  ema = ExponentialMovingAverage(conditional_score_model.parameters(), decay=config.model.ema_rate)
  optimizer = losses.get_optimizer(config, conditional_score_model.parameters())
  state = dict(optimizer=optimizer, conditional_model=conditional_score_model, ema=ema, step=0)

  # Setup SDEs
  if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
  else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")

  # Build one-step training and evaluation functions
  optimize_fn = losses.optimization_manager(config)
  train_step_fn = losses.get_conditional_step_fn(sde, train=True, optimize_fn=optimize_fn)
  eval_step_fn = losses.get_conditional_step_fn(sde, train=False, optimize_fn=optimize_fn)
  ER_step_fn = losses.get_ER_step_fn()

  num_train_steps = config.training.n_iters
  initial_step = state["step"]
  prev_ori_data = ori_data

  prev_loss = 0
  prev_eval_loss = 0

  ori_time, max_seq_len = extract_time(prev_ori_data)
  with torch.no_grad():
    encoded_data = state_ER['encoder'](torch.tensor(prev_ori_data, dtype=torch.float32).to(device))
  encoded_data, min_val, max_val = NormMinMax(encoded_data.detach().cpu().numpy())
  ori_data = scaler(encoded_data)

  # In case there are multiple hosts (e.g., TPU pods), only log to host 0
  logging.info("Starting training loop at step %d." % (initial_step,))

  for step in range(initial_step, num_train_steps + 1):
    X0, T = batch_generator(ori_data, ori_time, opt.batch_size)
    batch = torch.tensor(X0, dtype=torch.float32).to(device)
    input_data = torch.tensor(prev_ori_data, dtype=torch.float32).to(device)

    # Execute one training step
    conditional_loss = train_step_fn(state, batch)

    if step % config.training.log_freq == 0:
      writer.add_scalar("training_conditional_loss", conditional_loss, step)

    if step % config.training.eval_freq == 0:
      X0, T = batch_generator(ori_data, ori_time, opt.batch_size)
      eval_batch = torch.tensor(X0, dtype=torch.float32).to(device)

      eval_loss = eval_step_fn(state, eval_batch)
      writer.add_scalar("eval_loss", eval_loss.item(), step)

    # Save a checkpoint periodically
    if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
      # Save the checkpoint if current losses are greater than previous losses
      if eval_loss<prev_eval_loss and conditional_loss<prev_loss:
        save_step = step // config.training.snapshot_freq
        print(f'save at {save_step}')
        save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint.pth'), state)
      prev_loss = conditional_loss
      prev_eval_loss = eval_loss

def cal_score(opt, prev_ori_data, config, workdir):
  #########################edit config#################################
  # config.sampling.predictor = 'adap_euler_maruyama2'
  ode_sampling = False
  samp_num = 19711
  generate_iter = 1
  #########################edit config#################################

  if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
  else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")
  
  metric_results = dict()
  total_metric_result = dict()

  work_dir = f"./score_trained"
  tf.io.gfile.makedirs(work_dir)
  checkpoint_dir = os.path.join(work_dir, "checkpoints")

  ER_dir = f"./ER_trained/checkpoints"
  ckpt = sorted([ int(i[11:-4]) for i in os.listdir(ER_dir)])[-1]
  ckpt_dir = os.path.join(ER_dir, f"checkpoint_{ckpt}.pth")

  device = config.device
  nete = Encoder(opt).to(device)
  nete = torch.nn.DataParallel(nete, device_ids=[0, 1])
  netr = Recovery(opt).to(device)
  netr = torch.nn.DataParallel(netr, device_ids=[0, 1])
  optimizer_e = optim.Adam(nete.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  optimizer_r = optim.Adam(netr.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

  state_ER = {'encoder': nete, 'decoder': netr, 'opt_e': optimizer_e, 'opt_r': optimizer_r}
  state_ER = restore_ER_checkpoint(ckpt_dir, state_ER, device)

  for samp_iter in range(generate_iter):

    for ckpt in list([5]):
      sampling_shape = (samp_num, 24, 28)
      sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, sampling_eps, ode_sampling)

      device = config.device
      conditional_score_model = conditional_Time_Score(config).to(device)
      conditional_score_model = torch.nn.DataParallel(conditional_score_model, device_ids=[0, 1])
      ema = ExponentialMovingAverage(conditional_score_model.parameters(), decay=config.model.ema_rate)
      optimizer = losses.get_optimizer(config, conditional_score_model.parameters())
      state = dict(optimizer=optimizer, conditional_model=conditional_score_model, ema=ema, step=0)
      ckpt_path = os.path.join(checkpoint_dir, f'checkpoint.pth')
      state = restore_checkpoint(ckpt_path, state, device=device)    
      # import pdb; pdb.set_trace()

      ori_time, max_seq_len = extract_time(prev_ori_data)
      with torch.no_grad():
        ori_data = state_ER['encoder'](torch.tensor(prev_ori_data, dtype=torch.float32).to(device))
      _, min_val, max_val = NormMinMax(ori_data.detach().cpu().numpy())
      # ori_max, _ = torch.max(ori_data, dim=0)
      # ori_min, _ = torch.min(ori_data, dim=0)
      
      ema.copy_to(conditional_score_model.parameters())
      sample, n = sampling_fn(conditional_score_model)
      result=sample
      generated_data_curr = inverse_scaler(result)
      generated_data_curr = generated_data_curr.detach().cpu().numpy() * max_val
      generated_data_curr = generated_data_curr + min_val
      generated_data_curr = torch.tensor(generated_data_curr).to(device)
      # generated_data_curr = restore_data(generated_data_curr, ori_max, ori_min)
      with torch.no_grad():
        generated_data_curr = state_ER['decoder'](generated_data_curr).detach().cpu().numpy()


      np.save('./save', generated_data_curr)
      generated_data = list()
      for i in range(samp_num):
        temp = generated_data_curr[i,:ori_time[i],:]
        generated_data.append(temp)
            
      prev_ori_data = np.array(prev_ori_data)

      ## Performance metrics   
      
      # 1. Discriminative Score
      discriminative_score = list()
      for _ in range(opt.metric_iteration):
        temp_disc = discriminative_score_metrics(prev_ori_data, generated_data)
        discriminative_score.append(temp_disc)
          
      metric_results[f'discriminative_{ckpt}'] = np.mean(discriminative_score)
      total_metric_result[f'discriminative_{samp_iter}_{ckpt}'] = np.mean(discriminative_score)
          
      # 2. Predictive score
      predictive_score = list()
      for tt in range(opt.metric_iteration):
        temp_pred = predictive_score_metrics(prev_ori_data, generated_data)
        predictive_score.append(temp_pred)   
          
      metric_results[f'predictive_{ckpt}'] = np.mean(predictive_score)
      total_metric_result[f'predictive_{samp_iter}_{ckpt}'] = np.mean(predictive_score)     
      # 3. Visualization (tSNE)
      visualization(prev_ori_data, generated_data, 'tsne', samp_iter=samp_iter, number=ckpt)
      
    # Print discriminative and predictive scores
      print(metric_results)
  print(total_metric_result)



