import argparse
from models.noise2noisefields import Noise2NoiseEncoder
from models.noise2noise_dataset import ShapeDataModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import WandbLogger
import torch

from models import models_ae
import os
import multiprocessing

import torch, gc
import json


os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

torch.set_float32_matmul_precision('high')

def main(config):
    multiprocessing.set_start_method('spawn')
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    gc.collect()
    
    fix_model = models_ae.ae_d512_m512()
    finetune_model = models_ae.ae_d512_m512()

    fix_model.load_state_dict(torch.load("/vectsetfeature/checkpoints/ae_d512_m512/checkpoint-199.pth", map_location='cuda')['model'])
    finetune_model.load_state_dict(torch.load("/vectsetfeature/checkpoints/ae_d512_m512/checkpoint-199.pth", map_location='cuda')['model'])
    

    fix_model.eval()
    
    
    dataset_config = {
        "dataset_folder": config['dataset_path'],
        "categories": config['categories'],
        "num_queries": config['num_queries'],
        "pc_size": config['pc_size'],
        "surface_noise_std": config['surface_noise'],
        "point_noise_std": config['point_noise'],
        "uniform_padding": config['uniform_padding'],
        "distributed": config['num_gpus'] > 1,
        "max_samples": config['max_samples'],
        "replica": config['replica'],
        "one_train_shape_only": config['one_train_shape_only'],
        "shuffle_seed": config['shuffle_seed'],
        "noise_type": config['noise_type'] if 'noise_type' in config else 'gaussian',
        "noise_mean": config['noise_mean'] if 'noise_mean' in config else 0.0
    }
    

    dm = ShapeDataModule(
        dataset_config=dataset_config,
        batch_size=config['batch_size'],
        num_workers=config['num_workers']
    )
    
    dm.setup('fit')
    
    model = Noise2NoiseEncoder(
        fix_model=fix_model,
        finetune_model=finetune_model,
        trainable_layers=config['trainable_layers'],
        dataset=dm.dataset,
        output_dir=config['output_dir'],
        samples_to_evaluate=config['samples_to_evaluate']
    )

    wandb_logger = WandbLogger(
        project=config['wandb_project'],
        entity=config['wandb_entity'],
        log_model=True,
        name=config['run_name']
    )

    trainer = pl.Trainer(
        accelerator='gpu',
        devices=config['num_gpus'],
        strategy='ddp' if config['num_gpus'] > 1 else 'auto',
        max_epochs=config['max_epochs'],
        logger=wandb_logger,
        check_val_every_n_epoch=config['val_interval'],
        log_every_n_steps=config['log_every_n_steps'],
        precision=config['precision'],
        enable_checkpointing=False,
    )

    trainer.fit(model, datamodule=dm)
    

if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default='config/train_config/full_train.json')

    args = parser.parse_args()

    with open(args.config, 'r') as f:
        config = json.load(f)
    

    config['output_dir'] = f'{config["output_dir"]}/{config["run_name"]}'
    os.makedirs(config['output_dir'], exist_ok=True)

    print(json.dumps(config, indent=4))

    main(config)
