import getpass
import os
import shutil
import socket
import sys
import tempfile
import warnings
from contextlib import contextmanager
import torch

import wandb

from runner import Runner
from utilities import GeneralUtility

warnings.filterwarnings('ignore')

debug = "--debug" in sys.argv
defaults = dict(
    # System
    seed=1,

    # Data
    dataset='REDACTED_FOR_ANONYMITY',
    batch_size=1,

    # Architecture
    arch='swin_video_unet',  # Defaults to unet
    aggregation='attention',
    use_pretrained_model=False,
    temporal_skip_reduction="transformer_year",
    use_final_convs=False,
    downsample_per_year=True,
    slope_no_disturbance=-0.0,
    freeze_model = 'full',
    no_disturbance_factor=1,

    # Optimization
    optim='AdamW',  # Defaults to AdamW
    loss_name='combi_2heads',  # Defaults to shift_l1
    use_l2 = True, # Whether to use L2 loss in disturbance regression loss
    lambda_regression=3.0,
    full_disturbance_window=True,
    disturbance_indicator=-1,
    n_iterations=10,
    log_freq=1,
    initial_lr=0.0001,
    weight_decay=1e-5,
    use_overlapping_patches=False,
    overlap_lambda=1.0,
    overlap_size=40,

    # Efficiency
    fp16=False,
    num_workers_per_gpu=8,   # Defaults to 8
    prefetch_factor=None,
    # Other
    use_grad_clipping=False,
    use_weighted_sampler=None,  # Currently deactivated
    early_stopping=True,  # Flag for early stopping
    # model_checkpoint=os.path.join('models_trained', 'REDACTED_FOR_ANONYMITY'),  # Path to model checkpoint
    # model_checkpoint = "REDACTED_FOR_ANONYMITY,
    
    # Input size configuration
    use_reduced_input_size=96,  # If True, crop input from 256x256 to 64x64
    patch_size_time=1,  # Temporal patch size for SwinVideoUnet
    patch_size_image=1,  # Spatial patch size for SwinVideoUnet (will be overridden to 1 if use_reduced_input_size=True)
    
    # Scaling adjustments
    scale_adjust_1234=0.0,  # Adjustment for channels 1, 2, 3, 4.
    scale_adjust_6789=0.0,  # Adjustment for channels 6, 7, 8, 9.
    scale_adjust_0=0.0,     # Adjustment for channel 0.
    scale_adjust_51011=0.0, # Adjustment for channels 5, 10, 11.

    reduce_time=(28, 14, 7),
    window_size_temporal=2,
    window_size_spatial=6,
    
    # SwinVideoUnet architecture configuration
    encoder_depths=(6,4,4,6),
    decoder_depths=(4,6,8,16),
)

if not debug:
    # Set everything to None recursively
    defaults = GeneralUtility.fill_dict_with_none(defaults)

# Add the hostname to the defaults
defaults['computer'] = socket.gethostname()

# Configure wandb logging
wandb.init(
    config=defaults,
    project='REDACTED_FOR_ANONYMITY',  # automatically changed in sweep
    entity='REDACTED_FOR_ANONYMITY',  # automatically changed in sweep
)
config = wandb.config
config = GeneralUtility.update_config_with_default(config, defaults)

@contextmanager
def tempdir():
    "REDACTED_FOR_ANONYMITY"


with tempdir() as tmp_dir:
    "REDACTED_FOR_ANONYMITY"

    runner = Runner(config=config, tmp_dir=tmp_dir, debug=debug)
    runner.run()

    # Close wandb run
    wandb_dir_path = wandb.run.dir
    wandb.join()

    # Delete the local files
    if os.path.exists(wandb_dir_path):
        shutil.rmtree(wandb_dir_path)
