""" Main training script. It is mostly inspired from the repository
https://github.com/PolymathicAI/multiple_physics_pretraining """
import argparse
import os
import sys
from pathlib import Path
current_path = Path(os.getcwd())
sys.path.append(str(current_path))
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap as ruamelDict
from dadaptation import DAdaptAdam, DAdaptAdan
from adan_pytorch import Adan
from collections import OrderedDict
import wandb
import gc
from collections import defaultdict

from src.global_constants import *
from src.utils import is_debug, YParams, logging_utils, TimeTracker, standardize
from src.data import get_data_objects
from src.models import build_model



def add_weight_decay(params, weight_decay=1e-5, skip_list=()):
    """ From Ross Wightman at:
    https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/3 
    
    Goes through the parameter list and if the squeeze dim is 1 or 0 (usually means bias or scale) 
    then don't apply weight decay. 
    """
    decay = []
    no_decay = []
    for name, param in params:
        if not param.requires_grad:
            continue
        if (len(param.squeeze().shape) <= 1 or name in skip_list):
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.,},
        {'params': decay, 'weight_decay': weight_decay}
    ]


def param_norm(parameters):
    with torch.no_grad():
        total_norm = 0
        for p in parameters:
            total_norm += p.pow(2).sum().item()
        return total_norm**.5


def grad_norm(parameters):
    with torch.no_grad():
        total_norm = 0
        for p in parameters:
            if p.grad is not None:
                total_norm += p.grad.pow(2).sum().item()
        return total_norm**.5


