import time

from dataloaders.data_loader import Moving_MNIST_Loader
from dataloaders.shanghaitech import ShanghaiTechLoader
from dataloaders.ucsd import UCSD_Loader
from dataloaders.yup import YUP_Loader
from model.model_reconstruction import Unfolding_RNN
import model.constraints as constraints
import model.metrics as metrics

import numpy as np
from scipy.io import loadmat
import scipy.io as sio
import yaml
from yaml import SafeLoader
import argparse
from datetime import datetime
import os
from os.path import join
import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import cv2
import gc
from torch.autograd.functional import jacobian

import wandb

from dataloaders.utils import (
    assemble_patches,
    create_patches,
)

from transformers import ViTImageProcessor
from model import vit


BLOWUP_THRESHOLD = 1e10


def add_noise(batch_raw, perturb_frac, data_std, growing_perturb_coeff=1.0, normalized=False):
    """
    Add noise to the batch based on perturb_frac and growing_perturb_coeff parameters.
    The noise level is scaled by growing_perturb_coeff, which can increase over epochs
    to progressively make training more difficult.
    The actual noise added is: noise * growing_perturb_coeff, where noise ~ N(0, data_std * perturb_frac).
    If perturb_frac=0 or growing_perturb_coeff=0.0, no noise is added.
    
    For normalized data, the noise scale is adjusted accordingly.
    """
    if perturb_frac > 0:
        # For normalized data, use a smaller noise scale since data is typically in [-1, 1] range
        # instead of [0, 255]
        if normalized:
            # Use a scale factor appropriate for normalized data
            noise_scale = perturb_frac
        else:
            noise_scale = data_std * perturb_frac
            
        noise = torch.normal(
            0,
            noise_scale,
            size=batch_raw.shape,
            device=batch_raw.device,
        )
        return batch_raw + noise * growing_perturb_coeff
    return batch_raw


