
from unicodedata import decimal
from models.drnet import Drnet
from models.vcnet import Vcnet

from torch.optim import Adam
from models.mdnv2 import MDN
# from models.CRNet import CRNet
from models.crnet2 import CRNet
from models.crnet3 import CRNetv3
from models.nnv2 import NNv2
from models.nn import NN
import torch




def train_model(args, device, train_data, val_data=None, adrf=None, test_data=None):

    if args.model == 'cr':
        model = CRNet(args.t_dim, args.x_dim, args.train_bs).to(device)
        y_opt = Adam(model.b_module.parameters(), lr=args.lr)
        w_opt = Adam(model.w_module.parameters(), lr=args.lr)

        opt = Adam(model.parameters(), lr=args.lr)

        args.y_epoch = 300
        model = model.train_model(args, device, w_opt, y_opt, train_data, adrf, model, None, test_data)
        # model = model.train_model(args, device, opt, train_data, adrf, model)

    if args.model == 'nn':
        stg1 = NN(args.t_dim+ args.x_dim).to(device)
        opt = Adam(stg1.parameters(), lr=args.lr)
        model = stg1.train_model(args, device, opt, train_data, val_data, adrf)


    if args.model == 'mdn':
        stg1 = MDN(5, args.x_dim)
        opt = Adam(stg1.parameters(), lr=args.lr)
        model = stg1.train_model(args, opt, train_data)
    if args.model == 'mdnr':
        stg1 = MDN(5, args.x_dim, args.t_dim).to(device)
        opt = Adam(stg1.parameters(), lr=args.lr)
        model = stg1.train_model(args, device, opt, train_data)

    if args.model == 'drnet':

        cfg_density = [(args.x_dim+args.t_dim-1, 50, 1, 'relu'), (50, 50, 1, 'relu')]
        num_grid = 10
        cfg = [(50, 50, 1, 'relu'), (50, 1, 1, 'id')]
        isenhance = 1
        stg1 = Drnet(cfg_density, num_grid, cfg, isenhance)
        stg1._initialize_weights()
        # opt = SGD(model.parameters(), lr=5e-2)
        opt = Adam(stg1.parameters(), lr=args.lr)
        model = stg1.train_model(opt, train_data)
    
    if args.model == 'vcnet':
        cfg_density = [(args.x_dim+args.t_dim-1,50,1,'relu'), (50,50,1,'relu')]
        num_grid = 10
        cfg = [(50,50,1,'relu'), (50,1,1,'id')]
        degree = 2
        knots = args.knots
        stg1 = Vcnet(cfg_density, num_grid, cfg, degree, knots)
        stg1._initialize_weights()
        # opt = SGD(model.parameters(), lr=0.001)
        opt = Adam(stg1.parameters(), lr=args.lr)
        model = stg1.train_model(opt, train_data)


    return model