class Trainer:

    def __init__(self, params, global_rank, local_rank, device):
        self.device = device
        self.params = params
        self.base_dtype = torch.float
        self.global_rank = global_rank
        self.local_rank = local_rank
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.log_to_screen = params.log_to_screen

        # Basic setup
        self.start_epoch = 0
        self.epoch = 0
        if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
            self.mp_type = torch.bfloat16
        else:
            self.mp_type = torch.half
        self.true_time = params.true_time

        self.iters = 0
        self.initialize_data(self.params)
        self.initialize_model(self.params)
        self.initialize_optimizer(self.params)
        if params.resuming:
            print("Loading checkpoint %s"%params.checkpoint_path)
            print('LOADING CHECKPOINTTTTTT')
            self.restore_checkpoint(params.checkpoint_path)
        if params.pretrained_MPP:
            assert not params.resuming
            print("Loading MPP checkpoint")
            self.restore_MPP()
            self.iters = 0
            self.start_epoch = 0
        if 'finetune' in params and params.finetune:
            print("Loading checkpoint")
            if self.params.model == "avit":
                self.restore_avit()
            if self.params.model == "icnpde":
                self.restore_oml()
            if self.params.model == "ardiff":
                self.restore_ardiff()
            self.iters = 0
            self.start_epoch = 0
        if not params.resuming and params.pretrained:
            print("Starting from pretrained model at %s"%params.pretrained_ckpt_path)
            self.restore_checkpoint(params.pretrained_ckpt_path)
            self.iters = 0
            self.start_epoch = 0
        self.initialize_scheduler(self.params)

    def initialize_data(self, params):
        
        train_data_paths = params.train_data_paths or []
        valid_data_paths = params.valid_data_paths or []

        if self.log_to_screen:
            print(f"Initializing data on rank {self.global_rank}")
        self.train_dataset, self.train_sampler, self.train_data_loader = get_data_objects(
            train_data_paths, params.batch_size, params.epoch_size, params.train_val_test,
            params.n_past, params.n_future, dist.is_initialized(), params.num_data_workers,
            rank=self.global_rank, world_size=self.world_size, split='train', 
            data_params=params.train_data_params, template_name=params.train_template, 
            mode=params.train_mode
        )
        self.valid_dataset, _, self.valid_data_loader = get_data_objects(
            valid_data_paths, params.batch_size, params.epoch_size, params.train_val_test,
            params.n_past, params.n_future, dist.is_initialized(), params.num_data_workers,
            rank=self.global_rank, world_size=self.world_size, split='val', 
            data_params=params.valid_data_params, template_name=params.valid_template, 
            mode=params.valid_mode
        )
        if dist.is_initialized():
            self.train_sampler.set_epoch(0)

    def initialize_model(self, params):

        print(f"Initializing model on rank {self.global_rank}")
        
        self.model = build_model(params).to(self.device, dtype=self.base_dtype)

        if dist.is_initialized():
            self.model = DistributedDataParallel(
                self.model, device_ids=[self.local_rank], 
                output_device=[self.local_rank], find_unused_parameters=True,
            )

        n_params = sum([p.numel() for p in self.model.parameters()])
        
        self.single_print(f'Model parameter count: {n_params:,}')
        if self.params.model == "icnpde" and ('finetune' not in self.params or not self.params.finetune) and not self.params.in_channels is None:
            if dist.is_initialized():
                self.single_print(f"Operator class params: {self.model.module.out_chans:,}")
            else:
                self.single_print(f"Operator class params: {self.model.out_chans:,}")

    def initialize_optimizer(self, params): 

        # parameters_standard = [(n, p) for n, p in self.model.named_parameters()]
        parameters_standard = self.model.named_parameters()
        parameters = add_weight_decay(parameters_standard, self.params.weight_decay) # Dont use weight decay on bias/scaling terms
        if params.optimizer == 'adam':
            if self.params.learning_rate < 0:
                self.optimizer =  DAdaptAdam(parameters, lr=1., growth_rate=1.05, log_every=100, decouple=True)
            else:
                self.optimizer = optim.AdamW(parameters, lr=params.learning_rate)
        elif params.optimizer == 'adan':
            if self.params.learning_rate < 0:
                self.optimizer =  DAdaptAdan(parameters, lr=1., growth_rate=1.05, log_every=100)
            else:
                self.optimizer = Adan(parameters, lr=params.learning_rate)
        elif params.optimizer == 'sgd':
            self.optimizer = optim.SGD(self.model.parameters(), lr=params.learning_rate, momentum=0.9)
        else: 
            raise ValueError(f"Optimizer {params.optimizer} not supported")
        self.gscaler = torch.amp.GradScaler(enabled=(self.mp_type == torch.half and params.enable_amp))

    def initialize_scheduler(self, params):

        if params.scheduler_epochs > 0:
            sched_epochs = params.scheduler_epochs
        else:
            sched_epochs = params.max_epochs
        self.scheduler = None
        if params.scheduler == 'cosine':
            if self.params.learning_rate < 0:
                self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                    self.optimizer, last_epoch=(self.start_epoch*params.epoch_size)-1,
                    T_max=sched_epochs*params.epoch_size, eta_min=1e-5
                )
            else:
                k = params.warmup_steps
                if (self.start_epoch*params.epoch_size) < k:
                    warmup = torch.optim.lr_scheduler.LinearLR(self.optimizer, start_factor=.01, end_factor=1.0, total_iters=k)
                    decay = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, eta_min=params.learning_rate/100, T_max=sched_epochs)
                    self.scheduler = torch.optim.lr_scheduler.SequentialLR(self.optimizer, [warmup, decay], [k], last_epoch=(params.epoch_size*self.start_epoch)-1)
                else:
                    self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, eta_min=params.learning_rate/100, T_max=sched_epochs)

    def single_print(self, *text):
        if self.global_rank == 0 and self.log_to_screen:
            print(' '.join([str(t) for t in text]))

    def restore_checkpoint(self, checkpoint_path):
        """ Load model/opt from path """
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        try:
            self.model.load_state_dict(checkpoint['model_state'])
        except:
            new_state_dict = OrderedDict()
            for key, val in checkpoint['model_state'].items():
                name = key[7:]
                new_state_dict[name] = val
            self.model.load_state_dict(new_state_dict)
        self.iters = checkpoint['iters']
        if self.params.resuming:  #restore checkpoint is used for finetuning as well as resuming. If finetuning (i.e., not resuming), restore checkpoint does not load optimizer state, instead uses config specified lr.
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.start_epoch = checkpoint['epoch']
            self.epoch = self.start_epoch
        if self.params.pretrained:
            if self.params.freeze_middle:
                self.model.module.freeze_middle()
            elif self.params.freeze_processor:
                self.model.module.freeze_processor()
            else:
                self.model.module.unfreeze()
        checkpoint = None

    def restore_MPP(self):
        """ Load model/opt from path """
        if self.params.resuming:
            raise ValueError("Cannot resume and restore MPP at the same time")
        checkpoint_path = ""  # should be replaced by the provided path to the model checkpoint
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        pretrained_dict = OrderedDict()
        for key, val in checkpoint['model_state'].items():
            name = key[7:]
            pretrained_dict[name] = val

        # add the weights manually
        model_dict = self.model.state_dict()
        def replace_check(k_init, k_pretrained):
            assert model_dict[k_init].shape == pretrained_dict[k_pretrained].shape
            model_dict[k_init] = pretrained_dict[k_pretrained]
        # populate encoder weights
        replace_check('tokenizer.encoder.1.conv.weight', 'embed2d.in_proj.0.weight')
        replace_check('tokenizer.encoder.2.weight', 'embed2d.in_proj.1.weight')
        replace_check('tokenizer.encoder.4.conv.weight', 'embed2d.in_proj.3.weight')
        replace_check('tokenizer.encoder.5.weight', 'embed2d.in_proj.4.weight')
        replace_check('tokenizer.encoder.7.conv.weight', 'embed2d.in_proj.6.weight')
        replace_check('tokenizer.encoder.8.weight', 'embed2d.in_proj.7.weight')
        # populate time-space attention blocks weights
        for k in model_dict:
            if 'blocks' in k:
                replace_check(k, k)
        # populate decoder weights
        replace_check('tokenizer.decoder.0.convT.weight', 'debed2d.out_proj.0.weight')
        replace_check('tokenizer.decoder.1.weight', 'debed2d.out_proj.1.weight')
        replace_check('tokenizer.decoder.3.convT.weight', 'debed2d.out_proj.3.weight')
        replace_check('tokenizer.decoder.4.weight', 'debed2d.out_proj.4.weight')

        # one step more: retrieve the remaining weights, specifically learned 
        # on compNS128 (PDEBench) to make sure we did everything correctly
        if True:
            model_dict['tokenizer.encoder.0.weight'] = pretrained_dict['space_bag.weight'][:,[6, 7, 8, 9]].unsqueeze(-1).unsqueeze(-1)
            model_dict['tokenizer.encoder.0.bias'] = pretrained_dict['space_bag.bias']
            model_dict['tokenizer.decoder.6.convT.weight'] = pretrained_dict['debed2d.out_kernel'][:,[6, 7, 8, 9],:,:]
            model_dict['tokenizer.decoder.6.convT.bias'] = pretrained_dict['debed2d.out_bias'][[6, 7, 8, 9]]

        # populate the weights
        self.model.load_state_dict(model_dict)
        self.iters = checkpoint['iters']
        checkpoint = None

    def restore_avit(self):
        """ Load model from path """
        if self.params.resuming:
            raise ValueError("Cannot resume and restore at the same time")
        
        checkpoint_path = ""  # should be replaced by the provided path to the model checkpoint
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        pretrained_dict = checkpoint['model_state']
        
        # add the weights manually # this will effectively replace all the weights except the ones associated to the new datasets
        model_dict = self.model.state_dict()

        for k in pretrained_dict:
            if ('out_kernel1d' in k) or ('out_kernel2d' in k) or ('space_bag.weight' in k):
                n_states = pretrained_dict[k].shape[1]
                model_dict[k][:,:n_states,...] = pretrained_dict[k]
                continue
            if ('out_bias1d' in k) or ('out_bias2d' in k):
                n_states = pretrained_dict[k].shape[0]
                model_dict[k][:n_states] = pretrained_dict[k]
                continue
            model_dict[k] = pretrained_dict[k]

        # populate the weights
        self.model.load_state_dict(model_dict)
        checkpoint = None

        # freeze everything that is not related to shearflow or euler dataset
        if self.params.finetune_freeze:
            for name, param in self.model.named_parameters():
                if 'shearflow' in name or 'euler' in name or 'space_bag' in name \
                    or 'out_kernel2d' in name or 'out_bias2d' in name \
                    or 'out_kernel1d' in name or 'out_bias1d' in name:
                    continue
                param.requires_grad = False

    def restore_oml(self):
        """ Load model from path """
        if self.params.resuming:
            raise ValueError("Cannot resume and restore at the same time")
        
        checkpoint_path = ""  # should be replaced by the provided path to the model checkpoint
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        pretrained_dict = checkpoint['model_state']
        
        # add the weights manually # this will effectively replace all the weights except the ones associated to the new datasets
        model_dict = self.model.state_dict()
        for k in pretrained_dict:
            if 'space_bag.weight' in k:
                n_states = pretrained_dict[k].shape[1]
                model_dict[k][:,:n_states] = pretrained_dict[k]
                continue
            model_dict[k] = pretrained_dict[k]

        # populate the weights
        self.model.load_state_dict(model_dict)
        checkpoint = None

        # freeze everything that is not related to shearflow or euler dataset
        if self.params.finetune_freeze:
            for name, param in self.model.named_parameters():
                if 'shearflow' in name or 'euler' in name or 'space_bag' in name:
                    continue
                param.requires_grad = False

    def restore_ardiff(self):
        if self.params.resuming:
            raise ValueError("Cannot resume and restore at the same time")
        
        checkpoint_path = ""  # should be replaced by the provided path to the model checkpoint
        checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(self.local_rank))
        pretrained_dict = checkpoint['model_state']

        model_dict = self.model.state_dict()
        for k in pretrained_dict:  
            model_dict[k] = pretrained_dict[k]

        # populate the weights
        self.model.load_state_dict(model_dict)
        checkpoint = None
       

    def save_checkpoint(self, checkpoint_path, model=None):
        """ Save model and optimizer to checkpoint """
        if not model:
            model = self.model
        d = {
            'iters': self.epoch*self.params.epoch_size, 
            'epoch': self.epoch, 
            'model_state': model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }
        torch.save(d, checkpoint_path)

    @staticmethod
    def mse_loss(y_ref, y):
        var = 1e-7 + y_ref.var((-1,-2), keepdim=True)
        loss = F.gaussian_nll_loss(y, y_ref, torch.ones_like(y)*var, eps=1e-8, reduction='mean')
        with torch.no_grad():
            residual = y - y_ref
            norm_ref = 1e-7 + y_ref.pow(2).mean((-1,-2), keepdim=True)
            raw_loss = residual.pow(2.0).mean((-1,-2), keepdims=True) / norm_ref
        return loss, raw_loss

    def train_one_epoch(self):
        tt = TimeTracker()
        tt.track("data", "training")
        self.epoch += 1
        self.model.train()
        logs = {
            'train_nrmse': torch.zeros(1).to(self.device, dtype=self.base_dtype),
            'train_l1': torch.zeros(1).to(self.device, dtype=self.base_dtype),
        }
        steps = 0
        grad_logs = defaultdict(lambda: torch.zeros(1, device=self.device))
        grad_counts = defaultdict(lambda: torch.zeros(1, device=self.device))
        loss_logs = defaultdict(lambda: torch.zeros(1, device=self.device))
        loss_counts = defaultdict(lambda: torch.zeros(1, device=self.device))
        dset_counts = {}
        self.single_print("--------")
        self.single_print('train_loader_size', len(self.train_data_loader), len(self.train_dataset))
        for batch_idx, batch in enumerate(self.train_data_loader):
            if batch_idx >= self.params.epoch_size:  # certain dataloaders are not restricted in length
                break
            steps += 1
            tt.track("data", "training", "data_batch")
            if self.true_time:
                torch.cuda.synchronize()

            inp, tar = batch['input_fields'].to(self.device), batch['output_fields'].to(self.device)
            dset_name = batch['name'][0]
            state_labels = None
            if self.params.train_mode == "PDEBench" or ('finetune' in self.params and self.params.finetune):
                state_labels = batch['field_labels'].to(self.device)

            dset_counts[dset_name] = dset_counts.get(dset_name, 0) + 1
            loss_counts[dset_name] += 1

            # whether the model weights should be updated this batch
            self.model.require_backward_grad_sync = ((1+batch_idx) % self.params.accum_grad == 0)
            with torch.amp.autocast("cuda", enabled=self.params.enable_amp, dtype=self.mp_type):
                
                # forward 
                tt.track("forward", "training", "forw_batch")
                if self.params.model != "ardiff":
                    output, metadata = self.model(
                        inp, predict_normed=False, n_future_steps=self.params.n_future, 
                        state_labels=state_labels[0] if state_labels is not None else None,
                        dset_name=dset_name
                    )
                    tar = (tar - metadata['mean']) / metadata['std']  # normalize tar

                    # loss
                    tt.track("loss", "training", "loss_batch")
                    # g o psi o f(X) = x_{t+T+1}
                    loss, loss_raw = self.mse_loss(tar, output)
                    loss = loss / self.params.accum_grad

                    # logs
                    tt.track("logs", "training")
                    with torch.no_grad():
                        log_nrmse = loss_raw.sqrt().mean()
                        logs['train_nrmse'] += log_nrmse
                        loss_logs[dset_name] += loss.item()
                        loss_print = log_nrmse.item()

                else:
                    # the diffusion model uses as context 2 steps in the past
                    inp = inp[:,[-2,-1],...]

                    # preprocessing
                    spatial_dims = tuple(range(3,inp.squeeze(-1,-2).ndim))
                    inp, mean, std = standardize(inp, dims=(1,*spatial_dims), return_stats=True)
                    tar = (tar - mean) / std  # normalize tar

                    # t-2, t-1 steps as conditioning fed as b x (2c) x h x w tensor
                    conditionning = torch.cat([inp[:,[-2],...], inp[:,[-1],...]], dim=2)

                    # forward
                    noise, predictedNoise = self.model(conditionning, tar, self.params.n_future, dset_name=dset_name)  # seems strange to feed the output to the model, but this is what they do

                    # loss
                    tt.track("loss", "training", "loss_batch")
                    loss = F.smooth_l1_loss(noise, predictedNoise)
                    loss = loss / self.params.accum_grad

                    # logs
                    tt.track("logs", "training")
                    with torch.no_grad():
                        logs['train_l1'] += F.l1_loss(noise, predictedNoise)
                        loss_logs[dset_name] += loss.item()
                        loss_print = loss.item()

                # backward
                tt.track("backward", "training", "back_batch")
                self.gscaler.scale(loss).backward()  # Scaler is no op when not using AMP

                if self.true_time:
                    torch.cuda.synchronize()

                # gradient step
                tt.track("gradient_step", "training", "optim_batch")
                if self.model.require_backward_grad_sync:  # Only take step once per accumulation cycle
                    grad_logs[dset_name] += grad_norm(self.model.parameters())
                    grad_counts[dset_name] += 1
                    # clip the gradients 
                    if self.params.gnorm is not None:
                        nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.params.gnorm)
                    self.gscaler.step(self.optimizer)
                    self.gscaler.update()
                    self.optimizer.zero_grad(set_to_none=True)
                    if self.scheduler is not None:
                        self.scheduler.step()
                    if self.true_time:
                        torch.cuda.synchronize()

                # logs
                if self.true_time:
                    torch.cuda.synchronize()
                logs['learning_rate'] = self.optimizer.param_groups[0]['lr']
                if self.log_to_screen and batch_idx % self.params.log_interval == 0 and self.global_rank == 0:
                    print(f"Epoch {self.epoch} Batch {batch_idx} Train Loss {loss_print:.3f}")
                    print('Total Times. Batch: {}, Rank: {}, Data Shape: {}, Data time: {:.2f}, Forward: {:.2f}, Backward: {:.2f}, Optimizer: {:.2f}'.format(
                        batch_idx, self.global_rank, list(inp.shape), tt.get("data_batch"), tt.get("forw_batch"), tt.get("back_batch"), tt.get("optim_batch"))
                    )
                tt.reset("data_batch", "forw_batch", "back_batch", "loss_batch", "optim_batch")

        # logs
        logs = {k: v/steps if k == 'learning_rate' else v for k, v in logs.items()}
        if dist.is_initialized():
            for key in sorted(logs.keys()):
                if key == "learning_rate":
                    continue
                dist.all_reduce(logs[key].detach())
                logs[key] = float(logs[key]/dist.get_world_size())
            for key in sorted(loss_logs.keys()):
                dist.all_reduce(loss_logs[key].detach())
            for key in sorted(grad_logs.keys()):
                dist.all_reduce(grad_logs[key].detach())
            for key in sorted(loss_counts.keys()):
                dist.all_reduce(loss_counts[key].detach())
            for key in sorted(grad_counts.keys()):
                dist.all_reduce(grad_counts[key].detach())

        times = tt.stop()
        for k in times:
            logs[f'time/{k}'] = times[k] / steps

        for key in loss_logs.keys():
            logs[f'{key}/train_nrmse'] = loss_logs[key] / loss_counts[key]
        for key in grad_logs.keys():
            logs[f'{key}/train_grad_norm'] = grad_logs[key] / grad_counts[key]
        self.iters += steps
        if self.global_rank == 0:
            logs['iters'] = self.iters
            logs['parameter norm'] = param_norm(self.model.parameters())
        self.single_print('all reduces executed!')

        return times["training"], times["data"], logs

    def validate_one_epoch(self, cutoff):
        """
        Validates - for each batch just use a small subset to make it easier.

        Note: need to split datasets for meaningful metrics, but TBD. 
        """
        # Don't bother with full validation set between epochs
        self.model.eval()
        self.single_print('STARTING VALIDATION')
        with torch.inference_mode():
            with torch.amp.autocast("cuda", enabled=False, dtype=self.mp_type):
                if self.params.valid_mode == "sf_euler":
                    sub_dsets = [self.valid_dataset]
                else:
                    sub_dsets = self.valid_dataset.sub_dsets
                logs = {}
                distinct_dsets = [dset.dataset_name for dset in sub_dsets]
                counts = {dset: 0 for dset in distinct_dsets}
                # iterate over the validation datasets
                for subset in sub_dsets:
                    dset_name = subset.dataset_name
                    if self.params.use_ddp:
                        val_loader = torch.utils.data.DataLoader(
                            subset, batch_size=self.params.batch_size,
                            num_workers=self.params.num_data_workers,
                            sampler=torch.utils.data.distributed.DistributedSampler(subset, drop_last=True)
                        )
                    else:
                        val_loader = torch.utils.data.DataLoader(
                            subset, batch_size=self.params.batch_size,
                            num_workers=self.params.num_data_workers, 
                            shuffle=True, generator=torch.Generator().manual_seed(0), drop_last=True
                        )
                    count = 0
                    for batch_idx, batch in enumerate(val_loader):
                        # Only do a few batches of each dataset if not doing full validation
                        if count >= cutoff:  # validating on burgers equations is extremely long
                            del(val_loader)
                            break
                        count += 1
                        counts[dset_name] += 1

                        inp, tar = batch['input_fields'].to(self.device), batch['output_fields'].to(self.device)
                        state_labels = None
                        if self.params.train_mode == "PDEBench":
                            state_labels = torch.tensor(
                                self.train_dataset.subset_dict.get(subset.get_name(), [-1]*len(self.valid_dataset.subset_dict[subset.get_name()])),
                                device=self.device
                            ).unsqueeze(0)
                        elif ('finetune' in self.params and self.params.finetune):
                            state_labels = batch['field_labels'].to(self.device)  # shearflow and euler labels are created at the level of the dataset
                        
                        if self.params.model != "ardiff":
                            output, _ = self.model(
                                inp, predict_normed=True, n_future_steps=self.params.n_future,
                                state_labels=state_labels[0] if state_labels is not None else None,
                                dset_name=dset_name
                            )
                        else:
                            # the diffusion model uses as context 2 steps in the past
                            inp = inp[:,[-2,-1],...]
                            # preprocessing
                            spatial_dims = tuple(range(3,inp.squeeze(-1,-2).ndim))
                            inp, mean, std = standardize(inp, dims=(1,*spatial_dims), return_stats=True)
                            # t-2, t-1 steps as conditioning
                            conditionning = torch.cat([inp[:,[-2],...], inp[:,[-1],...]], dim=2)
                            # t step  # won't be used anyway
                            data = inp[:,[-1],...]
                            output = self.model(conditionning, data, n_future_steps=self.params.n_future, dset_name=dset_name)
                            # unnormalize
                            output = output * std + mean

                        # loss
                        residuals = output - tar
                        spatial_dims = tuple(range(residuals.ndim))[3:] # Assume 0, 1 are B, C
                        nrmse = (residuals.pow(2).mean(spatial_dims, keepdim=True) / (1e-7 + tar.pow(2).mean(spatial_dims, keepdim=True))).sqrt()#.mean()

                        logs[f'{dset_name}/valid_nrmse'] = logs.get(f'{dset_name}/valid_nrmse',0) + nrmse.mean()
                        logs[f'{dset_name}/valid_rmse'] = logs.get(f'{dset_name}/valid_rmse',0) + residuals.pow(2).mean(spatial_dims).sqrt().mean()
                        logs[f'{dset_name}/valid_l1'] = logs.get(f'{dset_name}/valid_l1', 0) + residuals.abs().mean()

            self.single_print('DONE VALIDATING - NOW SYNCING')
            
            # divide by number of batches
            for k, v in logs.items():
                dset_name = k.split('/')[0]
                logs[k] = v / counts[dset_name]

            # # replace keys <>
            # average nrmse across datasets
            logs['valid_nrmse'] = 0
            for dset_name in distinct_dsets:
                logs['valid_nrmse'] += logs[f'{dset_name}/valid_nrmse']/len(distinct_dsets)

            if dist.is_initialized():
                for key in sorted(logs.keys()):
                    dist.all_reduce(logs[key].detach())
                    logs[key] = float(logs[key].item()/dist.get_world_size())
                    if 'rmse' in key:
                        logs[key] = logs[key]
            self.single_print('DONE SYNCING - NOW LOGGING')
        return logs

    def train(self):

        if self.params.log_to_wandb:
            wandb.init(
                dir=self.params.experiment_dir, config=self.params, 
                name=self.params.name, group=self.params.group,
                project=self.params.project, entity=self.params.entity, 
                resume=True
            )
        if self.params.log_to_wandb:
            wandb.watch(self.model)
        self.single_print("Starting Training Loop...")
    
        # Actually train now, saving checkpoints, logging time, and logging to wandb
        best_valid_loss = 1.e6
        for epoch in range(self.start_epoch, self.params.max_epochs):
            if dist.is_initialized():
                self.train_sampler.set_epoch(epoch)
            start = time.time()

            tr_time, data_time, train_logs = self.train_one_epoch()
            
            valid_start = time.time()
            
            # decide whether to do a small/medium/complete validation
            val_cutoff = self.params.val_cutoff
            if epoch == self.params.max_epochs - 1:
                val_cutoff = 999
            if self.params.debug:
                val_cutoff = 1
            valid_logs = self.validate_one_epoch(val_cutoff)
            
            post_start = time.time()
            train_logs.update(valid_logs)
            train_logs['time/train_time'] = valid_start - start
            train_logs['time/train_data_time'] = data_time
            train_logs['time/train_compute_time'] = tr_time
            train_logs['time/valid_time'] = post_start - valid_start
            if self.params.log_to_wandb:
                wandb.log(train_logs) 
            gc.collect()
            torch.cuda.empty_cache()

            if self.global_rank == 0:
                if self.params.save_checkpoint:
                    self.save_checkpoint(self.params.checkpoint_path)
                if epoch % self.params.checkpoint_save_interval == 0:
                    self.save_checkpoint(self.params.checkpoint_path + f'_epoch{epoch}')
                if valid_logs['valid_nrmse'] <= best_valid_loss:
                    self.save_checkpoint(self.params.best_checkpoint_path)
                    best_valid_loss = valid_logs['valid_nrmse']
                
                cur_time = time.time()
                self.single_print(f'Time for train {valid_start-start:.2f}. For valid: {post_start-valid_start:.2f}. For postprocessing:{cur_time-post_start:.2f}')
                self.single_print('Time taken for epoch {} is {:.2f} sec'.format(1+epoch, time.time()-start))
                self.single_print('Train loss: {}. Valid loss: {}'.format(train_logs['train_nrmse'], valid_logs['valid_nrmse']))


