import argparse
from email.policy import default
import shutil
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
import torch.cuda.amp as amp
from torch.nn.parallel import DistributedDataParallel
from einops import rearrange
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap as ruamelDict
from dadaptation import DAdaptAdam, DAdaptAdan
from adan_pytorch import Adan
from collections import OrderedDict
import wandb
import pickle as pkl
import gc
from torchinfo import summary
from collections import defaultdict

import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(current_dir)
sys.path.append(parent_dir)

from data_utils.datasets import get_data_loader, DSET_NAME_TO_OBJECT, pinf_frame_to_h5, pinf_combine_frames_to_h5, calculate_crop_box
from swin_transformer import build_vmae  # models.
from utils import logging_utils
from utils.YParams import YParams
    # from utils.load_ckpt_utils import load_ckpt
# except:
#     from .data_utils.datasets import get_data_loader, DSET_NAME_TO_OBJECT, pinf_frame_to_h5
#     from .models.transformer import build_vmae
#     from .utils import logging_utils
#     from .utils.YParams import YParams
    # from .utils.load_ckpt_utils import load_ckpt
from PIL import Image
from pdb import set_trace as bp



def grad_norm(parameters):
    with torch.no_grad():
        total_norm = 0
        for p in parameters:
            if p.grad is not None:
                total_norm += p.grad.data.pow(2).sum().item()
        return total_norm**.5

def grad_clone(parameters):
    with torch.no_grad():
        clones = []
        for p in parameters:
            if p.grad is not None:
                clones.append(p.grad.clone())
            else:
                clones.append(torch.zeros_like(p))
        return clones

def param_norm(parameters):
    with torch.no_grad():
        total_norm = 0
        for p in parameters:
            total_norm += p.pow(2).sum().item()
        return total_norm**.5

def param_diff(params1, params2):
    with torch.no_grad():
        total_norm = 0
        for p1, p2 in zip(params1, params2):
            total_norm += (p2-p1).pow(2).sum().item()
        return total_norm**.5

def freeze_layers(model, param):
    optimized_layers = param.optimized_layers
    freeze_former_layers = param.freeze_former_layers
    for name, param in model.named_parameters():
        if freeze_former_layers and name not in optimized_layers:
            param.requires_grad = False
    return model

def add_weight_decay(model, weight_decay=1e-5, inner_lr=1e-3, skip_list=()):
    """ From Ross Wightman at:
        https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/3 
        
        Goes through the parameter list and if the squeeze dim is 1 or 0 (usually means bias or scale) 
        then don't apply weight decay. 
        """
    decay = []
    no_decay = []
    
    for name, param in model.named_parameters():
        # if not ('head_ssl' in name): param.requires_grad = False # TODO
        if not param.requires_grad:
            continue
        else:
            print('Optimized layers include:', name)
        if (len(param.squeeze().shape) <= 1 or name in skip_list):
            no_decay.append(param)
        else:
            decay.append(param)
    return [
            {'params': no_decay, 'weight_decay': 0.,},
            {'params': decay, 'weight_decay': weight_decay}]

