# Referred to: https://github.com/PrieureDeSion/drive-any-robot

import wandb
import argparse
import yaml
import time
import os
import random
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
from utils import load_model, get_saved_optimizer, train_eval_loop
from datasets import MapPretrainDataset
from model import IGNet
from torch.utils.data import DataLoader

def main_simulated(configs):
    # simulate training
    epochs = 10
    offset = random.random() / 5
    for epoch in range(2, epochs):
        acc = 1 - 2 ** -epoch - random.random() / epoch - offset
        loss = 2 ** -epoch + random.random() / epoch + offset
        
        # log metrics to wandb
        wandb.log({"acc": acc, "loss": loss})
        
    # [optional] finish the wandb run, necessary in notebooks
    wandb.finish()
    
def main(configs):
    # Set up GPU devices
    if torch.cuda.is_available():
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        if "gpu_ids" not in configs:
            configs["gpu_ids"] = [0]
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
            [str(x) for x in configs["gpu_ids"]]
        )
        print("Using cuda devices:", os.environ["CUDA_VISIBLE_DEVICES"])
        
    # Set up device
    device = torch.device(f'cuda:{configs["gpu_ids"][0]}' if torch.cuda.is_available() else 'cpu')
    
    # Set up seed
    if 'seed' in configs:
        np.random.seed(configs['seed'])
        torch.manual_seed(configs['seed'])

    # Create the model
    model = IGNet(configs)
    if len(configs["gpu_ids"]) > 1:
        model = nn.DataParallel(model, device_ids=configs["gpu_ids"])
    model = model.to(device)

    print('Build PM-Net model finished')
    num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    num_new_parameters = sum(
        p.numel() for name, p in model.named_parameters() if p.requires_grad and not name.startswith('image_encoder')
    )
    print(f'{num_parameters = }')
    print(f'{num_new_parameters = }')
        
    # Build image transforms
    # TODO: set which scale to use. 
    # TODO: use matching image preprocessor. 
    #       For pytorch models, use https://pytorch.org/vision/stable/models.html
    #       For huggingface models, use its loaded preprocessor. 
    
    # Load the dataset
    train_dataset = MapPretrainDataset(
        configs=configs,
        data_split_type='train', 
        transform=model.module.image_processor if len(configs["gpu_ids"]) > 1 else model.image_processor,
        use_huggingface_transform=model.module.huggingface_image_processor if len(configs["gpu_ids"]) > 1
        else model.huggingface_image_processor,
    )
    
    val_dataset = MapPretrainDataset(
        configs=configs,
        data_split_type='val',
        transform=model.module.image_processor if len(configs["gpu_ids"]) > 1 else model.image_processor,
        use_huggingface_transform=model.module.huggingface_image_processor if len(configs["gpu_ids"]) > 1
        else model.huggingface_image_processor,
    )
    
    # Create Dataloader
    train_loader = DataLoader(
        train_dataset, 
        batch_size=configs['batch_size'], 
        shuffle=True,
        num_workers=configs['num_workers'],
        drop_last=True
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=configs['eval_batch_size'], 
        shuffle=True,
        num_workers=configs['num_workers'], 
        drop_last=True
    )
    
    # Build optimizer
    lr = float(configs['lr'])
    
    configs['optimizer'] = configs['optimizer'].lower()
    if configs['optimizer'] == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif configs['optimizer'] == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    elif configs['optimizer'] == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    else:
        raise ValueError(f'Optimizer {configs["optimizer"]} not supported')
    
    current_epoch = 0
    if 'load_run' in configs:
        # Load checkpoints
        load_project_path = os.path.join('logs', configs['load_run'])
        print(f'Loading model from {load_project_path}')
        latest_path = os.path.join(load_project_path, 'latest.pth')
        latest_checkpoint = torch.load(latest_path, map_location=device)
        load_model(model, latest_checkpoint)
        optimizer = get_saved_optimizer(latest_checkpoint, device)
        current_epoch = latest_checkpoint['epoch'] + 1
        
    torch.autograd.set_detect_anomaly(True)
    
    train_eval_loop(
        model=model, 
        optimizer=optimizer, 
        train_loader=train_loader, 
        val_loader=val_loader, 
        epochs=configs['epochs'], 
        device=device, 
        run_name=configs['run_name'],
        project_path=configs['project_folder'], 
        print_log_freq=configs['print_log_freq'], 
        image_log_freq=configs['image_log_freq'], 
        num_images_log=configs['num_images_log'],
        current_epoch=current_epoch, 
        use_wandb=configs['use_wandb'],
        configs=configs
    )

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="PM-Net Training")
    
    # Config files
    parser.add_argument(
        '--config', 
        default="configs/default.yaml",
        type=str,
        help="Path to the config file.",
    )
    
    parser.add_argument(
        '--run-name', 
        default='develop_test',
        type=str
    )
    
    args = parser.parse_args()
    
    with open(args.config, "r") as f:
        configs = yaml.safe_load(f)
    
    # Experiment Name setting, and experiment folder setting. 
    # configs["run_name"] += "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
    configs["run_name"] = args.run_name + "_" + time.strftime("%Y_%m_%d_%H_%M_%S")
    configs["project_folder"] = os.path.join(
        "logs", configs["project_name"], configs["run_name"]
    )
    
    os.makedirs(configs["project_folder"])
    
    if configs["use_wandb"]:
        wandb.login()
        wandb.init(
            project=configs["project_name"], settings=wandb.Settings(start_method="fork")
        )
        
        wandb.run.name = configs["run_name"]
        # Update the wandb args with the training configurations
        if wandb.run:
            wandb.config.update(configs)
        
    print('All configs:')
    for key, val in configs.items():
        print(f'{key}: {val}')
    
    main(configs)