if __name__ == '__main__':

    print(f"DEBUG : {is_debug()}")

    # arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--use_ddp", action='store_true', help='Use distributed data parallel')
    parser.add_argument("--yaml_config", default='_debug.yaml', type=str)
    args = parser.parse_args()

    # config 
    params = YParams(CONFIG_PATH/"_base.yaml")
    refined_params = YParams(CONFIG_PATH/args.yaml_config)
    params.update_params(refined_params.params)
    if is_debug():
        debug_params = YParams(CONFIG_PATH/"_debug.yaml")
        params.update_params(debug_params.params)
    params['debug'] = is_debug()
    params['use_ddp'] = args.use_ddp

    # set up distributed training
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    global_rank = int(os.environ.get("RANK", 0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if args.use_ddp:
        dist.init_process_group("nccl")  # backend for nvidia gpus, multi-node, multi-gpu
    torch.cuda.set_device(local_rank)
        
    device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu")

    # modify params
    params['batch_size'] = int(params.batch_size//world_size)
    params['start_epoch'] = 0
    exp_dir = Path(params.exp_dir) / params.run_name
    params['experiment_dir'] = str(exp_dir)
    params['checkpoint_path'] = str(exp_dir / 'training_checkpoints' / 'ckpt.tar')
    # params['checkpoint_path'] = str(exp_dir / 'training_checkpoints' / 'ckpt.tar_epoch270')
    # params['checkpoint_path'] = str(exp_dir / 'training_checkpoints' / 'ckpt.tar_epoch70')
    params['best_checkpoint_path'] = str(exp_dir / 'training_checkpoints' / 'best_ckpt.tar')

    # Have rank 0 check for and/or make directory
    if global_rank == 0:
        if not exp_dir.exists():
            exp_dir.mkdir(parents=True)
            (exp_dir / 'training_checkpoints').mkdir(parents=True)
    params['resuming'] = Path(params.checkpoint_path).is_file()

    # wandb setup
    params['name'] = params.run_name
    if global_rank==0:
        logging_utils.log_to_file(logger_name=None, log_filename=exp_dir/'out.log')
        logging_utils.log_versions()
        params.log()

    params['log_to_wandb'] = (global_rank==0) and params['log_to_wandb']
    params['log_to_screen'] = (global_rank==0) and params['log_to_screen']

    if global_rank == 0:  # save config for this run
        hparams = ruamelDict()
        yaml = YAML()
        for key, value in params.params.items():
            hparams[str(key)] = value
        with open(exp_dir/'hyperparams.yaml', 'w') as hpfile:
            yaml.dump(hparams, hpfile)

    # start training
    trainer = Trainer(params, global_rank, local_rank, device)
    trainer.train()
    if params.log_to_screen:
        print('DONE ---- rank %d'%global_rank)