class Trainer:
    def __init__(self, params, global_rank, local_rank, device, sweep_id=None):
        self.device = device
        self.params = params
        self.global_rank = global_rank
        self.local_rank = local_rank
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.sweep_id = sweep_id
        self.log_to_screen = params.log_to_screen
        # Basic setup
        self.train_loss = nn.MSELoss()
        self.startEpoch = 0
        self.epoch = 0
        self.debug_grad = params.debug_grad
        self.mp_type = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.half

        self.iters = 0
        self.initialize_data(self.params)
        # print(f"Initializing model on rank {self.global_rank}")
        logger.info(f"Initializing model on rank {self.global_rank}")
        self.initialize_model(self.params)
        self.initialize_optimizer(self.params)
        if params.resuming:
            # print("Loading checkpoint %s"%params.checkpoint_path)
            self.single_print("Loading checkpoint %s"%params.checkpoint_path)
            self.restore_checkpoint(params.checkpoint_path)
        if params.resuming == False and params.pretrained:
            # print("Starting from pretrained model at %s"%params.pretrained_ckpt_path)
            self.single_print("Starting from pretrained model at %s"%params.pretrained_ckpt_path)
            self.restore_checkpoint(params.pretrained_ckpt_path)
            self.iters = 0
            self.startEpoch = 0
        # Do scheduler after checking for resume so we don't warmup every time
        self.initialize_scheduler(self.params)

    def single_print(self, *text):
        if self.global_rank == 0 and self.log_to_screen:
            # print(' '.join([str(t) for t in text]))
            logger.info(' '.join([str(t) for t in text]))

    def initialize_data(self, params, inference=False):
        if params.tie_batches:
            in_rank = 0
        else:
            in_rank = self.global_rank
        if self.log_to_screen:
            # print(f"Initializing data on rank {self.global_rank}")
            self.single_print(f"Initializing data on rank {self.global_rank}")
        if not inference:
            if len(params.train_data_paths) > 0:
                self.train_data_loader, self.train_dataset, self.train_sampler = get_data_loader(params, params.train_data_paths, 
                                  dist.is_initialized(), split='train', train_val_test=params.train_val_test, rank=in_rank, train_offset=self.params.embedding_offset, rollout=getattr(params, "rollout_train", 1))
                if dist.is_initialized():
                    self.train_sampler.set_epoch(0)
            if params['ssl'] != 'none':
                self.ood_data_loader, self.ood_dataset, self.ood_sampler = get_data_loader(params, params.ood_train_data_paths, 
                                  dist.is_initialized(), split='val', train_val_test=None, rank=in_rank, train_offset=self.params.embedding_offset, rollout=getattr(params, "rollout_train", 1))
                if dist.is_initialized():
                    self.ood_sampler.set_epoch(0)
        self.valid_datasets = []
        for _path in params.valid_data_paths:
            _, _dataset, _ = get_data_loader(params, [_path], dist.is_initialized(), split='val', train_val_test=params.train_val_test, rank=in_rank, rollout=getattr(params, "rollout_test", 1))
            self.valid_datasets.append(_dataset)
        if params['ssl'] != 'none':
            self.valid_datasets.append(get_data_loader(params, params.ood_valid_data_paths, dist.is_initialized(), split='val', train_val_test=None, rank=in_rank, rollout=getattr(params, "rollout_test", 1))[1])  


    def initialize_model(self, params):
        self.model = build_vmae(params).to(self.device)
        if self.params.compile:
            # print('WARNING: BFLOAT NOT SUPPORTED IN SOME COMPILE OPS SO SWITCHING TO FLOAT16')
            logger.warning('WARNING: BFLOAT NOT SUPPORTED IN SOME COMPILE OPS SO SWITCHING TO FLOAT16')
            self.mp_type = torch.half
            self.model = torch.compile(self.model)
        
        if dist.is_initialized():
            self.model = DistributedDataParallel(self.model, device_ids=[self.local_rank],
                                                 output_device=[self.local_rank], find_unused_parameters=True)
        
        self.single_print(f'Model parameter count: {sum([p.numel() for p in self.model.parameters()])}')

    def initialize_optimizer(self, params): 
        if params.freeze_former_layers:
            self.model = freeze_layers(self.model, params)
        parameters = add_weight_decay(self.model, self.params.weight_decay) # Dont use weight decay on bias/scaling terms
        if params.optimizer == 'adam':
            if self.params.learning_rate < 0:
                self.optimizer =  DAdaptAdam(parameters, lr=1., growth_rate=1.05, log_every=100, decouple=True )
            else:
                self.optimizer = optim.AdamW(parameters, lr=params.learning_rate)
        elif params.optimizer == 'adan':
            if self.params.learning_rate < 0:
                self.optimizer =  DAdaptAdan(parameters, lr=1., growth_rate=1.05, log_every=100)
            else:
                self.optimizer = Adan(parameters, lr=params.learning_rate)
        elif params.optimizer == 'sgd':
            self.optimizer = optim.SGD(self.model.parameters(), lr=params.learning_rate, momentum=0.9)
        else: 
            raise ValueError(f"Optimizer {params.optimizer} not supported")
        self.gscaler = amp.GradScaler(enabled= (self.mp_type == torch.half and params.enable_amp))

    def initialize_scheduler(self, params):
        if params.scheduler_epochs > 0:
            sched_epochs = params.scheduler_epochs
        else:
            sched_epochs = args.max_epochs
        if params.scheduler == 'cosine':
            if self.params.learning_rate < 0:
                self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 
                                                                            last_epoch = (self.startEpoch*params.epoch_size) - 1,
                                                                            T_max=sched_epochs*params.epoch_size, 
                                                                            eta_min=params.learning_rate / 100)
            else:
                k = params.warmup_steps
                if (self.startEpoch*params.epoch_size) < k:
                    warmup = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=.01, end_factor=1.0, total_iters=k)
                    decay = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, eta_min=params.learning_rate / 100, T_max=sched_epochs)
                    self.scheduler = torch.optim.lr_scheduler.SequentialLR(self.optimizer, [warmup, decay], [k], last_epoch=(params.epoch_size*self.startEpoch)-1)
        else:
            self.scheduler = None


    def save_checkpoint(self, checkpoint_path, model=None):
        """ Save model and optimizer to checkpoint """
        if not model:
            model = self.model

        torch.save({'iters': self.epoch*self.params.epoch_size, 'epoch': self.epoch, 'model_state': model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict()}, checkpoint_path)

    def restore_checkpoint(self, checkpoint_path):
        """ Load model/opt from path """
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        if 'model_state' in checkpoint:
            model_state = checkpoint['model_state']
        else:
            model_state = checkpoint
        try: # Try to load with DDP Wrapper
            self.model.load_state_dict(model_state)
        except: # If that fails, either try to load into module or strip DDP prefix
            if hasattr(self.model, 'module'):
                self.model.module.load_state_dict(model_state)
            else:
                new_state_dict = OrderedDict()
                for key, val in model_state.items():
                    # Failing means this came from DDP - strip the DDP prefix
                    name = key[7:]
                    new_state_dict[name] = val
                self.model.load_state_dict(new_state_dict)
        
        if self.params.resuming:  #restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr.
            self.iters = checkpoint['iters']
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            # self.startEpoch = 0 # checkpoint['epoch']
            self.epoch = self.startEpoch
        else:
            self.iters = 0
        if self.params.pretrained:
            if self.params.freeze_middle:
                self.model.module.freeze_middle()
            elif self.params.freeze_processor:
                self.model.module.freeze_processor()
            else:
                self.model.module.unfreeze()
            # See how much we need to expand the projections
            exp_proj = 0
            # Iterate through the appended datasets and add on enough embeddings for all of them. 
            for add_on in self.params.append_datasets:
                exp_proj += len(DSET_NAME_TO_OBJECT[add_on]._specifics()[2])
            self.model.module.expand_projections(exp_proj)
        checkpoint = None
        self.model = self.model.to(self.device)

    def train_one_epoch(self):
        self.model.set_dropout_p(0)
        self.model.train()
        self.epoch += 1
        tr_time = 0
        data_time = 0
        data_start = time.time()
        self.model.train()
        logs = {}
        steps = 0
        last_grads = [torch.zeros_like(p) for p in self.model.parameters()]
        grad_logs = defaultdict(lambda: torch.zeros(1, device=self.device))
        grad_counts = defaultdict(lambda: torch.zeros(1, device=self.device))
        loss_logs = defaultdict(lambda: torch.zeros(1, device=self.device))
        loss_counts = defaultdict(lambda: torch.zeros(1, device=self.device))

        if self.params['ssl'] != 'none':
            self.single_print('train_loader_size', len(self.ood_data_loader), len(self.ood_dataset))
            loader = self.ood_data_loader
            if len(params.train_data_paths) > 0:
                train_loader = self.train_data_loader
                train_iterator = iter(train_loader)
        else:
            self.single_print('train_loader_size', len(self.train_data_loader), len(self.train_dataset))
            loader = self.train_data_loader

        for batch_idx, data in enumerate(loader):
            steps += 1
            if self.params['ssl'] != 'none':
                # OOD data
                inp_ood, file_index_ood, field_labels_ood, bcs_ood, tar_ood = map(lambda x: x.to(self.device), data) 
                tar_ood = tar_ood[:, :1] # only pick the density channel
                tar_ood_all = torch.cat([inp_ood[:, :, :1], tar_ood.unsqueeze(1)], dim=1)
                dset_type_ood = self.ood_dataset.sub_dsets[file_index_ood[0]].type
                pde_param_ood = self.ood_dataset.sub_dsets[file_index_ood[0]].pde_param
                inp_ood = rearrange(inp_ood, 'b t c h w -> t b c h w')
                # original training set data
                if len(self.params.train_data_paths) > 0:
                    try:
                        data_train = next(train_iterator)
                    except:
                        train_iterator = iter(train_loader)
                        data_train = next(train_iterator)
                    inp, file_index, field_labels, bcs, tar = map(lambda x: x.to(self.device), data_train) 
                    if inp.shape[1] + 1 < self.params.n_steps + self.params.rollout_train: continue # input temporal length is too short for rollout
            else:
                inp, file_index, field_labels, bcs, tar = map(lambda x: x.to(self.device), data) 
                if inp.shape[1] + 1 < self.params.n_steps + self.params.rollout_train: continue # input temporal length is too short for rollout

            if len(self.params.train_data_paths) > 0:
                tar_all = torch.cat([inp, tar.unsqueeze(1)], dim=1)
                dset_type = self.train_dataset.sub_dsets[file_index[0]].type
                pde_param = self.train_dataset.sub_dsets[file_index[0]].pde_param
                inp = rearrange(inp, 'b t c h w -> t b c h w')
                loss_counts[dset_type] += 1
            data_time += time.time() - data_start
            dtime = time.time() - data_start

            self.model.require_backward_grad_sync = ((1+batch_idx) % self.params.accum_grad == 0)
            with amp.autocast(self.params.enable_amp, dtype=self.mp_type):
                model_start = time.time()
                loss = 0
                raw_loss = 0
                if len(self.params.train_data_paths) > 0:
                    xx = inp[:self.params.n_steps]
                    for t in range(self.params.rollout_train):
                        tar = tar_all[:, self.params.n_steps+t]
                        # output = torch.clip(self.model(xx), 0 ,1)
                        output = self.model(xx)
                        spatial_dims = tuple(range(output.ndim))[2:] # Assume 0, 1, 2 are T, B, C
                        xx = torch.cat((xx[1:], output.unsqueeze(0)), dim=0)
                        residuals = output - tar
                        # Differentiate between log and accumulation losses
                        tar_norm = (1e-7 + tar.pow(2).mean(spatial_dims, keepdim=True))
                        raw_loss += ((residuals).pow(2).mean(spatial_dims, keepdim=True) / tar_norm) # relative
                    raw_loss /= self.params.rollout_train
                    # Scale loss for accum
                    loss = raw_loss.mean() / self.params.accum_grad

                if self.params['ssl'] == 'gt':
                    # output_ssl = self.model(inp_ood)[:, 0:1] # only pick the density channel
                    # spatial_dims = tuple(range(output_ssl.ndim))[2:] # Assume 0, 1, 2 are T, B, C
                    # residuals_ssl = output_ssl - tar_ood
                    # tar_norm = (1e-7 + tar_ood.pow(2).mean(spatial_dims, keepdim=True))
                    # raw_loss_ood = ((residuals_ssl).pow(2).mean(spatial_dims, keepdim=True) / tar_norm) # relative

                    raw_loss_ood = 0
                    _inp_ood = inp_ood[:self.params.n_steps]
                    for t in range(self.params.rollout_train):
                        _tar_ood = tar_ood_all[:, self.params.n_steps+t]
                        output = self.model(_inp_ood)
                        spatial_dims = tuple(range(output.ndim))[2:] # Assume 0, 1, 2 are T, B, C
                        _inp_ood = torch.cat((_inp_ood[1:], output.unsqueeze(0)), dim=0)
                        residuals = output - _tar_ood
                        # Differentiate between log and accumulation losses
                        _tar_norm = (1e-7 + _tar_ood.pow(2).mean(spatial_dims, keepdim=True))
                        raw_loss_ood += ((residuals).pow(2).mean(spatial_dims, keepdim=True) / _tar_norm) # [0, 0, 0, 0] 

                    # Scale loss for accum
                elif self.params['ssl'] == 'interp':
                    inp_ood_interp = rearrange(inp_ood, 't b c h w -> b c t h w')
                    inp_ood_interp = F.interpolate(inp_ood_interp, size=(inp_ood_interp.shape[2]+self.params.rollout_train, *inp_ood_interp.shape[3:]), mode='trilinear', align_corners=True)
                    inp_ood_interp = rearrange(inp_ood_interp, 'b c t h w -> t b c h w')
                    raw_loss_ood = 0
                    _inp_ood = inp_ood_interp[:self.params.n_steps]
                    for t in range(self.params.rollout_train):
                        _tar_ood = inp_ood_interp[self.params.n_steps+t]
                        # output_ssl = torch.clip(self.model(_inp_ood), 0 ,1)
                        output_ssl = self.model(_inp_ood)
                        spatial_dims = tuple(range(output_ssl.ndim))[2:] # Assume 0, 1, 2 are T, B, C
                        residuals_ssl = output_ssl - _tar_ood
                        # Differentiate between log and accumulation losses
                        tar_norm = (1e-7 + _tar_ood.pow(2).mean(spatial_dims, keepdim=True))
                        raw_loss_ood += ((residuals_ssl).pow(2).mean(spatial_dims, keepdim=True) / tar_norm) # relative
                    raw_loss_ood /= self.params.rollout_train
                    # Scale loss for accum
                loss_ood = raw_loss_ood.mean() / self.params.accum_grad # raw_loss_ood  if [0,0,0,0] delete .mean()
                loss += loss_ood
                loss_counts['ssl'] += 1

                forward_end = time.time()
                forward_time = forward_end-model_start
                # Logging
                with torch.no_grad():
                    # logs['train_l1'] += F.l1_loss(output, tar)
                    if len(self.params.train_data_paths) > 0:
                        loss_logs[dset_type] += raw_loss.sqrt().mean().item()
                    if self.params['ssl'] != 'none':
                        loss_logs['ssl'] += raw_loss_ood.sqrt().mean().item()
                # Scaler is no op when not using AMP
                self.gscaler.scale(loss).backward()
                backward_end = time.time()
                backward_time = backward_end - forward_end
                # Check gradient info if we're in debug mode
                if len(self.params.train_data_paths) == 0:
                    dset_type = 'ssl'
                if self.debug_grad and ((1+batch_idx) % self.params.accum_grad == 1):
                    with torch.no_grad():
                        gnorm = self.params.accum_grad * grad_norm(self.model.parameters())
                        grad_logs[dset_type] += gnorm
                        grad_counts[dset_type] += 1
                        last_grads = grad_clone(self.model.parameters())
                elif self.debug_grad:
                    with torch.no_grad():
                        new_last_grads = grad_clone(self.model.parameters())
                        new_grad = [p - g for p, g in zip(new_last_grads, last_grads)]
                        gnorm = self.params.accum_grad * param_norm(new_grad)
                        grad_logs[dset_type] += gnorm
                        grad_counts[dset_type] += 1
                        last_grads = new_last_grads
                if self.debug_grad and self.model.require_backward_grad_sync:
                    with torch.no_grad():
                        self.gscaler.unscale_(self.optimizer)
                        grad_diff = grad_norm(self.model.parameters())
                        porig = [p.clone() for p in self.model.parameters()]
                # Only take step once per accumulation cycle
                optimizer_step = 0
                if self.model.require_backward_grad_sync:
                    self.gscaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)
                    self.gscaler.step(self.optimizer)
                    self.gscaler.update()
                    if self.debug_grad:
                        if self.global_rank == 0:
                            pdiff = param_diff(self.model.parameters(), porig)
                            if self.params['ssl'] == 'none':
                                self.single_print('grad_norm', grad_diff, 'last_step_size', pdiff, 'loss', loss.item(), 'data_shape', inp.shape)
                            else:
                                self.single_print('grad_norm', grad_diff, 'last_step_size', pdiff, 'loss', (loss-loss_ood).item(), 'loss_ood', loss_ood.item(), 'data_shape', inp_ood.shape)
                    self.optimizer.zero_grad(set_to_none=True)
                    if self.scheduler is not None:
                        self.scheduler.step()
                    optimizer_step = time.time() - backward_end
                tr_time += time.time() - model_start
                if self.log_to_screen and batch_idx % self.params.log_interval == 0 and self.global_rank == 0:
                    if self.params['ssl'] == 'none':
                        # print(f"Epoch {self.epoch} Batch {batch_idx} Train Loss {loss_logs[dset_type] / loss_counts[dset_type]}")
                        self.single_print(f"Epoch {self.epoch} Batch {batch_idx} Train Loss {loss_logs[dset_type] / loss_counts[dset_type]}")
                    else:
                        # print(f"Epoch {self.epoch} Batch {batch_idx} Train Loss {loss_logs[dset_type] / loss_counts[dset_type]} OOD Loss {loss_logs['ssl'] / loss_counts['ssl']}")
                        self.single_print(f"Epoch {self.epoch} Batch {batch_idx} Train Loss {loss_logs[dset_type] / loss_counts[dset_type]} OOD Loss {loss_logs['ssl'] / loss_counts['ssl']}")
                data_start = time.time()
        logs = {k: v/steps for k, v in logs.items()}
        # If distributed, do lots of logging things
        if dist.is_initialized():
            for key in sorted(logs.keys()):
                dist.all_reduce(logs[key].detach()) 
                logs[key] = float(logs[key]/dist.get_world_size())
            for key in sorted(loss_logs.keys()):
                dist.all_reduce(loss_logs[key].detach())
            for key in sorted(grad_logs.keys()):
                dist.all_reduce(grad_logs[key].detach())
            for key in sorted(loss_counts.keys()):
                dist.all_reduce(loss_counts[key].detach())
            for key in sorted(grad_counts.keys()):
                dist.all_reduce(grad_counts[key].detach())
            
        for key in loss_logs.keys():
            if key != 'ssl':
                logs[f'{key}_{pde_param}/train_nrmse'] = loss_logs[key] / loss_counts[key]
            else:
                logs[f'{key}/train_nrmse'] = loss_logs[key] / loss_counts[key]
        for key in grad_logs.keys():
            if key != 'ssl':
                logs[f'{key}_{pde_param}/train_grad_norm'] = grad_logs[key] / grad_counts[key]
            else:
                logs[f'{key}/train_grad_norm'] = grad_logs[key] / grad_counts[key]

        self.iters += steps
        if self.global_rank == 0:
            logs['iters'] = self.iters
            logs['parameter norm'] = param_norm(self.model.parameters())
        self.single_print('all reduces executed!')

        return tr_time, data_time, logs


    def validate(self, epoch, train_logs, cutoff=999999999999):
        for _dataset in self.valid_datasets:
            if epoch==args.max_epochs-1 or epoch==-1:
                valid_logs = self.validate_one_epoch(_dataset, cutoff=cutoff, save_img=True)
            else:
                # Only do full validation set on last epoch - don't waste time
                valid_logs = self.validate_one_epoch(_dataset, cutoff=999999999999 if self.params['ssl'] != 'none' else 40, save_img=epoch==0 or (epoch+1)%10==0) # TODO scalarflow needs to save all outputs)
            
            train_logs.update(valid_logs)


    def validate_one_epoch(self, dataset, cutoff=-1, save_img=False):
        """
        Validates - for each batch just use a small subset to make it easier.

        Note: need to split datasets for meaningful metrics, but TBD. 
        """
        # Don't bother with full validation set between epochs
        self.model.eval()
        if cutoff:
            cutoff = cutoff
        else:
            cutoff = 999999999999
        self.single_print(f'STARTING VALIDATION!!! {dataset.type_list} - {dataset.pde_param_list}')
        self.single_print('val dataset size', len(dataset))
        with torch.inference_mode():
            # There's something weird going on when i turn this off.
            with amp.autocast(False, dtype=self.mp_type):
                field_labels = dataset.get_state_names()
                counts = {}
                logs = {} # 
                # Iterate through all folder specific datasets
                for subset_group in dataset.sub_dsets:
                    for subset in subset_group.get_per_file_dsets():
                        dset_type = subset.title
                        self.single_print('VALIDATING ON', dset_type)
                        # Create data loader for each
                        if self.params.use_ddp:
                            temp_loader = torch.utils.data.DataLoader(subset, batch_size=self.params.batch_size if self.params.ssl == 'none' else 1,
                                                                    num_workers=self.params.num_data_workers,
                                                                    sampler=torch.utils.data.distributed.DistributedSampler(subset,
                                                                                                                            drop_last=True)
                                    )
                        else:
                            # Seed isn't important, just trying to mix up samples from different trajectories
                            temp_loader = torch.utils.data.DataLoader(subset, batch_size=self.params.batch_size if self.params.ssl == 'none' else 1,
                                                                    num_workers=self.params.num_data_workers, 
                                                                    shuffle=False,
                                                                    generator=torch.Generator().manual_seed(0), # TODO need randomness
                                                                    drop_last=True)
                        count = 0
                        for batch_idx, data in enumerate(temp_loader):
                            
                            # Only do a few batches of each dataset if not doing full validation
                            if count > cutoff:
                                del(temp_loader)
                                break
                            count += 1
                            counts[f'{dset_type}_{subset.pde_param}'] = counts.get(f'{dset_type}_{subset.pde_param}', 0) + 1
                            inp, bcs, tar, files_idx = map(lambda x: x.to(self.device), data) 
                            # import pdb; pdb.set_trace()
                            if inp.shape[1] + 1 < self.params.n_steps + self.params.rollout_test: continue # input temporal length is too short for rollout
                            # Labels come from the trainset - useful to configure an extra field for validation sets not included
                            tar_all = torch.cat([inp, tar.unsqueeze(1)], dim=1)
                            inp = rearrange(inp, 'b t c h w -> t b c h w')
                            nmse = 0

                            xx = inp[:self.params.n_steps]
                            outputs = []


                            for t in range(self.params.rollout_test):
                                ##################################
                                tar = tar_all[:, self.params.n_steps+t]
                                output = self.model(xx)
                                xx = torch.cat((xx[1:], output.unsqueeze(0)), dim=0)
                                spatial_dims = tuple(range(output.ndim))[2:] # Assume 0, 1, 2 are T, B, C
                                if subset.type == 'scalarflow':
                                    tar = tar[:, :1]
                                    output = output[:, :1]
                                outputs.append(output.detach()) # collect output at every step
                                residuals = output - tar
                                nmse += (residuals.pow(2).mean(spatial_dims, keepdim=True) 
                                        / (1e-7 + tar.pow(2).mean(spatial_dims, keepdim=True))).sqrt()
                            
                            outputs = torch.clip(torch.cat(outputs, dim=0), 0, 1).detach().cpu() * 255 # T, C, H, W
                            outputs = F.interpolate(outputs, size=(self.flow_h, self.flow_w), mode='bilinear', align_corners=True)
                            outputs = F.pad(outputs, (self.flow_pad_w_left, self.flow_pad_w_right, self.flow_pad_h_top, self.flow_pad_h_bottom))
                            outputs = outputs.numpy()
                            
                            if subset.type == 'scalarflow' and save_img:
                                sample_file_name = subset.files_paths[files_idx[0].item()].split('/')[-1].split('.')[0]
                                _file_path = os.path.join(self.params.experiment_dir, "output_frames", sample_file_name)
                                os.makedirs(_file_path, exist_ok=True)
                                for _i in range(len(outputs)):
                                    _file_name = os.path.join(_file_path, f"{sample_file_name}_{(_i + batch_idx*temp_loader.batch_size+self.params.n_steps)}")
                                    np.save(_file_name+".npy", outputs[_i, 0]) # only pick density channel
                                    im = Image.fromarray(np.repeat(outputs[_i, 0, :, :, np.newaxis], 3, axis=2).astype(np.uint8))
                                    im.save(_file_name+".png")

                            nmse /= self.params.rollout_test
                            logs[f'{dset_type}_{subset.pde_param}/valid_nrmse'] = logs.get(f'{dset_type}_{subset.pde_param}/valid_nrmse',0) + nmse.mean()
                            # logs[f'{dset_type}/valid_rmse'] = (logs.get(f'{dset_type}/valid_mse',0) + residuals.pow(2).mean(spatial_dims).sqrt().mean())
                            logs[f'{dset_type}_{subset.pde_param}/valid_l1'] = (logs.get(f'{dset_type}_{subset.pde_param}/valid_l1', 0) 
                                                                + residuals.abs().mean())

                            for i, field in enumerate(dataset.subset_dict[subset.type]):
                                field_name = field_labels[field]
                                logs[f'{dset_type}_{subset.pde_param}/{field_name}_valid_nrmse'] = (logs.get(f'{dset_type}_{subset.pde_param}/{field_name}_valid_nrmse', 0) 
                                                                                + nmse[:, i].mean())
                                logs[f'{dset_type}_{subset.pde_param}/{field_name}_valid_rmse'] = (logs.get(f'{dset_type}_{subset.pde_param}/{field_name}_valid_rmse', 0) 
                                                                                    + residuals[:, i:i+1].pow(2).mean(spatial_dims).sqrt().mean())
                                logs[f'{dset_type}_{subset.pde_param}/{field_name}_valid_l1'] = (logs.get(f'{dset_type}_{subset.pde_param}/{field_name}_valid_l1', 0) 
                                                                            +  residuals[:, i].abs().mean())
                        else:
                            del(temp_loader)

            self.single_print('DONE VALIDATING - NOW SYNCING')
            for k, v in logs.items():
                dset_type = k.split('/')[0]
                logs[k] = v/counts[dset_type]

            # logs['valid_nrmse'] = 0
            # for dset_type in distinct_dsets:
            #     logs['valid_nrmse'] += logs[f'{dset_type}/valid_nrmse']/len(distinct_dsets)
            
            if dist.is_initialized():
                for key in sorted(logs.keys()):
                    dist.all_reduce(logs[key].detach()) # There was a bug with means when I implemented this - dont know if fixed
                    logs[key] = float(logs[key].item()/dist.get_world_size())
                    if 'rmse' in key:
                        logs[key] = logs[key]
            self.single_print('DONE SYNCING - NOW LOGGING')
        return logs               


    def train(self):
        # This is set up this way based on old code to allow wandb sweeps
        if self.params.log_to_wandb:
            if self.sweep_id:
                wandb.init(dir=self.params.experiment_dir)
                hpo_config = wandb.config.as_dict()
                self.params.update_params(hpo_config)
                params = self.params
            else:
                wandb.init(dir=self.params.experiment_dir, config=self.params, name=self.params.name, group=self.params.group, 
                           project=self.params.project, entity=self.params.entity, resume=not args.inference)
                
        if self.sweep_id and dist.is_initialized():
            param_file = f"temp_hpo_config_{os.environ['SLURM_JOBID']}.pkl"
            if self.global_rank == 0:
                with open(param_file, 'wb') as f:
                    pkl.dump(hpo_config, f)
            dist.barrier() # Stop until the configs are written by hacky MPI sub
            if self.global_rank != 0: 
                with open(param_file, 'rb') as f:
                    hpo_config = pkl.load(f)
            dist.barrier() # Stop until the configs are written by hacky MPI sub
            if self.global_rank == 0:
                os.remove(param_file)
            # If tuning batch size, need to go from global to local batch size
            if 'batch_size' in hpo_config:
                hpo_config['batch_size'] = int(hpo_config['batch_size'] // self.world_size)
            self.params.update_params(hpo_config)
            params = self.params

            self.initialize_data(self.params, inference=args.inference) # This is the annoying redundant part - but the HPs need to be set from wandb
            self.initialize_model(self.params) 
            if not args.inference:
                self.initialize_optimizer(self.params)
                self.initialize_scheduler(self.params)
        if self.global_rank == 0:
            summary(self.model)
        if self.params.log_to_wandb:
            wandb.watch(self.model)

        if args.inference or self.params['ssl'] != 'none':
            train_logs = {}
            valid_start = time.time()
            self.validate(-1, train_logs, cutoff=999999999999 if self.params['ssl'] != 'none' else 40) # TODO scalarflow needs to save all outputs
            if self.params.log_to_wandb:
                wandb.log(train_logs) 
            if self.global_rank == 0:
                cur_time = time.time()
                valid_nrmses = [value.item() for key, value in train_logs.items() if key.endswith('/valid_nrmse')]
                valid_keys = [key for key in train_logs.keys() if key.endswith('/valid_nrmse')]
                self.single_print(f'Time for valid: {cur_time-valid_start}')
                self.single_print('Valid Loss: {}'.format([(k, v) for k, v in zip(valid_keys, valid_nrmses)]))
            if args.inference: exit(0)
        
        self.single_print("Starting Training Loop...")
        # Actually train now, saving checkpoints, logging time, and logging to wandb
        best_valid_loss = 1.e6
        for epoch in range(self.startEpoch, args.max_epochs):
            if dist.is_initialized():
                self.train_sampler.set_epoch(epoch)
            start = time.time()
            
            if epoch >= 40 and int(epoch/20) > self.params['rollout_train']:
                if self.params['rollout_train'] < 8:
                    self.params['rollout_train'] = int(epoch/20)
                    self.initialize_data(self.params)   
                    print('Update rollout_train to ', self.params['rollout_train'])
            
            tr_time, data_time, train_logs = self.train_one_epoch()
            
            valid_start = time.time()
            
            if self.params['ssl'] == 'none' or (epoch==0 or (epoch+1)%10==0):
                self.validate(epoch, train_logs)

            post_start = time.time()
            train_logs['time/train_time'] = valid_start-start
            train_logs['time/valid_time'] = post_start-valid_start
            train_logs['time/train_data_time'] = data_time
            train_logs['time/train_compute_time'] = tr_time
            if self.params.log_to_wandb:
                wandb.log(train_logs) 
            gc.collect()
            torch.cuda.empty_cache()

            if self.global_rank == 0:
                if self.params.save_checkpoint:
                    self.save_checkpoint(self.params.checkpoint_path)
                if epoch % self.params.checkpoint_save_interval == 0:
                    self.save_checkpoint(self.params.checkpoint_path + f'_epoch{epoch}')
                
                train_nrmses = [value.item() for key, value in train_logs.items() if key.endswith('/train_nrmse')]
                train_keys = [key for key in train_logs.keys() if key.endswith('/train_nrmse')]
                valid_nrmses = [value.item() for key, value in train_logs.items() if key.endswith('/valid_nrmse')]
                valid_keys = [key for key in train_logs.keys() if key.endswith('/valid_nrmse')]
                mean_valid_nrmse = np.mean(valid_nrmses)
                if mean_valid_nrmse <= best_valid_loss:
                    self.save_checkpoint(self.params.best_checkpoint_path)
                    best_valid_loss = mean_valid_nrmse
                
                cur_time = time.time()
                self.single_print(f'Time for train {valid_start-start}. For valid: {post_start-valid_start}. For postprocessing:{cur_time-post_start}')
                self.single_print('Time taken for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
                self.single_print('Train Loss: {}'.format([(k, v) for k, v in zip(train_keys, train_nrmses)]))
                self.single_print('Valid Loss: {}'.format([(k, v) for k, v in zip(valid_keys, valid_nrmses)]))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--run_name", default='00', type=str)
    parser.add_argument("--use_ddp", action='store_true', help='Use distributed data parallel')
    parser.add_argument("--yaml_config", default='./config/base_config.yaml', type=str)
    parser.add_argument("--config", default=None, type=str)
    parser.add_argument("--sweep_id", default=None, type=str, help='sweep config from ./configs/sweeps.yaml')
    parser.add_argument("--inference", action='store_true', help='Inference only')
    parser.add_argument("--ckpt", default="", help='Path to checkpoint to load. For inference only.')
    # training options
    parser.add_argument("--frame_num", type=int, default=20, help='number of frames per video to use')
    parser.add_argument("--rollout_train", type=int, default=5, help='number of frames per video to use')
    parser.add_argument("--n_steps", type=int, default=7, help='number of frames per video to use')

    # for finetune (if params.freeze_former_layers is True)
    parser.add_argument("--finetune", action='store_true', help='finetune')
    parser.add_argument("--hy_dir", type=str, default='', help='hy frames dir')
    parser.add_argument("--num_hy_frames", type=int, default=0, help='number of frames from HY')

    parser.add_argument("--max_epochs", type=int, default=101, help='number of frames per video to use')

    args = parser.parse_args()
    params = YParams(os.path.abspath(args.yaml_config), config_name=args.config)

    params.use_ddp = args.use_ddp
    # Set up distributed training
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    global_rank = int(os.environ.get("RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    if args.use_ddp:
        dist.init_process_group("nccl")
        torch.cuda.set_device(local_rank) # Torch docs recommend just using device, but I had weird memory issues without setting this.
    device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu")

    # Modify params
    params['batch_size'] = int(params.batch_size//world_size)
    params['startEpoch'] = 0
    if args.sweep_id:
        jid = os.environ['SLURM_JOBID'] # so different sweeps dont resume
        expDir = os.path.join(params.exp_dir, args.sweep_id, args.config, str(args.run_name), jid)
    else:
        expDir = os.path.join(params.exp_dir, args.config, str(args.run_name))

    
    params['scalarflow_frame_num'] = args.frame_num
    params['rollout_train'] = args.rollout_train
    params['n_steps'] = args.n_steps

    params['old_exp_dir'] = expDir # I dont remember what this was for but not removing it yet
    params['experiment_dir'] = os.path.abspath(expDir)
    params['checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/ckpt.tar')
    params['best_checkpoint_path'] = os.path.join(expDir, 'training_checkpoints/best_ckpt.tar')
    params['old_checkpoint_path'] = os.path.join(params.old_exp_dir, 'training_checkpoints/best_ckpt.tar')
    num_f = params['scalarflow_frame_num']
  

    num_f = params['scalarflow_frame_num']     # frame_num        
    num_f_test = num_f+args.num_hy_frames+30 if num_f+args.num_hy_frames+30<120 else -1
    params['rollout_test'] = num_f_test - num_f
    

    # if not args.finetune:
    params.freeze_former_layers = True if args.finetune else False
    print("freeze_former_layers:", params.freeze_former_layers)
    
    if args.ckpt:
        params['vmae_pretrained'] = args.ckpt
    else:
        params['vmae_pretrained'] = False

    if not args.inference:
        print('rollout_train:',  params['rollout_train'], 'n_steps:', params['n_steps'])
    else:
        print('rollout_test:',  params['rollout_test'], 'n_steps:', params['n_steps'])

    if "scalarflow_frame_num" in params:
        print("Converting Scalarflow H5 files...")

        # test videos load all frames
        # frame_num_cutoff=params['scalarflow_frame_num'] ~ size_train
        # frame_num_cutoff=-1 ~ size_test

        bbox_train, size_train = calculate_crop_box(split='all', frame_num_cutoff=num_f, half_res=True)
        bbox_test, size_test = calculate_crop_box(split='all', frame_num_cutoff=num_f_test, half_res=True)#params['scalarflow_frame_num']+30
        
        save_path = 'swin'
        # train
        if args.num_hy_frames:
            scalarflow_path_trains = pinf_combine_frames_to_h5(split='train', frame_num_cutoff=num_f, target_shape=params['input_size'], bbox=bbox_test, save_path=save_path, hy_dir=args.hy_dir, hy_frames=args.num_hy_frames, half_res=True)
            scalarflow_path_tests = pinf_combine_frames_to_h5(split='test', frame_num_cutoff=num_f_test, target_shape=params['input_size'], bbox=bbox_test, save_path=save_path, hy_dir=args.hy_dir, hy_frames=args.num_hy_frames, half_res=True)
            
            
            if args.inference:
                scalarflow_path_tests = pinf_combine_frames_to_h5(split='all', frame_num_cutoff=num_f_test, target_shape=params['input_size'], bbox=bbox_test, save_path=save_path, hy_dir=args.hy_dir, hy_frames=args.num_hy_frames, half_res=True)  # if inference all use split='all'
        else:
            scalarflow_path_trains = pinf_frame_to_h5(split='train', frame_num_cutoff=num_f, target_shape=params['input_size'], save_path=save_path, bbox=bbox_train, half_res=True)
            scalarflow_path_tests = pinf_frame_to_h5(split='test', frame_num_cutoff=num_f_test, target_shape=params['input_size'], bbox=bbox_train, save_path=save_path, half_res=True)
            
            if args.inference:
                scalarflow_path_tests = pinf_frame_to_h5(split='all', frame_num_cutoff=num_f_test, target_shape=params['input_size'], bbox=bbox_test, save_path=save_path, half_res=True)  # if inference all use split='all'
        # validate during training, use the same size of scalarflow_path_trains
        # if finetue, use larger crop size
        (flow_h, flow_w), (flow_pad_h_top, flow_pad_h_bottom, flow_pad_w_left, flow_pad_w_right) = size_test if args.num_hy_frames or args.inference else size_train

        for _path in scalarflow_path_trains:
            params['ood_train_data_paths'].append([_path, 'scalarflow', '', 'train_'+_path.split('/')[-1]])

        for _path in scalarflow_path_tests: #scalarflow_path_tests_trains
            params['ood_valid_data_paths'].append([_path, 'scalarflow', '', 'test'])

    # Have rank 0 check for and/or make directory
    if  global_rank==0:
        if not os.path.isdir(expDir):
            os.makedirs(expDir)
            os.makedirs(os.path.join(expDir, 'training_checkpoints/'))
    params['resuming'] = False # True if os.path.isfile(params.checkpoint_path)  and args.inference  else False 

    # WANDB things
    params['name'] = str(args.run_name + ("_inference" if args.inference else ""))
    global logger
    if global_rank==0:
        logger = logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log'))
        # logging_utils.log_versions()
        params.log()

    if global_rank==0:
        logger = logging_utils.log_to_file(logger_name=None, log_filename=os.path.join(expDir, 'out.log'))
        # logging_utils.log_versions()
        params.log()

    params['log_to_wandb'] = (global_rank==0) and params['log_to_wandb']
    params['log_to_screen'] = (global_rank==0) and params['log_to_screen']
    torch.backends.cudnn.benchmark = False

    if global_rank == 0:
        hparams = ruamelDict()
        yaml = YAML()
        for key, value in params.params.items():
            hparams[str(key)] = str(value)
        if not args.inference:
            shutil.copyfile(args.yaml_config, os.path.join(expDir, 'hyperparams.yaml'))
    trainer = Trainer(params, global_rank, local_rank, device, sweep_id=args.sweep_id)

    setattr(trainer, "flow_h", flow_h)
    setattr(trainer, "flow_w", flow_w)
    setattr(trainer, "flow_pad_h_top", flow_pad_h_top)
    setattr(trainer, "flow_pad_h_bottom", flow_pad_h_bottom)
    setattr(trainer, "flow_pad_w_left", flow_pad_w_left)
    setattr(trainer, "flow_pad_w_right", flow_pad_w_right)

    if args.sweep_id and trainer.global_rank==0:
        logger.info(args.sweep_id, trainer.params.entity, trainer.params.project)
        wandb.agent(args.sweep_id, function=trainer.train, count=1, entity=trainer.params.entity, project=trainer.params.project) 
    else:
        import ipdb
        try:
            trainer.train()
        except Exception as e:
            print(e)
            ipdb.post_mortem()
    if params.log_to_screen:
        logger.info('DONE ---- rank %d'%global_rank)
