''' Copied from `train_single_level.py` but meant for 2D images for better visualization and debugging '''
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 
# make sure the parent of this folder is in path to be 
# able to access everything
from TransMorph.models.TransMorph import TransFeX
from TransMorph.models.unet3d import UNet2D, UNet3D, UNetEncoder3D
from TransMorph.models.lku import LKUNet, LKUEncoder
from TransMorph.models.configs_TransMorph import get_3DTransFeX_config
from solver.adam import multi_scale_warp_solver, multi_scale_diffeomorphic_solver, multi_scale_affine2d_solver
from solver.utils import gaussian_1d, img2v_3d, v2img_3d, separable_filtering
from solver.losses import NCC_vxm, DiceLossWithLongLabels, _get_loss_function_factory
from solver.losses import LocalNormalizedCrossCorrelationLoss
from omegaconf import OmegaConf
# logging
import wandb
import hydra
from model_utils import displacements_to_warps, downsample
from utils import set_seed, init_wandb, open_log, cleanup
from datasets.oasis import OASIS, OASISNeurite3D
import numpy as np
from scipy.ndimage import zoom
# torch.set_num_threads(1)
import argparse
import datetime

# ddp
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch import nn

# this is a global setting to stay compatible with scipy's grid sampling
from solver.adam import ALIGN_CORNERS as align_corners

def setup_ddp(rank, world_size, port=12355):
    # setup ddp environment
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    # dist.init_process_group("gloo", rank=rank, world_size=world_size, timeout=datetime.timedelta(seconds=1000))

def cleanup_ddp():
    dist.destroy_process_group()

# we will run ddp here
@hydra.main(config_path="configs", config_name="default.yaml")
def mainfunc(cfg):
    world_size = 4
    mp.spawn(main, args=(cfg, world_size,), nprocs=world_size, join=True)

def main(rank, cfg, world_size):
    setup_ddp(rank, world_size)
    print("Running DDP on rank", rank)
    torch.cuda.set_device(rank)
    print(cfg)

    model = nn.Sequential(
        nn.Linear(32, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, 10),
    ).cuda()
    mat = torch.randn(10, 32).cuda()
    model = DDP(model, device_ids=[rank])
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    rng = np.random.RandomState((100123*rank % 123147))
    for epoch in range(100):
        print("Epoch")
        for i in range(1000):
            x = torch.from_numpy(rng.randn(32, 32)).float().cuda()
            y = (x**2) @ mat.T  # [B, 10]
            out = model(x)
            optim.zero_grad()
            loss = F.mse_loss(out, y)
            loss.backward()
            optim.step()
            print(rank, loss.item())
        if rank == 0:
            torch.save(model.module.state_dict(), f"/tmp/model_{epoch}.pt")

    

if __name__ == '__main__':
    mainfunc()
