import torch
from torch.utils.tensorboard import SummaryWriter
import hydra
from tqdm import tqdm
from omegaconf import  OmegaConf
import logging
from time import time
import os
import json
import datetime
import matplotlib.pyplot as plt
import warnings
import math

from sesamo.models import SymmetryEnforcingFlow
from sesamo.loss import ReverseKL
from sesamo.action import HubbardAction, ScalarPhi4Action, ComplexPhi4Action, GaussianMixtureAction

warnings.filterwarnings("ignore", message="networkx backend defined more than once")
logger = logging.getLogger("train")
# logging.disable(logging.CRITICAL)

def flatten_json(data, parent_key="", separator="."):
    """
    Flatten a nested JSON object.
    
    Args:
        data: The JSON object to flatten (dict or list).
        parent_key: The base key string (used for recursion).
        separator: The separator for keys in the flattened JSON.

    Returns:
        A flattened dictionary.
    """
    items = []
    if isinstance(data, dict):
        for key, value in data.items():
            new_key = f"{parent_key}{separator}{key}" if parent_key else key
            items.extend(flatten_json(value, new_key, separator).items())
    elif isinstance(data, list):
        for index, value in enumerate(data):
            new_key = f"{parent_key}{separator}{index}" if parent_key else str(index)
            items.extend(flatten_json(value, new_key, separator).items())
    else:
        items.append((parent_key, data))
    return dict(items)


def clear_dir(dir):
    # walk through directory and delete all csv, pth and tensorboard files
    for root, dirs, files in os.walk(dir):
        for file in files:
            if file.endswith(".csv") or file.endswith(".pth") or file.startswith("events"):
                os.remove(os.path.join(root, file))



