# Snippets of code have been taken from PIDM and NVIDIA's EDM

import argparse
import json
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np
import os
from pathlib import Path
import seaborn as sns
import time
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import yaml

from src.data_utils import Dataset, cycle
from src.utils import fix_seeds, load_model, save_model, noop
from src.utils import EDM2Loss, EMA, exists, evaluate_log_likelihood , edm_sampler #, estimate_dataset_std_sample
from src.models import SoftlyConstrainedDenoiser
from src.residuals_darcy import ResidualsDarcy

#----------------------------------------------------------------------------
# Learning rate decay schedule used in the paper "Analyzing and Improving
# the Training Dynamics of Diffusion Models".

def learning_rate_schedule(cur_nimg, batch_size, ref_lr=100e-4, ref_batches=70e3, rampup_Mimg=10):
    lr = ref_lr
    if ref_batches > 0:
        lr /= np.sqrt(max(cur_nimg / (ref_batches * batch_size), 1))
    if rampup_Mimg > 0:
        lr *= min(cur_nimg / (rampup_Mimg), 1)
    return lr

def create_parser():
    parser = argparse.ArgumentParser()
    
    # Job ID
    parser.add_argument('--job_id', type=int, default=0,
                        help='Identifier for this training job')
    parser.add_argument('--wandb_project', type=str, default='pi_diffusion',
                        help='WandB project name for tracking, if using (default: pi_diffusion)')
    parser.add_argument('--gov_eqs', type=str, default='darcy',
                        help='Set of governing equations (default: darcy)')

    # Training parameters with sensible defaults
    parser.add_argument('--seed', type=int, default=51239,
                        help='Random seed for reproducibility (default: 51239)')
    parser.add_argument('--train_iterations', type=int, default=30000,
                        help='Number of training iterations (default: 30000)')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='Training batch size (default: 64)')
    parser.add_argument('--guidance_scale', type=float, default=1.0,
                        help='Scaling parameter for gradient guidance (default: 1.0)')
    parser.add_argument('--log_freq', type=int, default=20,
                        help='Logging frequency in steps (default: 20)')
    parser.add_argument('--residual_log', type=int, default=1000,
                        help='Number of steps to evaluate residuals (default: 1000)')
    parser.add_argument('--n_mc', type=int, default=16,
                        help='Number of Markov samples to estimate gradient guidance (default: 16)')
    parser.add_argument('--model_channels', type=int, default=20,
                        help='Number hidden channels in the UNet (default: 20)')
    parser.add_argument('--num_blocks', type=int, default=8,
                        help='Number of residual blocks per resolution (default: 8)')
    parser.add_argument('--exp_lam', type=float, default=10,
                        help='Exponential weight decay for the guidance term (default: 10)')
    parser.add_argument('--diff_steps', type=int, default=100,
                        help='Number of diffusion steps (default:100)')
    parser.add_argument('--fd_acc', type=int, default=2,
                        help='Decimals of precision (default: 2)')
    parser.add_argument('--use_guidance', type=int, default=1,
                        help='Whether to use gradient guidance (default: True)')
    parser.add_argument('--use_bcs', type=int, default=1,
                        help='Whether to use boundary conditions (default: True)')
    parser.add_argument('--prepend', type=str, default='',
                        help='String to prepend to wandb run name (default: empty)')
    parser.add_argument('--pass_loss_res', type=int, default=0,
                        help='Whether to pass a residual function to the loss (default: False)')
    parser.add_argument('--log_model', type=int, default=10000,
                        help='Number of steps to decide model checkpoints (default: 10000)')

    # Optimization parameters
    parser.add_argument('--lr', type=float, default=1e-4,
                        help='Learning rate (default: 1e-4)')
    
    parser.add_argument('--rampup', type=int, default=10)

    return parser

parser = create_parser()
args = parser.parse_args()

lam = args.exp_lam
n_mc = args.n_mc
model_channels = args.model_channels
num_blocks = args.num_blocks
diff_steps = args.diff_steps

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

guidance_scale = torch.tensor(args.guidance_scale, device=device)

fix_seeds(args.seed + args.job_id)

name = f'YOUR_MODEL_NAME'
wandb_track = True # set to True to track training with wandb

load_model_flag = False # set to True to load a model
if load_model_flag:
    name = 'your_pretrained_model'
    load_path = './trained_models/' + name
    load_model_step = 0
    config = yaml.safe_load(Path(load_path, 'model', 'model.yaml').read_text())
else:
    config = yaml.safe_load(Path('model.yaml').read_text())

gov_eqs = args.gov_eqs
fd_acc = args.fd_acc # finite difference accuracy

# evaluation params
test_eval_freq =1000
sample_freq = args.residual_log
ema_start = 1000
ema = EMA(0.99)
use_double = False
no_samples = 8
eval_residuals = True
log_grads = args.use_guidance
use_guidance = args.use_guidance
residual_heatmap = True
logging_img = ["p", "K"]