class Frame_Reconstruction:

    def __init__(self, configs):
        self.configs = configs
        self.dtype = torch.float32
        self.metric = torch.nn.MSELoss()
        self.constrained = configs["constrained"]
        self.normalize = configs.get("normalize", False)
        
        # Set flag based on model type
        if configs["model"].lower() == "vit":
            self.using_vit_processor = True
        else:
            self.using_vit_processor = False
        
        # Initialize flag for using input energy as baseline
        self.configs["use_input_energy_baseline"] = configs.get("use_input_energy_baseline", False)

        if "device" in configs:
            self.device = torch.device(configs["device"])
        elif torch.cuda.is_available():
            print("Running on CUDA")
            self.device = torch.device("cuda")
        # MPS
        elif torch.backends.mps.is_available():
            print("Running on MPS")
            self.device = torch.device("mps")
        else:
            print("Running on CPU")
            self.device = torch.device("cpu")

        if configs["model"].lower() in [
            "sista",
            "l1l1",
            "reweighted",
            "gru",
            "lstm",
            "rnn",
            "dust",
            "dust_vec",
            "unrolled_transformer",
        ]:
            D_init = loadmat(configs["D_init_file_path"])["dict"].astype(
                "float32"
            )  # (F,H)

            # To support both single D (1,F,H) and varying D per layer (K,F,H)
            D_init = np.expand_dims(D_init, axis=0)
            if configs["diff_D"]:
                D_init = D_init.repeat(configs["num_layers"], axis=0)

            if configs["model"].lower() == "vit" and configs["compression_factor"] > 1:
                print(f"WARNING: Received model ViT and compression factor of {configs['compression_factor']} factor is ignored for ViT models.")
                
            n_input = int(configs["num_features"] / configs["compression_factor"])
            A_init = np.asarray(
                np.random.uniform(
                    low=-np.sqrt(6.0 / (n_input + configs["num_features"])),
                    high=np.sqrt(6.0 / (n_input + configs["num_features"])),
                    size=(n_input, configs["num_features"]),
                )
                / 2.0,
                dtype=np.float32,
            )

            self.model = Unfolding_RNN(A_init, D_init, configs).to(self.device)

            # For debugging, harmless
            self.A_init = A_init
            self.D_init = D_init

        elif configs["model"].lower() == "vit":
            print("Initializing ViT model for reconstruction...")
            self.model = vit.PretrainedViTForReconstructionNoPreprocessing(configs).to(self.device)

        else:
            print("not recognised model, exitting...")
            exit()

        if configs["resume"] is True:
            assert os.path.exists(configs["checkpoint"])
            temp = torch.load(
                configs["checkpoint"] + "/" + configs["run_name"] + "/best_model.pth"
            )
            self.model.load_state_dict(
                torch.load(
                    configs["checkpoint"]
                    + "/"
                    + configs["run_name"]
                    + "/best_model.pth"
                )
            )
            print("Checkpoint at {} successfully loaded.".format(configs["checkpoint"]))

    def train(self, data_loader, include_test=True):
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.configs["lr"],
            weight_decay=self.configs["weight_decay"],
        )

        self.scheduler = ReduceLROnPlateau(
            self.optimizer, mode="min", factor=0.3, patience=5, verbose=True
        )

        self.metric = torch.nn.MSELoss()
        
        if self.normalize and self.using_vit_processor:
            self.processor = data_loader.vit_processor

        best_val_psnr = 0
        if self.constrained:
            # Double the size of nu if redundant_constraints is enabled
            num_constraints = self.model.config["num_layers"] * (2 if self.configs.get("redundant_constraints", False) else 1)
            
            nu = nn.Parameter(
                torch.ones(
                    num_constraints,
                    requires_grad=True,
                    device=self.device,
                )
                * self.configs["dual_init"]
            )

            if self.configs["resilience_optimizer"] == "adam":
                self.dual_optimizer = torch.optim.Adam(
                    [nu],
                    lr=self.configs["lr_dual"],
                )
            elif self.configs["resilience_optimizer"] == "sgd":
                self.dual_optimizer = torch.optim.SGD(
                    [nu],
                    lr=self.configs["lr_dual"],
                )
            else:
                raise ValueError(
                    f"Resilience optimizer {self.configs['resilience_optimizer']} not supported"
                )
        else:
            # For unconstrained training, create a dummy nu of appropriate size
            num_constraints = self.model.config["num_layers"] * (2 if self.configs.get("redundant_constraints", False) else 1)
            nu = torch.zeros(
                num_constraints, requires_grad=False, device=self.device
            )
        slacks = (
            torch.zeros(num_constraints, device=self.device)
            if self.constrained
            else None
        )  # resilient

        if self.configs["use_input_energy_baseline"]:
            # If using input energy baseline, we'll compute it per batch in train_one_epoch
            initial_energy = None
            print("Using noisy input energy as baseline for initial constraint")
        elif self.configs["initial_energy"] is None:
            initial_energy = self._infer_initial_energy(data_loader, num_batches=500)
        else:
            initial_energy = self.configs["initial_energy"]


        wandb.watch(self.model, log_freq=2000)

        # Get noise schedule from configs
        noise_schedule = self.configs.get('noise_schedule', 'growing')

        for epoch in range(1, self.configs["num_epoch"] + 1):
            print("Training epoch {}".format(epoch))

            # Set growing_perturb_coeff based on noise schedule
            if noise_schedule == 'constant':
                growing_perturb_coeff = 1.0  # No growth in noise
            else:  # 'growing'
                growing_perturb_coeff = float(epoch - 1)

            train_loss, train_psnr, nu, _ = self.train_one_epoch(
            #train_loss, train_psnr, nu, slacks = self.train_one_epoch(
                data_loader,
                update_duals=epoch > self.configs["unconstr_warmup_epochs"],
                current_epoch=epoch,
                nu=nu,
                slacks=slacks,
                initial_energy=initial_energy,
                growing_perturb_coeff=growing_perturb_coeff,
            )

            # Count parameters
            model_parameters = filter(
                lambda p: p.requires_grad, self.model.parameters()
            )
            params = sum([np.prod(p.size()) for p in model_parameters])
            print("Current model has {} trainable parameters".format(params))

            if train_loss > BLOWUP_THRESHOLD or torch.isnan(torch.tensor(train_loss)):
                print(f"epoch {epoch}: LOSS EXPLODED")
                print(f"train_loss: {train_loss}")
                if self.configs.get("break_if_exploding", False):
                    print("Breaking training due to loss explosion")
                    return
                print("Continuing training despite loss explosion")

            # when training, always test on clean data.
            val_loss, val_psnr = self.validation(data_loader, perturb_frac_test=0.0)

            print(f"Running OOD validation")
            self._evaluate_ood_on_split(
                data_loader, split="val", num_batches=100
            )  # quick OOD Validation

            wandb.log(
                {
                    "train_loss": train_loss,
                    "train_psnr": train_psnr,
                    "val_loss": val_loss,
                    "val_psnr": val_psnr,
                    "epoch": epoch,
                },
                commit=False,
            )

            # Model checkpoint
            if val_psnr > best_val_psnr:
                best_val_psnr = val_psnr
                model_path = f"{self.configs['checkpoint']}/{self.configs['run_name']}/best_model.pth"
                print(f"Saving best model to {model_path}")
                torch.save(self.model.state_dict(), model_path)
                wandb.save(model_path, policy="now")

                print("val_psnr increased -> Model checkpoint saved.")
            model_path_checkpoint = f"{self.configs['checkpoint']}/{self.configs['run_name']}/model_e{epoch}.pth"
            torch.save(self.model.state_dict(), model_path_checkpoint)
            wandb.save(model_path_checkpoint, policy="now")
            print(f"Saved model at epoch {epoch}")

            self.scheduler.step(val_loss)
            if epoch % self.configs["display_each"] == 0:
                to_print = "epoch: {}, train loss: {}, train psnr: {}, val loss: {}, val psnr: {}".format(
                    epoch, train_loss, train_psnr, val_loss, val_psnr
                )
                if self.configs["model"].lower() in [
                    "sista",
                    "l1l1",
                    "reweighted",
                ]: 
                    to_print += "\nld0: {}, ld1: {}, ld2: {}, sparsity: {}, compressed input range: {} to {} ".format(
                        self.model.lambda0,
                        self.model.lambda1,
                        self.model.lambda2,
                        self.model.sparsity,
                        self.model.now_input.min(),
                        self.model.now_input.max(),
                    )
                print(to_print)

        # if self.configs['model'].lower() in ['rnn', 'reweighted', 'l1l1', 'sista']:
        #     self._save_network_statistics(data_loader)
        if include_test:
            # when training, always test on clean data.
            test_loss, test_psnr = self.test(
                data_loader,
                perturb_frac_test=0.0,
                growing_perturb_coeff=0.0,
                log_wandb=True,
            )
            # Run OOD analysis
            self.ood_analysis(data_loader)

    def train_one_epoch(
        self,
        data_loader,
        update_duals,
        current_epoch,
        nu=None,
        slacks=None,
        initial_energy=None,
        growing_perturb_coeff=1.0,
    ):  # train_one_epoch
        """
        Trains the model for one epoch using the given data loader.

        Args:
            update_duals (bool): Whether to update the dual variables (in case we want warmup)
            nu (torch.Tensor, optional): Dual variables if training constrained
            slacks (torch.Tensor, optional): Slack variables for resilience
            initial_energy (float, optional): Pre-computed initial energy value (if not using input energy baseline)
            growing_perturb_coeff (float): Coefficient for growing perturbation
        Returns:
            tuple: A tuple containing the average loss, average PSNR, and updated dual variables (nu and slacks).
        """

        psnr_batch = []
        loss_batch = []
        iterator = tqdm(
            range(int(data_loader.train.shape[1] / self.configs["batch_size"]))
        )
        if not update_duals:
            print("Not updating duals this epoch")
        for i in iterator:
            # torch.autograd.set_detect_anomaly(True)
            self.optimizer.zero_grad()
            batch = data_loader.load_batch_train(self.configs["batch_size"])
            input = torch.tensor(batch, dtype=self.dtype, device=self.device)

            noisy_input = self.add_noise(
                input,
                self.configs["perturb"],
                data_loader.train_data_std,
                growing_perturb_coeff=growing_perturb_coeff,
            )
            output = self.model.forward(noisy_input)

            if self.configs["l0_loss"] > 0:
                raw_loss = self.compound_loss(
                    input,
                    output,
                    self.model.sparse_code,
                    torch.tensor(self.configs["l0_loss"]).cuda(),
                )
            else:
                raw_loss = self.metric(input, output)

            

            # Compute energy at input using identity reconstruction
            # This serves as a baseline of what the worst-case denoising should be
            batch_size, feature_dim = noisy_input.size(0), noisy_input.size(-1)
            dummy_sparse_code = noisy_input.reshape(noisy_input.size(0), 1, -1)  # Reshape to match expected (B,1,F)
            identity_dict = torch.eye(feature_dim, device=noisy_input.device).unsqueeze(0)  # (1,F,F) for broadcasting
            
            energy_at_input = constraints.compute_layerwise_metrics(
                input,  # Clean signal as reference
                dummy_sparse_code,
                identity_dict,
                self.model.lambda1,
                self.model.lambda2,
                constraint_type=self.configs["constraint_type"] if self.configs["constraint_type"] != "none" else self.configs["energy_to_report"],
                reconstructions=noisy_input.unsqueeze(1),  # Add layer dimension (B,1,F)                
                vit_processor=self.processor if (self.normalize and self.using_vit_processor) else None
            )

            # For the initial constraint with unprocessed input
            if self.configs["use_input_energy_baseline"]:
                # Use the energy of unprocessed input as baseline for first layer
                input_baseline = energy_at_input
            else:
                # Use the pre-computed or specified initial energy
                input_baseline = initial_energy

            wandb.log({"energy_at_input": energy_at_input, "initial_energy_baseline": input_baseline}, commit=False)
            
            energy_per_l = constraints.compute_layerwise_metrics(
                input,
                self.model.sparse_code,
                self.model.D,
                self.model.lambda1,
                self.model.lambda2,
                constraint_type=self.configs["constraint_type"] if self.configs["constraint_type"] != "none" else self.configs["energy_to_report"],
                reconstructions=self.model.reconstructions,
                vit_processor=self.processor if (self.normalize and self.using_vit_processor) else None
            )

            energy_diffs = (
                energy_per_l[1:] - self.configs["constraint_alpha"] * energy_per_l[:-1]
            )

            # After calculating the regular energy_diffs but before adding redundant constraints
            # Add the initial constraint: f(H_1) ≤ α·f(X̃)
            initial_constraint = energy_per_l[0:1] - self.configs["constraint_alpha"] * input_baseline
            wandb.log({"initial_constraint": initial_constraint.item()}, commit=False)
            wandb.log({"constr_violation@layer1": initial_constraint.item()}, commit=False)

            energy_diffs = torch.concatenate([initial_constraint, energy_diffs])

            # Log the regular layerwise constraint violations
            for i in range(len(energy_diffs) - 1):  # Skip the first one as we already logged it
                wandb.log({f"constr_violation@layer{i+2}": energy_diffs[i+1].item()}, commit=False)

            # Add redundant constraints if enabled
            if self.configs.get("redundant_constraints", False):
                # Create powers of alpha for each layer: [α^0, α^1, α^2, ...]
                layer_indices = torch.arange(energy_per_l.shape[0], device=energy_per_l.device)
                alpha_powers = self.configs["constraint_alpha"] ** layer_indices
                
                # Calculate all redundant constraints in one go
                redundant_constraints = energy_per_l - alpha_powers * input_baseline
                
                # Log each constraint violation
                for i in range(energy_per_l.shape[0]):
                    wandb.log({"redundant_constr_violation@layer{}".format(i+1): redundant_constraints[i].item()}, commit=False)
                
                # Concatenate with the original constraints
                energy_diffs = torch.concatenate([energy_diffs, redundant_constraints])

            detached_energy_diffs = energy_diffs.detach()

            loss = raw_loss + nu @ energy_diffs
            loss.backward()

            if self.configs["gradient_clip"] is True:
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.25)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.25)
            self.optimizer.step()

            if self.constrained and update_duals:
                # NEW Update duals
                self.dual_optimizer.zero_grad()
                resilience_alpha = self.configs["resilience_coefficient"]
                coeff = 1 / (2 * resilience_alpha)

                if self.configs["resilience_mode"] == "l2_wd":
                    dual_loss = nu @ detached_energy_diffs - coeff * nu.norm(p=2) ** 2
                elif self.configs["resilience_mode"] == "l2":
                    # L2 resilience
                    dual_loss = nu @ detached_energy_diffs - nu @ slacks
                    with torch.no_grad():  # no grad for this guy
                        slacks = slacks - self.configs["lr_resilience"] * (slacks - nu)
                elif self.configs["resilience_mode"] == "none":
                    # dual_loss has no resilience term
                    dual_loss = nu @ detached_energy_diffs
                (-dual_loss).backward()

                # Apply norm clipping to dual gradients
                max_norm = 100
                torch.nn.utils.clip_grad_norm_(nu, max_norm)

                self.dual_optimizer.step()
                nu.data = torch.clamp(nu.data, min=0)
                ###

                # Logs
                num_layer_constraints = self.model.config["num_layers"]
                num_constraints_violated = (detached_energy_diffs > 0).sum()

                # Log regular constraints
                for i in range(min(num_layer_constraints, nu.shape[0])):
                    wandb.log({f"nu@layer{i+1}": nu[i]}, commit=False)
                
                # Log redundant constraints if enabled
                if self.configs.get("redundant_constraints", False):
                    for i in range(num_layer_constraints):
                        wandb.log({f"redundant_nu@layer{i+1}": nu[i + num_layer_constraints]}, commit=False)
                
                if slacks is not None:
                    for i in range(slacks.shape[0]):
                        wandb.log({f"slack@layer{i+1}": slacks[i]}, commit=False)
                
                wandb.log({"num_constraints_violated": num_constraints_violated}, commit=False)

            self.model.lambda2.data = torch.clamp(self.model.lambda2.data, min=1e-3)
            self.model.lambda1.data = torch.clamp(self.model.lambda1.data, min=1e-3)

            if loss > BLOWUP_THRESHOLD:
                print("Loss exploded")
                # If weights blow up, stop training
                model_weights = [p for p in self.model.parameters()]
                if any([torch.isnan(p).any() for p in model_weights]):
                    print("Model weights exploded")
                    raise ValueError("Model weights exploded")

            if loss.isnan().any():
                print("FOUND NANS IN LOSS")

            # Logging
            num_dicts = self.model.D.shape[0] if self.configs["diff_D"] else 1
            energy_metrics = {}
            for k in range(energy_per_l.shape[0]):
                dict_idx = k if num_dicts > 1 else 0
                energy_metrics[f"energy@layer{k+1}"] = energy_per_l[k]
                energy_parameters = (
                    input,
                    self.model.sparse_code[:, k : k + 1],
                    self.model.D[dict_idx],
                    self.model.lambda1,
                    self.model.lambda2,
                )
                # TODO: This is hardcoded to the DUST energy function gradient.
                # energy_metrics[f"gradient@layer{k+1}"] = jacobian(
                #     constraints.evaluate_energy, energy_parameters, create_graph=True,
                #     self.model.reconstructions
                # )[2].norm()
            # energy_metrics[f"energy@initial"] = (
            #     initial_energy  # Constant for reference and plots.
            # )
            # energy_metrics[f"energy@last_layer"] = energy_per_l[-1]

            D_norms = self.model.D.norm(dim=(1, 2))
            D_norm_metrics = {
                f"D_norm@layer{i+1}": D_norms[i] for i in range(D_norms.shape[0])
            }
            sparsity_metrics = {
                f"sparsity@layer{i+1}": self.model.sparsity[i]
                for i in range(self.model.sparse_code.shape[1])
            }
            with torch.no_grad():
                if self.normalize and self.using_vit_processor:
                    inverted_input = self.processor.invert_processor(input)
                    inverted_output = self.processor.invert_processor(output)
                    psnr = metrics.psnr(inverted_input, inverted_output).item()
                else:
                    psnr = metrics.psnr(input, output, normalized=self.normalize).item()
            wandb.log(
                {
                    "train_loss": loss.item(),
                    "train_psnr": psnr,
                    "batch_idx": i,
                    "lambda1": self.model.lambda1.item(),
                    "lambda2": self.model.lambda2.item(),
                    "A_norm": self.model.A.norm().item(),
                    "sparse_code_norm": self.model.sparse_code.norm().item(),
                    "output_norm": output.norm().item(),
                    "input_norm": input.norm().item(),
                    **energy_metrics,
                    **D_norm_metrics,
                    **sparsity_metrics,
                }
            )

            iterator.set_description("batch loss: {:.6f}".format(loss))
            loss_batch.append(loss.item())
            psnr_batch.append(psnr)

        loss = sum(loss_batch) / len(loss_batch)
        psnr = sum(psnr_batch) / len(psnr_batch)

        return loss, psnr, nu, slacks

    def compound_loss(self, ref, pre, sparse_codes, lb):
        """
        """
        loss = self.metric(ref, pre)
        l1loss = torch.mean(torch.norm(sparse_codes[:, -1, ...], p=1, dim=-1).flatten())
        return lb * l1loss + loss

    def save_network_statistics(self, data_loader):
        """
        save network weights and sparse codes
        """
        self.model.eval()
        sparse_code = np.empty(
            [
                self.configs["time_steps"],
                data_loader.train.shape[1],
                self.configs["num_hidden"],
            ],
            dtype=np.float32,
        )
        with torch.no_grad():
            data_loader.current_idx_train = 0
            while data_loader.current_idx_train + self.configs["batch_size"] <= min(
                data_loader.train.shape[1], 1024
            ):
                data_torch = torch.Tensor(
                    data_loader.load_batch_train(self.configs["batch_size"])
                ).to(self.device)
                output = self.model.forward(data_torch)
                if self.configs["model"].lower() == "rnn":
                    sparse_code[
                        :,
                        data_loader.current_idx_train : self.configs["batch_size"]
                        + data_loader.current_idx_train,
                        :,
                    ] = self.model.hidden.data.cpu().numpy()
                else:
                    sparse_code[
                        :,
                        :,
                        data_loader.current_idx_train : data_torch.shape[1]
                        + data_loader.current_idx_train,
                        :,
                    ] = self.model.sparse_code.data.cpu().numpy()[
                        :,
                        :,
                        : min(
                            data_torch.shape[1],
                            data_loader.train.shape[1] - data_loader.current_idx_train,
                            1024 - data_loader.current_idx_train,
                        ),
                        ...,
                    ]

        mat = {
            "sparse_codes": sparse_code,
            "D": self.model.D.data.cpu().numpy(),
            "A": self.model.A.data.cpu().numpy(),
            "h0": self.model.h_0.data.cpu().numpy(),
            "c": self.model.alpha.data.cpu().numpy(),
            "lambdas": np.array(
                [
                    self.model.lambda0.data.cpu().numpy(),
                    self.model.lambda1.data.cpu().numpy(),
                    self.model.lambda2.data.cpu().numpy(),
                ]
            ),
        }

        if self.configs["model"].lower() == "reweighted":
            mat["G"] = self.model.G.data.cpu().numpy()
            mat["Z"] = self.model.Z.data.cpu().numpy()
            mat["g"] = self.model.g.data.cpu().numpy()
        elif self.configs["model"].lower() == "l1l1":
            mat["G"] = self.model.G.data.cpu().numpy()
        else:
            mat["F"] = self.model.F.data.cpu().numpy()

        sio.savemat("statistics" + self.configs["model"] + ".mat", mat)

    def _evaluate(
        self,
        data_loader,
        split="test",
        perturb_frac=0.0,
        growing_perturb_coeff=0.0,
        log_wandb=True,
    ): #eval
        """
        Common evaluation method for both validation and test data.

        Args:
            data_loader: The data loader to use
            split: Which data split to use ("test" or "val")
            perturb_frac: Fraction of perturbation to add
            growing_perturb_coeff: Coefficient for growing perturbation
            log_wandb: Whether to log results to wandb (needed for OOD analysis)

        Returns:
            tuple: (loss, psnr) average values
        """
        # Set up loaders based on split
        if split == "val":
            batch_loader = data_loader.load_batch_validation
            data_shape = data_loader.eval.shape
            data_std = data_loader.eval_data_std
            metric_prefix = "val"
        else:  # "test"
            batch_loader = data_loader.load_batch_test
            data_shape = data_loader.test.shape
            data_std = data_loader.test_data_std
            metric_prefix = "test"

        if self.normalize and self.using_vit_processor:
            self.processor = data_loader.vit_processor
        
        with torch.no_grad():
            psnr_batch = []
            loss_batch = []

            for i in tqdm(range(int(data_shape[1] / self.configs["batch_size"]))):
                batch = batch_loader(self.configs["batch_size"])
                input = torch.tensor(batch, dtype=self.dtype, device=self.device)
                noisy_input = self.add_noise(
                    input,
                    perturb_frac,
                    data_std,
                    growing_perturb_coeff=growing_perturb_coeff,
                )

                output = self.model.forward(noisy_input)

                if self.configs["l0_loss"] > 0:
                    loss = self.compound_loss(
                        input,
                        output,
                        self.model.sparse_code,
                        torch.tensor(self.configs["l0_loss"]).cuda(),
                    )
                else:
                    loss = self.metric(input, output)

                # Compute layerwise metrics: PSNR, MSE, selected energy function.
                metrics_to_log = {}
                for metric in set(
                    ["psnr", "mse", "loss", "sparsity", self.configs["constraint_type"]]
                ):
                    layerwise_metrics = constraints.compute_layerwise_metrics(
                        input,
                        self.model.sparse_code,
                        self.model.D,
                        self.model.lambda1,
                        self.model.lambda2,
                        constraint_type=metric,
                        reconstructions=self.model.reconstructions,
                        vit_processor=self.processor if (self.normalize and self.using_vit_processor) else None
                    )
                    if metric == "psnr":
                        psnr_metrics = layerwise_metrics #TODO deleteme
                    for k in range(layerwise_metrics.shape[0]):
                        metrics_to_log[f"{metric_prefix}_{metric}@layer{k+1}"] = (
                            layerwise_metrics[k]
                        )  # e.g. val_psnr@layer1

                if self.normalize and self.using_vit_processor:
                    inverted_input = self.processor.invert_processor(input)
                    inverted_output = self.processor.invert_processor(output)
                    psnr_value = metrics.psnr(inverted_input, inverted_output).item()
                else:
                    psnr_value = metrics.psnr(input, output, normalized=self.normalize).item()
                
                loss_batch.append(loss.item())
                psnr_batch.append(psnr_value)
                            
                wandb.log(
                    {
                        f"{metric_prefix}_loss": loss.item(),
                        f"{metric_prefix}_psnr": psnr_value,
                        **metrics_to_log,
                        "epoch": i,
                    },
                    commit=False,
                )

                # MNIST-specific visualization for test set only
                if self.configs["dataset"] == "moving_mnist" and split == "test":
                    self.mnist_specific_logging(batch, output, i)

            # Calculate averages
            loss = sum(loss_batch) / len(loss_batch)
            psnr = sum(psnr_batch) / len(psnr_batch)

            # Log overall metrics if requested
            if log_wandb:
                metrics_dict = {
                    f"{metric_prefix}_loss": loss,
                    f"{metric_prefix}_psnr": psnr,
                }
                if split == "test":
                    metrics_dict["test_psnr_noise_0_00"] = psnr
                wandb.log(metrics_dict, commit=False)

            to_print = f"{metric_prefix}_loss: {loss}, {metric_prefix}_psnr: {psnr}"
            print(to_print)

            return loss, psnr

    def validation(self, data_loader, perturb_frac_test=0.0):
        return self._evaluate(  
            data_loader,
            split="val",
            perturb_frac=perturb_frac_test,
            growing_perturb_coeff=0.0,
            log_wandb=True,
        )

    def test(
        self,
        data_loader,
        perturb_frac_test=0.0,
        growing_perturb_coeff=0.0,
        log_wandb=True,
    ):
        return self._evaluate(
            data_loader,
            split="test",
            perturb_frac=perturb_frac_test,
            growing_perturb_coeff=growing_perturb_coeff,
            log_wandb=log_wandb,
        )

    def mnist_specific_logging(self, test_batch, output, batch_index):
        """
        Save visualization samples for MNIST dataset.

        Args:
            test_batch: The original test batch
            output: The model's output tensor
            batch_index: The current batch index (for naming the output files)
        """
        # Saving samples
        output[output > 255.0] = 255.0
        output[output < 0.0] = 0.0
        sample_gt = np.squeeze(test_batch[10, 0, ...])
        sample_out = np.squeeze(output[10, 0, ...].cpu().detach().numpy())
        abs_error = np.abs(sample_out - sample_gt)

        sample_gt = cv2.cvtColor(np.reshape(sample_gt, (16, 16)), cv2.COLOR_GRAY2BGR)
        sample_out = cv2.cvtColor(np.reshape(sample_out, (16, 16)), cv2.COLOR_GRAY2BGR)

        cv2.imwrite(
            os.path.join(
                self.configs["log_folder"], "{}_10_gt.png".format(batch_index)
            ),
            sample_gt,
        )
        cv2.imwrite(
            os.path.join(
                self.configs["log_folder"], "{}_10_out.png".format(batch_index)
            ),
            sample_out,
        )

        error_scale_factor = 4
        abs_error_scaled = abs_error * error_scale_factor
        abs_error_scaled[abs_error_scaled > 255.0] = 255.0
        img_error = cv2.applyColorMap(
            np.reshape(abs_error_scaled.astype(np.uint8), (16, 16)),
            cv2.COLORMAP_JET,
        )
        cv2.imwrite(
            os.path.join(
                self.configs["log_folder"], "{}_10_ae.png".format(batch_index)
            ),
            img_error,
        )

    def ood_analysis(self, data_loader):
        """
        Run OOD analysis on the test set and compute an aggregated robustness score.
        The score is computed as the area under the PSNR vs noise curve, normalized
        by the clean PSNR and the maximum perturbation level.
        """
        perturb_fracs = [0.0, 0.01, 0.05, 0.1, 0.2, 0.25, 0.5, 0.75, 1.0, 1.5]
        psnrs = []

        for perturb_frac in perturb_fracs:
            print(f"Running OOD analysis for perturb_frac={perturb_frac}")
            data_loader.reset_indices()
            loss, psnr = self.test(
                data_loader,
                perturb_frac_test=perturb_frac,
                growing_perturb_coeff=1.0,
                log_wandb=False,
            )
            psnrs.append(psnr)
            wandb.log(
                {
                    f"test_psnr_noise_{str(perturb_frac).replace('.', '_')}": psnr,
                    f"test_loss_noise_{str(perturb_frac).replace('.', '_')}": loss,
                },
                commit=False,
            )

        # Compute area under PSNR vs noise curve (trapezoidal integration)
        auc = np.trapz(y=psnrs, x=perturb_fracs)

        # Normalize by clean PSNR and max perturbation to get score in [0,1]
        # A perfect model would maintain clean PSNR across all noise levels
        perfect_auc = psnrs[0] * perturb_fracs[-1]  # clean_psnr * max_perturb
        robustness_score = auc / perfect_auc

        # Calculate AUC for higher noise levels (beyond training noise)
        training_noise = self.configs["perturb"]
        higher_noise_indices = [
            i for i, noise in enumerate(perturb_fracs) if noise > training_noise
        ]

        if higher_noise_indices:
            # Calculate AUC only for noise levels higher than training noise
            higher_noise_perturb_fracs = [
                perturb_fracs[i] for i in higher_noise_indices
            ]
            higher_noise_psnrs = [psnrs[i] for i in higher_noise_indices]
            higher_noise_auc = np.trapz(
                y=higher_noise_psnrs, x=higher_noise_perturb_fracs
            )

            # Calculate perfect AUC for higher noise levels
            higher_noise_perfect_auc = psnrs[0] * (perturb_fracs[-1] - training_noise)
            higher_noise_robustness_score = higher_noise_auc / higher_noise_perfect_auc
        else:
            higher_noise_auc = 0.0
            higher_noise_perfect_auc = 0.0
            higher_noise_robustness_score = 0.0

        wandb.log(
            {
                "test_ood_robustness_score": robustness_score,
                "test_ood_auc": auc,
                "test_ood_perfect_auc": perfect_auc,
                "test_ood_auc_higher_noise": higher_noise_auc,
                "test_ood_perfect_auc_higher_noise": higher_noise_perfect_auc,
                "test_ood_robustness_score_higher_noise": higher_noise_robustness_score,
            }
        )

        return robustness_score

    def add_noise(self, batch_raw, perturb_frac, data_std, growing_perturb_coeff=1.0, normalized=False):
        return add_noise(batch_raw, perturb_frac, data_std, growing_perturb_coeff, normalized=self.normalize)

    def _evaluate_ood_on_split(self, data_loader, split="test", num_batches=256):
        """
        Evaluate OOD performance on a few batches from a given split ('train', 'val', or 'test'),
        without permanently modifying the dataloader state.
        """
        # Map split names to dataloader attributes and functions.
        if split == "train":
            saved_idx = data_loader.current_idx_train
            load_fn = data_loader.load_batch_train
            data_std = data_loader.train_data_std
            reset_attr = "current_idx_train"
        elif split == "val":
            saved_idx = data_loader.current_idx_eval
            load_fn = data_loader.load_batch_validation
            data_std = data_loader.eval_data_std
            reset_attr = "current_idx_eval"
        elif split == "test":
            saved_idx = data_loader.current_idx_test
            load_fn = data_loader.load_batch_test
            data_std = data_loader.test_data_std
            reset_attr = "current_idx_test"
        else:
            raise ValueError(f"Unsupported split: {split}")

        # Reset the corresponding index for a controlled evaluation.
        setattr(data_loader, reset_attr, 0)

        # Define a set of perturbation levels for the OOD curve.
        perturb_fracs = [0.0, 0.01, 0.05, 0.1, 0.2, 0.25, 0.5, 0.75, 1.0, 1.5]
        psnr_results = {}

        with torch.no_grad():
            for perturb_frac in perturb_fracs:
                batch_psnrs = []
                for _ in range(num_batches):
                    # Load a batch using the appropriate loader function.
                    batch = load_fn(self.configs["batch_size"])
                    input_tensor = torch.tensor(
                        batch, dtype=self.dtype, device=self.device
                    )
                    noisy_input = self.add_noise(
                        input_tensor, perturb_frac, data_std, growing_perturb_coeff=1.0
                    )
                    output = self.model.forward(noisy_input)
                    with torch.no_grad():
                        if self.normalize and self.using_vit_processor:
                            inverted_input = self.processor.invert_processor(input_tensor)
                            inverted_output = self.processor.invert_processor(output)
                            psnr = metrics.psnr(inverted_input, inverted_output).item()
                        else:
                            psnr = metrics.psnr(input_tensor, output, normalized=self.normalize).item()
                    batch_psnrs.append(psnr)
                # Average PSNR for this perturbation level.
                avg_psnr = np.mean(batch_psnrs)
                psnr_results[perturb_frac] = avg_psnr
                # Log the average PSNR for this noise level.
                wandb.log(
                    {
                        f"{split}_psnr_noise_{str(perturb_frac).replace('.', '_')}": avg_psnr
                    },
                    commit=False,
                )

        # Restore the original index so that subsequent evaluations remain unaffected.
        setattr(data_loader, reset_attr, saved_idx)

        # Compute aggregated robustness score (area under the PSNR vs. noise curve).
        perturb_fracs_arr = np.array(list(psnr_results.keys()))
        psnrs_arr = np.array(list(psnr_results.values()))
        auc = np.trapz(psnrs_arr, perturb_fracs_arr)
        perfect_auc = (
            psnrs_arr[0] * perturb_fracs_arr[-1]
        )  # baseline using the clean PSNR.
        robustness_score = auc / perfect_auc

        # Log aggregated metrics with a split-specific tag.
        wandb.log(
            {
                f"ood_robustness_score_{split}": robustness_score,
                f"ood_auc_{split}": auc,
                f"ood_perfect_auc_{split}": perfect_auc,
            }
        )

        return robustness_score, psnr_results

    def test_video(self, data_loader, split, sample_output=[10], error_map=False):
        """
        Frame reconstruction on the whole video using non-overlapping patches and sequences of time_steps frames
        If sample_output is not None, will save the [sample_output] frames in config['log_dir']
        """

        with torch.no_grad():
            clips = data_loader.test_clips
            psnrs = []

            for c_i, clip in enumerate(clips):
                sx, sy = clip.shape[1], clip.shape[2]

                # Create patches
                patches = create_patches(clip, self.configs["tile_size"])

                # Flatten and permute patches/frames indices
                patches = np.swapaxes(
                    np.reshape(patches, (patches.shape[0], patches.shape[1], -1)), 0, 1
                )

                # Inference (with batch division to avoid memory OF)
                torch.cuda.empty_cache()
                output = torch.tensor(
                    np.empty(patches.shape), dtype=self.dtype, device="cpu"
                )

                batch_idx = 0
                while batch_idx < patches.shape[1]:
                    end_batch = min(
                        batch_idx + self.configs["batch_size"], patches.shape[1]
                    )
                    patch_output = self.model.forward(
                        torch.tensor(
                            patches[:, batch_idx:end_batch, ...],
                            dtype=self.dtype,
                            device=self.device,
                        )
                    )
                    output[:, batch_idx:end_batch, ...] = patch_output.cpu()

                    batch_idx += self.configs["batch_size"]

                # Reconstruct output clip
                out = output.cpu().detach().numpy()
                out = np.reshape(
                    out,
                    (
                        patches.shape[0],
                        patches.shape[1],
                        self.configs["tile_size"],
                        self.configs["tile_size"],
                    ),
                )

                out = assemble_patches(out, self.configs["tile_size"], sx, sy)

                out[out > 255.0] = 255.0
                out[out < 0.0] = 0.0
                psnrs.append(metrics.psnr_numpy(clip, out))
                # psnrs.append(metrics.psnr(torch.tensor(clip, dtype=self.dtype, device='cpu'), torch.tensor(out, device='cpu', dtype=self.dtype)).item())

                if sample_output is not None:
                    # if c_i%10 == 0:
                    for so in sample_output:
                        img_gt, img_pred = np.squeeze(clip[so, ...]), np.squeeze(
                            out[so, ...]
                        )
                        abs_error = np.abs(img_pred - img_gt)
                        img_gt = cv2.cvtColor(img_gt, cv2.COLOR_GRAY2BGR)
                        img_pred = cv2.cvtColor(img_pred, cv2.COLOR_GRAY2BGR)
                        cv2.imwrite(
                            os.path.join(
                                self.configs["log_folder"],
                                "{}_{}_gt.png".format(c_i, so),
                            ),
                            img_gt,
                        )
                        cv2.imwrite(
                            os.path.join(
                                self.configs["log_folder"],
                                "{}_{}_rec.png".format(c_i, so),
                            ),
                            img_pred,
                        )

                        # Output error map
                        if error_map is True:
                            error_scale_factor = 4
                            abs_error_scaled = abs_error * error_scale_factor
                            abs_error_scaled[abs_error_scaled > 255.0] = 255.0
                            img_error = cv2.applyColorMap(
                                abs_error_scaled.astype(np.uint8), cv2.COLORMAP_JET
                            )
                            cv2.imwrite(
                                os.path.join(
                                    self.configs["log_folder"],
                                    "{}_{}_ae.png".format(c_i, so),
                                ),
                                img_error,
                            )

            # Mean PSNR
            avg_psnr = sum(psnrs) / len(psnrs)
            print("Avg_psnr {}".format(avg_psnr))

    def _infer_initial_energy(self, data_loader, num_batches=None):
        """
        Estimate the initial energy by computing the energy for a specified number of batches
        from the training data.

        This is used as a fixed baseline for the initial constraint when use_input_energy_baseline=False.
        When use_input_energy_baseline=True, this method isn't used, and instead the energy of each 
        noisy input batch is computed dynamically during training.

        Parameters:
        -----------
        data_loader : DataLoader
            The data loader containing the training data.
        num_batches : int, optional
            The number of batches to use for estimating the initial energy. If not provided,
            it is derived from the size of the training data and the batch size.

        Returns:
        --------
        float
            The mean of the computed initial energies multiplied by 10.
        """
        if not num_batches:
            num_batches = int(data_loader.train.shape[1] / self.configs["batch_size"])
        initial_energies = []
        with torch.no_grad():
            for i in range(num_batches):
                batch = data_loader.load_batch_train(self.configs["batch_size"])
                input = torch.tensor(batch, dtype=self.dtype, device=self.device)
                noisy_input = self.add_noise(
                    input,
                    self.configs["perturb"],
                    _,
                    0.0,
                )  # TODO this lloks broken

                _ = self.model(noisy_input)

                energy_per_l = constraints.compute_layerwise_metrics(
                    input,
                    self.model.sparse_code,
                    self.model.D,
                    self.model.lambda1,
                    self.model.lambda2,
                    constraint_type=self.configs["constraint_type"],
                    reconstructions=self.model.reconstructions,
                )

                initial_energies.append(energy_per_l[-1].item())
        # Reset the data_loader to the start.
        data_loader.reset_indices()
        return np.mean(initial_energies) * 10