@hydra.main(config_path="configs", config_name="config", version_base=None)
def main(cfg):
    # get ouput directory, delete None if present and clear it
    out_dir = cfg.out_dir
    if out_dir.endswith('/None'):
        out_dir = out_dir[:-5]
    os.makedirs(out_dir, exist_ok=True)
    clear_dir(out_dir)

    # init tensorboard
    writer = SummaryWriter(out_dir)

    # check if cuda is available
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")

    # get hyperparameters from config and save them to tensorboard
    hparams = flatten_json(OmegaConf.to_container(cfg, resolve=True))
    for key, value in hparams.items():
        writer.add_text(key, str(value))

    # save dict to config.json file
    cfg_dict = OmegaConf.to_container(cfg, resolve=True)
    with open(os.path.join(out_dir, "config.json"), "w") as f:
        json.dump(cfg_dict, f, indent=4)

    # init action
    if cfg.action["type"] == "hubbard":
        action = HubbardAction(**cfg.action["hubbard_params"])
    elif cfg.action["type"] == "scalarphi4":
        action = ScalarPhi4Action(**cfg.action["scalarphi4_params"])
    elif cfg.action["type"] == "complexphi4":
        action = ComplexPhi4Action(**cfg.action["complexphi4_params"])
    elif cfg.action["type"] == "gaussianmixture":
        action = GaussianMixtureAction(**cfg.action["gaussianmixture_params"])
    else:
        raise ValueError(f"Action {cfg.action['type']} not implemented")

    # init loss
    loss = ReverseKL()
    
    # init seed
    torch.manual_seed(cfg.train["seed"])

    # init flow
    flow = SymmetryEnforcingFlow(cfg.sampler).to(device)

    # init optimizer
    optimizer = torch.optim.Adam(flow.parameters(), lr=cfg.train["lr"], amsgrad=True)

    # init scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **cfg.train["lr_scheduler_params"], threshold=1e-4)
    
    # training loop
    t0 = time()

    logger.info(f"Start training in output directory: {out_dir}")
    pbar = tqdm(range(1, cfg.train["n_steps"]+1))
    
    # start training loop
    try:
        for step in pbar:
            t0_per_step = time()

            # zero the gradients
            optimizer.zero_grad()

            if cfg.sampler["flow"] == "vmonf" or cfg.sampler["flow"][-1] == "vmonf":
                # sample from the model
                samples_c, log_probs_c = flow.sample_with_logprob(cfg.train["batch_size"])
                prob_c = flow.prob_c  # shape (sectors, batch_size)

                # compute actions for all sectors
                actions_c = torch.zeros(samples_c.shape[0:2], device=device, dtype=samples_c.dtype)
                for i in range(samples_c.shape[0]):
                    actions_c[i] = action(samples_c[i])
                
                # weighted sum over normalinzing flows with sector probabilities
                samples = torch.sum(samples_c * prob_c.reshape(-1, cfg.train["batch_size"], *[1]*(samples_c.dim()-2)), dim=0)
                actions = torch.sum(actions_c * prob_c, dim=0)
                log_probs = torch.sum(log_probs_c * prob_c, dim=0)
            
                # compute loss
                losses = loss(actions, log_probs) - torch.sum(prob_c * torch.log(prob_c + 1e-10), dim=0)
                
            else:
                # sample from the model and compute loss
                samples, log_probs = flow.sample_with_logprob(cfg.train["batch_size"])
                actions = action(samples)
                losses = loss(actions, log_probs)
            
            # apply reinforce
            if cfg.train.get("reinforce", False):
                reward = -actions - log_probs
                for single_flow in flow.flow:
                    if hasattr(single_flow, "log_modprob"):
                        log_modprob = single_flow.log_modprob
                        losses += -log_modprob * (reward.detach() - reward.detach().mean())

            # add regularization to loss
            losses += flow.regularization.mean().unsqueeze(0)

            if losses.isnan().any():
                logger.info("Loss is NaN, skipping step")
                continue

            # backpropagate the mean loss
            loss_mean = losses.mean()
            loss_std = losses.std()
            loss_mean.backward()

            # clip gradients: this can stabilize training a much
            if cfg.train["clip_grad"] != 0:
                torch.nn.utils.clip_grad_norm_(flow.parameters(), cfg.train["clip_grad"])

            # optimize
            optimizer.step()

            # update scheduler
            if scheduler is not None:
                scheduler.step(loss_std)

            # update progress bar
            pbar.set_postfix({
                'loss': loss_mean.item(),
                'loss_std': loss_std.item(),
                'action': actions.mean().item(),
            })

            # save important values in stats, will be displaed in tensorboard
            time_left_min = (time() - t0) / step * (cfg.train["n_steps"] - step) / 60
            stats = {
                'loss': loss_mean.item(),
                'loss_std': loss_std.item(),
                'loss_nan': losses.isnan().any().item(),
                'log_prob': log_probs.mean().item(),
                'action': actions.mean().item(),
                'action_std': actions.std().item(),
                'step': step,
                'time_min': (time() - t0) / 60,
                'time_left_min': time_left_min,
                'steps_per_second': 1/(time() - t0_per_step),
                'loss_penalty': flow.regularization.mean().item(),
                'lr': optimizer.param_groups[0]['lr'],
            }
            if device == 'cuda':
                stats['gpu_memory_GB'] = torch.cuda.memory_allocated() / 1024**3
            if hasattr(action, 'logZ'):
                stats['KL'] = (log_probs + actions).mean().item() + action.logZ()
                stats['logZ'] = action.logZ()
                stats['logZ_estimator'] = torch.logsumexp(-actions - log_probs, 0).item() - math.log(actions.shape[0])
            for single_flow in flow.flow:
                if hasattr(single_flow, "breaking") and single_flow.breaking.numel() == 1:
                    stats["breaking"] = single_flow.breaking.item()
                    stats["breaking_grad"] = single_flow.breaking.grad.item()

            # add ess to stats that are only calculated every <ess_interval> steps
            if step % cfg.train["ess_interval"] == 0 or step == cfg.train["n_steps"] or step == 1:
                # calculate ESS
                weights = -actions - log_probs
                weights -= weights.mean()  # normalize weights for numerical stability
                ess =  ((weights).exp().mean() ** 2 / (2 * weights).exp().mean()).item()
                stats['ess'] = ess
                
                csv_filename = os.path.join(out_dir, 'stats.csv')
                if not os.path.exists(csv_filename):
                    with open(csv_filename, 'w') as fd:
                        fd.write(','.join(stats.keys()) + '\n')
                with open(csv_filename, 'a') as fd:
                    fd.write(','.join(str(stats[key]) for key in stats.keys()) + '\n')

            # log stats to tensorboard
            for key, value in stats.items():
                writer.add_scalar(key, value, step)

            # save model
            if step % cfg.train["save_interval"] == 0:
                filename = os.path.join(out_dir, f'checkpoint_{step}.pth')
                torch.save({'optim': optimizer.state_dict(), 'net': flow.state_dict()}, filename)

    except KeyboardInterrupt:
        logger.info(f"Training interrupted by user at step {step}")

    # save final model
    filename = os.path.join(out_dir, f'checkpoint_{step}.pth')
    torch.save({'optim': optimizer.state_dict(), 'net': flow.state_dict()}, filename) 

    train_time = datetime.timedelta(seconds=int(time() - t0))
    logger.info(f"Training finished after {train_time}")

    # ===========================================
    # PLOTTING
    # ===========================================
    logger.info("Plotting results...")

    # sample configurations from the model
    if cfg.sampler["flow"] == "vmonf" or cfg.sampler["flow"][-1] == "vmonf":
        samples_c, log_probs_c = flow.sample_with_logprob(100_000)
        prob_c = flow.prob_c
        
        samples = torch.sum(samples_c * prob_c.reshape(-1, prob_c.shape[1], *[1]*(samples_c.dim()-2)), dim=0)
        actions_c = torch.zeros(samples_c.shape[0:2], device=samples_c.device, dtype=samples_c.dtype)
        for i in range(samples_c.shape[0]):
            actions_c[i] = action(samples_c[i])
        actions = torch.sum(actions_c * prob_c, dim=0)
        log_probs = torch.sum(log_probs_c * prob_c, dim=0)
    
    else:
        samples, log_probs = flow.sample_with_logprob(100_000)
        actions = action(samples)
        
    # detach samples, actions and log_probs from gpu and convert to cpu
    samples = samples.detach().cpu()
    actions = actions.detach().cpu()
    log_probs = log_probs.detach().cpu()
    
    # compute final ess
    weights = -actions - log_probs
    weights -= weights.mean()  # normalize weights for numerical stability
    ess =  ((weights).exp().mean() ** 2 / (2 * weights).exp().mean()).item()

    # reshape samples to 2d or 1d depending on the action
    if cfg.action["type"] in ["hubbard", "gaussianmixture"]: # 2d histo
        samples = samples.sum(dim=(1))
    elif cfg.action["type"] == "complexphi4": # 2d histo
        samples = samples.mean(dim=(2,3))
    elif cfg.action["type"] == "scalarphi4": # 1d histo
        samples = samples.flatten()[:500_000]

    # plot the samples
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))

    # plot 2d hist 
    if cfg.action["type"] in ["hubbard", "gaussianmixture", "complexphi4"]:
        ax.hist2d(samples[:, 0], samples[:, 1], bins=150, density=True)
        ax.set_xlabel(r"$x_1$")
        ax.set_ylabel(r"$x_2$")
    # plot 1d hist
    else:
        ax.hist(samples, bins=150, density=True)
        ax.set_xlabel(r"$\langle x \rangle$")
        ax.set_ylabel("Density")

    # add ess to the plot
    ax.text(0.05, 0.95, f"ESS: {ess:.3f}", transform=ax.transAxes, fontsize=14,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

    fig.savefig(os.path.join(out_dir, "samples.png"), dpi=300)
    logger.info(f"Saved samples plot to {os.path.join(out_dir, 'samples.png')}")


if __name__ == "__main__":
    main()