# training parameters and datasets
data_paths = None
if gov_eqs == 'darcy':
    # [xi_1,xi_2] -> [p,K]
    input_dim = 2
    output_dim = 2
    pixels_at_boundary = True
    domain_length = 1.
    reverse_d1 = False # this is to be consistent with ascending coordinates in the figures
    data_paths = ('./data/darcy/train/p_standardized.csv', './data/darcy/train/logK_standardized.csv')
    data_paths_valid = ('./data/darcy/valid/p_standardized.csv', './data/darcy/valid/logK_standardized.csv')
    bcs = 'none' # 'none', 'periodic'
    pixels_per_dim = 64
    ds = Dataset(data_paths, use_double=use_double)
    ds_valid = Dataset(data_paths_valid, use_double=use_double)

    # NOTE: Values estimated from the data. Hardcoded for now, but can gather them in a dictionary and then load them in
    K_max = 45.69186865278786
    K_min = 0.0133886577644885
    K_mean = 1.3857543618069972
    p_max = 1.385958867811089
    p_min = -1.687631424742386
    logK_std = 0.8067736271701023
    K_normalizing_c = logK_std * 2
    p_std = 0.0721202269541207
    p_normalizing_c = p_std * 2
    sigma_data = 0.5
    train_batch_size = args.batch_size
    train_iterations = args.train_iterations
else:
    raise ValueError('Unknown governing equations.')

channels = np.arange(output_dim)

if use_double:
    torch.set_default_dtype(torch.float64)

dl = cycle(DataLoader(ds, batch_size = train_batch_size, shuffle=False))
dl_valid = cycle(DataLoader(ds_valid, batch_size = train_batch_size, shuffle=False))

# residual computation based on governing equations
if gov_eqs == 'darcy':
    residuals_train = ResidualsDarcy(fd_acc = fd_acc, pixels_per_dim = pixels_per_dim, \
                               pixels_at_boundary = pixels_at_boundary, reverse_d1 = reverse_d1, \
                               device = device, bcs = bcs, domain_length = domain_length, r=10.0)
    residuals_eval = ResidualsDarcy(fd_acc = fd_acc, pixels_per_dim = pixels_per_dim, \
                               pixels_at_boundary = pixels_at_boundary, reverse_d1 = reverse_d1, \
                               device = device, bcs = bcs, domain_length = domain_length, r=10.0)
else:
    raise ValueError('Unknown residuals mode.')

# model
if gov_eqs == 'darcy':
    model = SoftlyConstrainedDenoiser(constraint_f=lambda x: residuals_train.compute_residual_direct_logK(x, error_fn=lambda in_data: torch.abs(in_data),
                                                                                                logK_norm_c=K_normalizing_c,
                                                                                                p_norm_c=p_normalizing_c,
                                                                                                use_bcs=args.use_bcs,
                                                                                                rms=False),
                                       sigma_data=sigma_data,
                                       num_blocks=num_blocks,
                                       model_channels=model_channels).to(device)
else:
    raise ValueError('Unknown governing equations, cannot create model.')
if load_model_flag:
    load_model(Path(load_path, 'model', 'checkpoint_' + str(load_model_step) + '.pt'), model)