if __name__ == "__main__":
    # defaults
    CONFIG_PATH = "frame_reconstruction_configs.yaml"
    with open(CONFIG_PATH, "r") as stream:
        config = yaml.load(stream, SafeLoader)
        # Optionally, cast values to int when applicable
        for key, val in config.items():
            try:
                val = int(val)
            except:
                pass

    # Parse CLI arguments with proper type conversion
    parser = argparse.ArgumentParser(description="Multiple L1 for frame reconstruction")
    parser.add_argument(
        "-m",
        "--model",
        choices=[
            "sista",
            "l1l1",
            "reweighted",
            "gru",
            "lstm",
            "rnn",
            "dust",
            "dust_vec",
            "unrolled_transformer",
            "vit",
        ],
        default="dust",
    )
    parser.add_argument("-ld0", "--lambda0", help="Lambda0 value", type=float)
    parser.add_argument("-ld1", "--lambda1", help="Lambda1 value", type=float)
    parser.add_argument("-ld2", "--lambda2", help="Lambda2 value", type=float)
    parser.add_argument(
        "-learn_ld0", "--learn_lambda0", help="Learn Lambda0 flag", type=int
    )
    parser.add_argument(
        "-learn_ld1", "--learn_lambda1", help="Learn Lambda1 flag", type=int
    )
    parser.add_argument(
        "-learn_ld2", "--learn_lambda2", help="Learn Lambda2 flag", type=int
    )
    parser.add_argument(
        "-h0",
        "--h0_init",
        choices=["zeros", "ones", "normal_0", "normal_d"],
        default="zeros",
        help="Initialization of the hidden state.",
    )
    parser.add_argument(
        "-e0",
        "--initial_energy",
        type=float,
        required=False,
        help="Energy value for initial layer constraint.",
    )
    parser.add_argument(
        "--use_input_energy_baseline",
        action="store_true",
        help="Use energy of noisy input as baseline for initial constraint, ignoring initial_energy parameter.",
    )
    parser.add_argument("-k", "--num_layers", help="Number of layers", type=int)
    parser.add_argument("-wd", "--weight_decay", help="Weight decay", type=float)
    parser.add_argument("-lr", "--lr", help="Learning rate", type=float)
    parser.add_argument("-dlr", "--lr_dual", help="Dual learning rate", type=float)
    parser.add_argument(
        "-d_init",
        "--dual_init",
        default=0.0,
        help="Initial value of the dual variables",
        type=float,
    )
    # Resilience parameters
    parser.add_argument(
        "-rlr",
        "--lr_resilience",
        default=0.0,
        help="Resilience learning rate",
        type=float,
    )
    parser.add_argument(
        "-rc",
        "--resilience_coefficient",
        type=float,
        default=1.0,
        help="Coefficient for resilience term",
    )
    parser.add_argument(
        "--unconstr_warmup_epochs",
        type=int,
        default=1,
        help="Number of epochs to warm up the model before applying constraints",
    )

    parser.add_argument(
        "-ro",
        "--resilience_optimizer",
        choices=["adam", "sgd"],
        help="Resilience optimizer",
        default="sgd",
    )
    parser.add_argument(
        "-rm",
        "--resilience_mode",
        choices=["l2", "l2_wd", "none"],
        help="Resilience mode: L2 (regular) or L2 with weight decay or none. If L2, will need to provide lr_resilience",
        default="none",
    )

    parser.add_argument(
        "-cf", "--compression_factor", help="Compression factor", type=float
    )
    parser.add_argument("-epoch", "--num_epoch", help="Number of epochs", type=int)
    parser.add_argument(
        "-nh", "--num_hidden", help="Number of units per hidden layer", type=int
    )
    parser.add_argument("-ts", "--time_steps", help="Time steps per sequence", type=int)
    parser.add_argument(
        "-d_path",
        "--D_init_file_path",
        help="Path for dictionary initialization",
        type=str,
    )
    parser.add_argument(
        "-d",
        "--dataset",
        choices=["moving_mnist", "yup", "ucsd", "shanghaitech", "avenue"],
        default="moving_mnist",
    )
    parser.add_argument("-r", "--resume", action="store_true", default=False)
    parser.add_argument("-cpt", "--checkpoint", default="./checkpoints", type=str)
    parser.add_argument("-t", "--test", action="store_true", default=False)
    parser.add_argument("-tc", "--test_config", type=str)
    parser.add_argument("-gc", "--gradient_clip", action="store_true", default=True)
    parser.add_argument("-l0", "--l0_loss", default=0.0, help="L0 loss", type=float)
    parser.add_argument("-em", "--error_map", default=False, action="store_true")
    parser.add_argument("--constrained", action="store_true")
    parser.add_argument(
        "--constraint_type",
        choices=["energy", "energy_jacobian", "loss", "none", "mse"],
        default="energy",
        help="Type of constraint to apply.",
    )
    parser.add_argument(
        "--energy_to_report",
        choices=["energy", "energy_jacobian", "loss", "mse", "psnr", "sparsity"],
        default="mse",
        help="Energy function to report when constraint_type is 'none'.",
    )
    parser.add_argument(
        "-ca",
        "--constraint_alpha",
        type=float,
        default=1.0,
        help="Constraint alpha parameter",
    )

    parser.add_argument("-diff", "--diff_D", action="store_true")
    parser.add_argument(
        "-pert", "--perturb", default=0.0, help="Perturbation rate", type=float
    )
    parser.add_argument(
        "-p", "--wandb_project", default="project", type=str
    )
    parser.add_argument("--experiment_tag", default="e0_untagged", type=str)
    parser.add_argument("--run_name", default="", type=str)
    parser.add_argument("--seed", type=int, default=2018, help="Random seed")
    parser.add_argument('--noise-schedule', type=str, choices=['growing', 'constant'], 
                        default='growing', help='Noise schedule during training: growing (increases over epochs) or constant')
    parser.add_argument('--device', type=str, choices=['cpu', 'cuda', 'mps'], required=False, help='Device to run the model on')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument(
        "--redundant_constraints", 
        action="store_true",
        help="Enable redundant constraints",
    )
    parser.add_argument("--wandb_note", type=str, default="", help="Additional notes for wandb run")
    parser.add_argument(
        "--normalize", 
        action="store_true",
        help="Use normalized images with proper inversion before PSNR calculation",
    )
    parser.add_argument(
        "--break_if_exploding",
        action="store_true",
        help="Break training if loss exceeds BLOWUP_THRESHOLD or becomes NaN",
    )
    args = vars(parser.parse_args())

    # Direct transfer of arguments to config dictionary
    for key, value in args.items():
        if value is not None:  # Only update if argument was provided
            config[key] = value

    # Update boolean flags separately
    for flag in ["resume", "test", "gradient_clip", "error_map", "diff_D", "use_input_energy_baseline", "normalize"]:
        config[flag] = bool(args.get(flag))

    # Handle the constrained flag and its relation to constraint_type
    if args.get("constrained") and config.get("constraint_type") == "none":
        raise ValueError(
            f'Contradicting "constrained" and "constraint_type" parameters: {args["constrained"]} and {config["constraint_type"]}'
        )
    config["constrained"] = config.get("constraint_type", "none") != "none"
    config["is_constrained"] = config["constrained"]
    
    # Ensure redundant constraints are only used when constraints are enabled
    if config.get("redundant_constraints", False) and not config["constrained"]:
        raise ValueError("Redundant constraints can only be used when constraints are enabled (constraint_type != 'none')")

    # Construct log folder and run name
    now = datetime.now()
    config["log_folder"] = join(
        config["log_path"],
        f"{now.year}_{now.month}_{now.day} {now.hour}h{now.minute}_{now.second}",
    )
    config[
        "log_folder"
    ] += f" {config['model']}_{config['dataset']}_k{config['num_layers']}_C{config['compression_factor']}_n{config['num_hidden']}"
    os.makedirs(config["log_folder"])

    run_name = f"{config['model']}_{config['dataset']}_k{config['num_layers']}_C{config['compression_factor']}_n{config['num_hidden']}_diffD{config['diff_D']}_constrType_{config['constraint_type']}_ca{config['constraint_alpha']}"
    
    # Add redundant constraints to run name if enabled
    if config.get("redundant_constraints", False):
        run_name += "_redundant"

    config["run_name"] = run_name
    config["seed"] = args["seed"]

    if not os.path.exists((config["checkpoint"] + "/" + config["run_name"])):
        os.makedirs((config["checkpoint"] + "/" + config["run_name"]))

    # save configs to file
    with open(join(config["log_folder"], "configs.yml"), "w") as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

    # TODO: codesmell bunch of repeating parameters.
    if config["model"] == "vit":
        processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
        vit_clip_processor = vit.ViTClipProcessor(processor=processor,device=args["device"])
        flatten=False #ViT expects H and W dimensions.
    else:
        vit_clip_processor=None
        flatten=True
    if config["dataset"] == "moving_mnist":
        data_loader = Moving_MNIST_Loader(
            config["moving_mnist_path"],
            time_steps=config["time_steps"],
            load_only=-1,
            flatten=flatten,
            scale=False,
        )
    elif config["dataset"] == "yup":
        data_loader = YUP_Loader(
            config["yup_path"],
            patch_size=config["tile_size"],
            time_steps=20,
            flatten=flatten,
            scale=False,
            normalize=False,
            category=["Street"],
            camera=["static"],
        )

    elif config["dataset"] == "ucsd":
        data_loader = UCSD_Loader(
            config["ucsd_path"],
            patch_size=config["tile_size"],
            time_steps=20,
            flatten=flatten,
            scale=False,
            normalize=config["normalize"],
            vit_clip_processor=vit_clip_processor,
            normalized_images=config["normalize"],
        )

    elif config["dataset"] == "shanghaitech":
        data_loader = ShanghaiTechLoader(
            config["shanghaitech_loader_base_path"],
            patch_size=config["tile_size"],
            time_steps=20,
            flatten=flatten,
            scale=False,
            normalize=config["normalize"],
            vit_clip_processor=vit_clip_processor,
            normalized_images=config["normalize"],
        )

    elif config["dataset"] == "avenue":
        data_loader = ShanghaiTechLoader(
            config[
                "shanghaitech_loader_base_path"
            ],  
            patch_size=config["tile_size"],
            time_steps=20,
            flatten=flatten,
            scale=False,
            normalize=config["normalize"],
            split=[
                "avenue"
            ],  
            vit_clip_processor=vit_clip_processor,
            normalized_images=config["normalize"],
        )

    else:
        print(config["dataset"] + ": Invalid dataset")
        exit()

    np.random.seed(config["seed"])
    torch.manual_seed(config["seed"])
    wandb.init(
        project=args["wandb_project"],
        name=run_name,
        notes=args["wandb_note"],
        tags=[args["experiment_tag"]],
        config=config,
    )
    fr = Frame_Reconstruction(config)

    if config["test"] is False:
        fr.train(data_loader)
    else:
        if config["test_config"] == "video_test":
            fr.test_video(data_loader, "test", error_map=args["error_map"])
        elif config["test_config"] == "video_eval":
            fr.test_video(data_loader, "eval", error_map=args["error_map"])
        elif config["test_config"] == "patch_test":
            fr.metric = torch.nn.MSELoss()
            fr.test(data_loader, perturb_frac_test=0.0, growing_perturb_coeff=0.0)
        elif config["test_config"] == "patch_eval":
            fr.metric = torch.nn.MSELoss()
            fr.validation(data_loader)
        elif config["test_config"] == "mnist_eval":
            fr.metric = torch.nn.MSELoss()
            fr.validation(data_loader)
        elif config["test_config"] == "mnist_test":
            fr.metric = torch.nn.MSELoss()
            fr.test(
                data_loader,
                perturb_frac_test=0.0,
                growing_perturb_coeff=0.0,
            )
        else:
            print("Unrecognized test configuration")

    wandb.finish()
    assert wandb.run is None

print("Ok!")
