# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
"""Training and evaluation for score-based generative models. """
import time
from tqdm import tqdm
import numpy as np
import pandas as pd
import logging

from models import ncsnpp_tabular
import losses
import likelihood
import sampling
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets
from torch.utils.data import DataLoader
import evaluation
import sde_lib
from absl import flags
import torch
from torch.utils import tensorboard
from utils import save_checkpoint, restore_checkpoint, apply_activate
import collections
import os
from sklearn.preprocessing import MinMaxScaler
from matplotlib import pyplot as plt
FLAGS = flags.FLAGS


  
def train(config, workdir):
  randomSeed = 2021
  torch.manual_seed(randomSeed)
  torch.cuda.manual_seed(randomSeed)
  torch.cuda.manual_seed_all(randomSeed) 
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  np.random.seed(randomSeed)

  tb_dir = os.path.join(workdir, "tensorboard")
  os.makedirs(tb_dir, exist_ok=True)
  writer = tensorboard.SummaryWriter(tb_dir)

  checkpoint_dir = os.path.join(workdir, "checkpoints")
  samples_dir = os.path.join(workdir, "samples")
  sample_split_dir = os.path.join(workdir, "sample_split")
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
  os.makedirs(checkpoint_dir, exist_ok=True)
  os.makedirs(checkpoint_meta_dir, exist_ok=True)
  os.makedirs(samples_dir, exist_ok=True)
  os.makedirs(sample_split_dir, exist_ok=True)


  logging.info(config)
  print(config)

  # Build data iterators
  train_ds, eval_ds, transformer, log_probs, meta = datasets.get_dataset(config,
                                              uniform_dequantization=config.data.uniform_dequantization) 

  if meta['problem_type'] == 'binary_classification': 
    metric = 'binary_f1'
  elif meta['problem_type'] == 'regression': metric = "r2"
  else: metric = 'macro_f1'

  logging.info(f"train shape : {[train_ds[id_].shape for id_ in range(len(transformer.final_n_clusters))]}")

  if config.training.batch_size > len(train_ds[0])  :
    config.training.batch_size = len(train_ds[0])

  logging.info(f"batch size: {config.training.batch_size}")
  real_train_ds = [transformer.inverse_transform(data, id_) for id_, data in zip(transformer.final_n_clusters, train_ds)]

  if metric != "r2":
    logging.info('raw data : {}'.format([collections.Counter(i[:, -1]) for i in real_train_ds]))                                          
  
  if_belongs = [ transformer.opt_sel == id_  for id_ in transformer.final_n_clusters]
  train_iters = [list(DataLoader(np.concatenate([train_ds[id_], log_probs[id_].reshape(-1, 1), if_belongs[id_].reshape(-1, 1)], axis=1), batch_size=config.training.batch_size)) for id_ in range(len(transformer.final_n_clusters))]

  scaler = datasets.get_data_scaler(config) 
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Setup SDEs
  if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
  else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")

  continuous = config.training.continuous
  reduce_mean = config.training.reduce_mean
  likelihood_weighting = config.training.likelihood_weighting

  sampling_shape = [(np.sum(transformer.opt_sel == id_), config.data.image_size)  for id_ in transformer.final_n_clusters]
  
  total_num_params = []

  def make_fn(id_):
    range = (log_probs[id_][if_belongs[id_]][np.isfinite(log_probs[id_][if_belongs[id_]] )].min(), np.max(log_probs[id_][if_belongs[id_]]))

    sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape[id_], inverse_scaler, sampling_eps)
    optimize_fn = losses.optimization_manager(config)
    likelihood_fn = likelihood.get_likelihood_fn(sde, inverse_scaler)

    train_step_fn = losses.get_step_fn(sde, range=range, train=True, spcl=config.training.spcl, optimize_fn=optimize_fn,
                                       reduce_mean=reduce_mean, continuous=continuous,
                                       likelihood_weighting=likelihood_weighting, writer=writer)

    score_model = mutils.create_model(config)
    logging.info(score_model)

    num_params = sum(p.numel() for p in score_model.parameters())
    print("the number of parameters", num_params)
    logging.info("the number of parameters %d" % (num_params))
    total_num_params.append(num_params)

    ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, score_model.parameters(), lr=config.optim.lr)
    state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0, epoch=0)
    
    state = restore_checkpoint(os.path.join(checkpoint_meta_dir, f"checkpoint_{id_}.pth"), state, config.device)

    return {
      'sampling_fn': sampling_fn,
      "likelihood_fn": likelihood_fn,
      "optimize_fn": optimize_fn, 
      'train_step_fn': train_step_fn, 
      'state': state,
    }


  def running_training(fn, id_):

    initial_step = int(fn['state']['epoch'])
    logging.info(f"Starting training loop at epoch {initial_step}")
    scores_maxs = -99999

    running_weight_iter = torch.zeros((real_train_ds[0].shape[0]))

    for epoch in  range(initial_step, config.training.epoch+1):
      fn['state']['epoch'] += 1
        
      for iteration, batch in enumerate(train_iters[id_]):
        batch = batch.to(config.device).float()
        batch, log_prob, belong = batch[:, :config.data.image_size], batch[:, -2], batch[:, -1]

        loss, weight = fn['train_step_fn'](fn['state'], batch, log_prob, belong, id_) 
        running_weight_iter[config.training.batch_size*(iteration):config.training.batch_size*(iteration+1)] += weight.detach().cpu()
        writer.add_scalar(f"training_loss/{id_}", loss.item(), fn['state']['step'])

      logging.info("id_: %d, epoch: %d, iter: %d, training_loss: %.5e" % (id_, fn['state']['epoch'], iteration, loss.item()))


      if epoch != 0 and epoch % 1000 == 0: 
        logging.info(f"Start validation on epoch {epoch}")
        fn['state']['ema'].store(fn['state']['model'].parameters())
        fn['state']['ema'].copy_to(fn['state']['model'].parameters())
        sample, ode_nfe, dopri_step = fn['sampling_fn'](fn['state']['model'], sampling_shape=sampling_shape[id_])
        sample = apply_activate(sample, transformer.output_info)
        fn['state']['ema'].restore(fn['state']['model'].parameters())
        sample = transformer.inverse_transform(sample.cpu().numpy(), transformer.final_n_clusters[id_])
        score, _ = evaluation.compute_scores(train=real_train_ds[id_], test=None, synthesized_data=[sample], metadata=meta, eval=False)
        
        if metric != "r2":
          logging.info('sampled data: {}, real data: {}'.format(collections.Counter(sample[:,-1]), collections.Counter(real_train_ds[id_][:, -1])))

        logging.info(f'fake {id_}: {score[metric]}')
        logging.info(f"epoch: {epoch}, {metric}: {score[metric]}, ode_nfe: {ode_nfe}, dopri_step: {dopri_step}")
        writer.add_scalar(f"{metric}/{id_}", torch.tensor(score[metric]), epoch) 
        writer.add_scalar(f"ode_nfe/{id_}", ode_nfe, epoch) 
        writer.add_scalar(f"dopri_step/{id_}", dopri_step, epoch) 


        if scores_maxs <  torch.tensor(score[metric]):
          scores_maxs = torch.tensor(score[metric])
          save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{id_}_max.pth'), fn['state'])

        save_checkpoint(os.path.join(checkpoint_meta_dir, f"checkpoint_{id_}.pth"), fn['state']) # save the model each epoch

    running_weight_iter = running_weight_iter / fn['state']['epoch']
    fig = plt.figure()
    avgweight_density = plt.scatter(running_weight_iter, log_probs[id_].reshape(-1, 1))
    writer.add_figure(f"avgweight_density_scatter/{id_}", fig, 0)


  for id in range(len(transformer.final_n_clusters)):
    print(f"Start training on {id}")
    logging.info(f"Start training on {id}")
    model_fn = make_fn(id)
    running_training(fn=model_fn, id_=id)


  logging.info(f"total num of params : {total_num_params}")
  logging.info(f"total num of params : {sum(total_num_params)}")

  sample_list = []
  num_sampling_rounds = 5
  nfes = {}
  times = {}
  dopri_steps = {}


  for id_ in range(len(transformer.final_n_clusters)):
    test_fn = make_fn(id_)

    ckpt_filename = os.path.join(checkpoint_dir, f'checkpoint_{id_}_max.pth')
    test_fn['state'] = restore_checkpoint(ckpt_filename, test_fn['state'], device=config.device)
    logging.info(f"checkpoint : {test_fn['state']['step']}")
    score_model = test_fn['state']['model']

    nfes[id_] = []
    times[id_] = []
    dopri_steps[id_] = []

    entire_data = []
    for r in range(num_sampling_rounds):   
      samples, ode_nfe, dopri_step  = test_fn['sampling_fn'](score_model, sampling_shape=sampling_shape[id_])
      nfes[id_].append(ode_nfe)
      dopri_steps[id_].append(dopri_step)

      samples = apply_activate(samples, transformer.output_info)
      entire_data.append(transformer.inverse_transform(samples.cpu().numpy(), transformer.final_n_clusters[id_]))

    sample_list.append(entire_data)

  final_sample = []
  for i in range(num_sampling_rounds):
    temp = []
    for j in range(len(transformer.final_n_clusters)):
      pd.DataFrame(sample_list[j][i]).to_csv(f"{sample_split_dir}/id_{j}_num_{i}.csv", header=False, index=False)
      temp.append(sample_list[j][i])
    final_sample.append(np.concatenate(temp))

  scores, stds = evaluation.compute_scores(train=real_train_ds[0], test=eval_ds, synthesized_data=final_sample, metadata=meta, eval=True)
  diversity_mean, diversity_std = evaluation.compute_diversity(train=real_train_ds[0], fake=final_sample)

  scores['coverage'] = diversity_mean['coverage']
  scores['density'] = diversity_mean['density']
  stds['coverage'] = diversity_std['coverage']
  stds['density'] = diversity_std['density']
  
  for id_ in range(len(transformer.final_n_clusters)):
    scores[f"nfe_{id_}"] = np.mean(nfes[id_])
    scores[f"dopri_step_{id_}"] = np.mean(dopri_steps[id_])

    stds[f"nfe_{id_}"] = np.std(nfes[id_])
    stds[f"dopri_step_{id_}"] = np.std(dopri_steps[id_])
    
  logging.info(f"{scores}")
  logging.info(f"{stds}")
  print(f"{scores}")
  print(f"{stds}")

  logging.info(f"total num of params : {total_num_params}")
  logging.info(f"total num of params : {sum(total_num_params)}")
  print(f"total num of params : {total_num_params}")
  print(f"total num of params : {sum(total_num_params)}")