ema.register(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of trainable parameters: {num_params}')

optimizer = optim.Adam(model.parameters(), lr=args.lr)

loss_fn = EDM2Loss(P_mean=-0.6, P_std=1.2, sigma_data=sigma_data)

if wandb_track:
    import wandb
    wandb.init(project=args.wandb_project, name=name)
    log_fn = wandb.log
else:
    log_fn = noop
log_freq = args.log_freq

output_save_dir = f'./trained_models/{name}'
os.makedirs(output_save_dir, exist_ok=True)
smallest_loss = np.inf
smallest_res = np.inf

n_bins = 0
if log_grads:
    n_bins = 64
    avg_bins = torch.zeros(n_bins)

pbar = tqdm(range(train_iterations+1))

t_max = 80
t_min = 5e-2
rho = 10
train_gs = guidance_scale

train_loop_times = []
sample_loop_times = []

for iteration in pbar:
    t0_train = time.time()
    model.train()
    # phema.train()
    cur_batch = next(dl).to(device)
    optimizer.zero_grad()
    loss, bins = loss_fn(model, cur_batch, n_bins=n_bins, log_bins=log_grads and use_guidance, use_guidance=use_guidance, \
                    lam=lam, guidance_scale=train_gs, n_mc=n_mc)
    if torch.isnan(loss) or torch.isinf(loss):
                    print(f"Loss is NaN or inf: {loss.item()}")
                    torch.save(
                        {
                            "batch": cur_batch,
                            "model": model.state_dict(),
                            "loss": loss.item(),
                            "step": iteration,
                            "lr": lr,
                        },
                        os.path.join("YOUR_DEBUG_FOLDER", "debug.pt"),
                    )
                    break
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()
    # logging
    if iteration % log_freq == 0:
        pbar.set_description(f'training loss: {loss.item():.3e}')
        log_fn({'loss': loss.item()}, step=iteration)
    # ema update
    if iteration > ema_start:
        ema.update(model)

    if log_grads:
        avg_bins += bins / sample_freq

    lr = learning_rate_schedule(cur_nimg=iteration, batch_size=5, ref_lr=args.lr, ref_batches=15e3, rampup_Mimg=args.rampup)
    for g in optimizer.param_groups:
        g['lr'] = lr
    
    train_loop_times.append(time.time()-t0_train)
    with open("train_loop_times.json", 'w') as f:
        json.dump(train_loop_times, f)

    # evaluation on validation set
    model.eval()
    ema.ema(model, backup=True)
    with torch.no_grad():
        if iteration % test_eval_freq == 0 and exists(dl_valid):
            cur_test_batch = next(dl_valid).to(device)

            loss_test, bins = loss_fn(model, cur_test_batch, n_bins=n_bins, log_bins=log_grads and use_guidance, use_guidance=use_guidance, \
                                      lam=lam, guidance_scale=train_gs, n_mc=n_mc, evaluation=True)
            log_fn({'test_loss': loss_test.item()}, step=iteration)
            del bins

            print(f'test loss at iteration {iteration}: {loss_test:.3e}')
        
        t0_sampling = time.time()

        # generate and evaluate samples
        if (iteration % sample_freq == 0) or (iteration == train_iterations):
            if gov_eqs == 'darcy':
                sample_shape = (no_samples, output_dim, pixels_per_dim, pixels_per_dim)

            noise = torch.randn(sample_shape, device=device)
            seqs = edm_sampler(model, noise, num_steps=diff_steps, use_guidance=use_guidance, 
                               lam=lam, n_mc=n_mc, guidance_scale=train_gs, evaluation=True, reduce_batch=True,
                               sigma_min = t_min, sigma_max=t_max,
                            #NOTE: Uncomment below for stochastic sampler
                            #    S_churn=40, S_min=5e-2, S_max=50, S_noise=1,
                               rho=rho).detach()
            seqs[:,0] = seqs[:,0] * p_normalizing_c
            seqs[:,1] = torch.exp(seqs[:,1] * K_normalizing_c)
            if iteration > 20000:
                nll = -evaluate_log_likelihood(model, cur_test_batch, tmax=t_max, tmin=t_min, rho=rho,
                                            num_steps=diff_steps, n_mc=n_mc, use_guidance=use_guidance,
                                            lam=lam, guidance_scale=train_gs, evaluation=True, reduce_batch=True).mean()
                log_fn({'NLL': nll.detach().item()}, step=iteration)
            del noise

            if eval_residuals:
                sample_residual = residuals_eval.compute_residual_direct(seqs, reduce_batch=False, use_bcs=args.use_bcs).abs()
                residual_mean = sample_residual[:,0].mean()
                sample_residual = sample_residual[0].detach().cpu()
                log_fn({'residual_loss': residual_mean}, step=iteration)

                for seq_idx, seq in enumerate(seqs):

                    for sel_channel in channels:
                        last_pred = seq[sel_channel].detach().cpu().numpy()
                        title = f'eq: {residual_mean:.2e}'
                        ax = sns.heatmap(last_pred, fmt='.2f')
                        plt.xticks(ticks=[], labels=[])
                        plt.yticks(ticks=[], labels=[])
                        plt.title(title, color='green')
                        fig = ax.get_figure()
                        log_fn({ f'{logging_img[sel_channel]} sample' : wandb.Image(fig) }, step=iteration)
                        plt.close(fig)

            if log_grads:
                log_fn({"Tracking bins guidance term": wandb.Histogram(avg_bins, num_bins=n_bins)}, step=iteration)
                avg_bins = torch.zeros(n_bins)
            
            sample_loop_times.append(time.time() - t0_sampling)
            with open("sample_loop_times.json", 'w') as f:
                json.dump(sample_loop_times, f)

            output_save_dir_step = output_save_dir + f'/training/step_{iteration}/'
            os.makedirs(output_save_dir_step, exist_ok=True)

            if residual_heatmap:
                ax = sns.heatmap(sample_residual[0], fmt='.2f', vmax = torch.max(sample_residual[0]), norm=LogNorm())
                plt.xticks(ticks=[], labels=[])
                plt.yticks(ticks=[], labels=[])
                plt.title(title, color='green')
                fig = ax.get_figure()
                log_fn({ f'sample residual' : wandb.Image(fig) }, step=iteration)
                plt.close(fig)

            del seqs

        # if iteration % args.log_model == 0 and (loss < smallest_loss or residual_mean < smallest_res):
        if iteration % args.log_model == 0:
            # save_model(phema, iteration, output_save_dir)
            save_model(model, iteration, output_save_dir)
            smallest_loss = loss
            smallest_res = residual_mean.mean()

        ema.restore(model)

if wandb_track:
    wandb.finish()
