from copy import deepcopy
import os
from types import SimpleNamespace
import numpy as np
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

from diffusion import create_diffusion
from mask_generator import VideoMaskGenerator
from torch.utils.data import DataLoader
from time import time
from tqdm import tqdm

from torch.nn.utils import clip_grad_norm_
from data_utils import create_timeseries_dataset
import json

from metrics import evaluate_model, calculate_final_weighted_score


from utils import (
    create_logger,
    load_checkpoint,
    save_checkpoint,
    EMA,
    instantiate_model,
    instantiate_lrsched_optim,
)
from unconditional.utils import sample_unconditional


def load_baseline_scores(score_file):
    """
    Load baseline scores from a JSON file.
    """
    if not os.path.exists(score_file):
        return None
    with open(score_file, "r") as f:
        return json.load(f)


def save_baseline_scores(score_file, score_dict: dict):
    """
    Save scores to a JSON file.
    If the file already exists, it will be overwritten.
    """
    with open(score_file, "w") as f:
        json.dump(score_dict, f, indent=4)


def main(args: SimpleNamespace):
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."
    args.device = "cuda"
    if args.ckpt_num is not None:
        args.resume_from_ckpt = os.path.join(args.checkpoint_dir, f"{args.ckpt_num}.pt")
    else:
        args.resume_from_ckpt = None

    logger = create_logger(args.exp_dir)
    logger.info(f"Experiment directory: {args.exp_dir}")

    timeseries_data, timefreq_data, tmps_covariate = create_timeseries_dataset(
        args, logger
    )

    freq_centers = timeseries_data.get_frequencies_centers()
    cov_adj = timeseries_data.get_covariate_adj()
    freq_adj = timeseries_data.get_frequency_adj()

    train_loader = DataLoader(
        dataset=torch.utils.data.TensorDataset(timefreq_data, tmps_covariate),
        batch_size=int(args.batch_size),
        shuffle=args.shuffle_tloader,
        num_workers=4,
        persistent_workers=True,
        pin_memory=True,
        drop_last=True,
    )
    total_steps_per_epoch = len(train_loader)

    eval_idxs = np.random.choice(
        np.arange(len(timefreq_data)),
        size=len(timefreq_data) // 4,
        replace=False,
    )
    val_loader = DataLoader(
        dataset=torch.utils.data.TensorDataset(
            timefreq_data[eval_idxs], tmps_covariate[eval_idxs]
        ),
        batch_size=int(args.batch_size),
        shuffle=False,
        num_workers=4,
        persistent_workers=True,
        pin_memory=True,
        drop_last=True,
    )

    model, args = instantiate_model(
        args,
        timefreq_data.shape,
        tmps_covariate.shape,
        freq_centers=freq_centers,
        cov_adj=cov_adj,
        freq_adj=freq_adj,
    )
    ema = EMA(model=model, decay=0.999)
    scheduler, optimizer = instantiate_lrsched_optim(model, args)
    args.scheduler = scheduler.__class__.__name__

    if args.resume_from_ckpt is not None:
        start_epoch = int(os.path.basename(args.resume_from_ckpt).split(".")[0])
        start_epoch += 1  # we will resume from the next epoch

        model, ema, optimizer, scheduler = load_checkpoint(
            checkpoint_path=args.resume_from_ckpt,
            model=model,
            ema=ema,
            optimizer=optimizer,
            scheduler=scheduler,
        )

        logger.info(f"Loaded checkpoint from {args.resume_from_ckpt}")
        logger.info(f"Resuming training from epoch {start_epoch}...")
    else:
        start_epoch = 0
    
    diff_vars = vars(args.__dict__["gaussian_diffusion"])
    diff_vars["num_frames"] = args.num_frames
    diffusion = create_diffusion(
        **diff_vars,
        training=True,
    ) 

    eval_diff_vars = deepcopy(diff_vars)
    eval_diff_vars["timestep_respacing"] = "20"  # we use 20 steps for evaluation we prefer speed rather than quality
    evalution_diffusion = create_diffusion(
        **eval_diff_vars,
        training=False,
    )

    generator = VideoMaskGenerator(
        (
            (args.num_frames, args.input_size[0], args.input_size[1])
            if isinstance(args.input_size, tuple)
            else (args.num_frames, args.input_size, args.input_size)
        ),
    )
    generator.predict_given_frame_length = args.num_frames // 2
    generator.backward_given_frame_length = args.num_frames // 2

    logger.info(
        f"Starting training for {args.epochs - start_epoch} epochs with {total_steps_per_epoch} batches per epoch.\nTraining parameters:\n{args}\n"
    )

    baseline_score_path = os.path.join(args.checkpoint_dir, "baseline_scores.json")
    best_checkpoint_path = os.path.join(args.checkpoint_dir, "best.pt")

    baseline_score_dict = load_baseline_scores(baseline_score_path)
    print(
        f"Baseline scores loaded: {baseline_score_dict}"
    ) 

    iterations_no_improvement = 0
    validation_metrics_weight = vars(args.validation_metrics_weight)

    # Check if your GPU supports bfloat16
    bf16_supported = torch.cuda.is_bf16_supported()
    logger.info(f"BF16 Supported: {bf16_supported}")
    # Choose the best available dtype
    dtype = torch.bfloat16 if bf16_supported else torch.float16
    scaler = torch.amp.GradScaler(args.device, enabled= not bf16_supported) # Scaler not needed for bfloat16
    for epoch in range(start_epoch, args.epochs):
        model.train()

        logger.info(f"Beginning epoch {epoch}...")

        train_tqdm_bar = tqdm(
            enumerate(train_loader),
            total=total_steps_per_epoch,
            desc=f"Epoch {epoch}/{args.epochs}",
        )

        running_loss_terms = {}
        start_time = time()
        for i, model_input in train_tqdm_bar:

            optimizer.zero_grad()

            x, tc = model_input
            x = x.to(device=args.device, dtype=torch.float)  # (B, T, C, H, W)

            mask = generator(
                batch_size=x.shape[0],
                device=args.device,
                idx=args.mask_choice
            )
            t = torch.randint(
                0, diffusion.num_timesteps, (x.shape[0],), device=args.device
            )
            noise = torch.randn_like(x).permute(0, 2, 1, 3, 4)
            if args.time_covariates:
                model_kwargs = {"tc": tc.to(args.device, dtype=dtype)}
            else:
                model_kwargs = None

            with torch.amp.autocast(device_type=args.device, dtype=dtype):
                loss_dict = diffusion.training_losses(
                    model, x, t, model_kwargs=model_kwargs, noise=noise, mask=mask
                )
                loss = loss_dict["loss"].mean()

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            ema.update()

            if scheduler.is_batch_based():
                scheduler.step(epoch + i / total_steps_per_epoch)

            postfix_dict = {k : v.mean().item() for k,v in loss_dict.items()}

            train_tqdm_bar.set_postfix(
                **postfix_dict,
                lr=optimizer.param_groups[0]["lr"],
            )
            if len(running_loss_terms) == 0:
                running_loss_terms = {k: v / total_steps_per_epoch for k, v in postfix_dict.items()}
            else:
                running_loss_terms = {
                    k: v + postfix_dict[k] / total_steps_per_epoch for k, v in running_loss_terms.items()
                }

        if not scheduler.is_batch_based():
            scheduler.step()

        cost_time = time() - start_time
        logger.info(
            f"Epoch {epoch} | Cost time: {cost_time} seconds | Steps/Sec {total_steps_per_epoch / cost_time:.2f}"
        )
        logger.info(
                f"Epoch {epoch} | {' | '.join([f'{k}: {v:.8f}' for k, v in running_loss_terms.items()])} | Epoch end LR {optimizer.param_groups[0]['lr']:.8f}"
        )

        if (epoch + 1) % args.save_chkpt_every == 0 or epoch == args.epochs - 1:
            # Save model checkpoint for resuming training
            checkpoint_path = os.path.join(args.checkpoint_dir, f"{epoch}.pt")
            save_checkpoint(
                model=model,
                ema=ema,
                optimizer=optimizer,
                scheduler=scheduler,
                file_path=checkpoint_path,
            )
            logger.info(
                f"Epoch {epoch} | Saved training checkpoint to {checkpoint_path}"
            )

            # Evaluate the model:
            generated_timefreq_data = []
            original_timefreq_data = []
            logger.info(f"Epoch {epoch} | Evaluating model")
            eval_start_time = time()

            # Validation or testing using EMA parameters
            ema.apply_shadow()
            generated_timefreq_data, original_timefreq_data = sample_unconditional(
                model=model,
                data_loader=val_loader,
                args=args,
                mask_generator=generator,
                diffusion_sampler=evalution_diffusion,
            )
            ema.restore()  # Restore original parameters after validation

            metrics_dict = evaluate_model(
                ori_data=(
                    timeseries_data.get_timeseries_from_timefreq(original_timefreq_data)
                    .numpy()
                    .astype(np.float64)
                ),
                fake_data=(
                    timeseries_data.get_timeseries_from_timefreq(
                        generated_timefreq_data
                    )
                    .numpy()
                    .astype(np.float64)
                ),
                metrics_iterations=3,
                score_file=None,
                compute_fid=validation_metrics_weight["context_fid"] > 0,
                compute_cross_corr=validation_metrics_weight["cross_correlation"] > 0,
                compute_discriminative=validation_metrics_weight["discriminative_score"] > 0,
                compute_predictive=validation_metrics_weight["predictive_score"] > 0,
            )

            logger.info(
                f"Epoch {epoch} | Evaluation Cost Time: {time() - eval_start_time:.2f} seconds | Evaluation metrics: {metrics_dict}"
            )

            # I can do a single if statement because of the condition evaluation order
            if baseline_score_dict is None or (
                calculate_final_weighted_score(
                    current_scores=metrics_dict,
                    baseline_scores=baseline_score_dict,
                    weights=validation_metrics_weight,
                )
                <= 1.0
            ):
                logger.info(
                    f"Epoch {epoch} | New best model found. Saving scores and best checkpoint."
                )
                iterations_no_improvement = 0
                baseline_score_dict = metrics_dict
                save_checkpoint(
                    model=model,
                    ema=ema,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    file_path=best_checkpoint_path,
                )
            else:
                iterations_no_improvement += 1
                logger.info(
                    f"Epoch {epoch} | Current model is worse than baseline. Iterations without improvement: {iterations_no_improvement}"
                )
                if (
                    not args.early_stopping_patience == -1
                    and iterations_no_improvement >= args.early_stopping_patience
                ):
                    logger.info(f"Epoch {epoch} | Early stopping triggered.")
                    break

            # At each evaluation step, we save the baseline scores and min_ckpt_loss till then
            save_baseline_scores(
                score_file=baseline_score_path,
                score_dict=baseline_score_dict,
            )

    model.eval()
    logger.info("Done!")
