import argparse
import torch
import random
import numpy as np
from trainer_twoway import SIRENTrainer, FunctaTrainer
from trainer_twoway_3d import SIRENTrainer3D, FunctaTrainer3D
import yaml
from pathlib import Path
import os
from utils import get_save_paths, load_config

def get_args():
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--dataset', type=str, 
                       choices=['fwi', 'ns3d_twoway'], 
                       required=True)
    
    parser.add_argument('--inr', type=str,
                        choices=['siren', 'functa'],
                        required=True)
    
    parser.add_argument('--modulation', type=str, 
                       choices=['shift', 'gfm', 'scale', 'film'], 
                       required=True)
    
    parser.add_argument('--config_name', type=str,
                        required=True)
    
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default='0')
    
    return parser.parse_args()


def main():
    args = get_args()
    
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    
    device = torch.device(f'cuda:{args.device}')
    config = load_config(args)
    print(args)
    

    if args.dataset == 'ns3d_twoway':
        args.enc_dim = 3
    else:
        args.enc_dim = 2
    
    VIS_TRAIN_SAVE_PATH, PARAM_SAVE_PATH = get_save_paths(args)
    
    if args.dataset == 'ns3d_twoway':
        if args.inr == 'siren':
            trainer = SIRENTrainer3D(args, device, VIS_TRAIN_SAVE_PATH, PARAM_SAVE_PATH)
        elif args.inr == 'functa':
            trainer = FunctaTrainer3D(args, device, VIS_TRAIN_SAVE_PATH, PARAM_SAVE_PATH)
    else:    
        if args.inr == 'siren':
            trainer = SIRENTrainer(args, device, VIS_TRAIN_SAVE_PATH, PARAM_SAVE_PATH)
        elif args.inr == 'functa':
            trainer = FunctaTrainer(args, device, VIS_TRAIN_SAVE_PATH, PARAM_SAVE_PATH)
    
    trainer.train()

if __name__ == "__main__":
    main()
