import torch
import torch.optim as optim
import rtdl_num_embeddings
from torch.optim import AdamW
from torch_ema import ExponentialMovingAverage

import model.Base_Transformer
import model.MLP
# import model.MambaTab
#import model.TTVAE
import model.TabDiff
import model.TabM
import model.TabNet
import model.ResMLP
import model.ResMLP2
import model.HybridMLP
import model.MLP2
import model.ddpm

def get_model(config, device, d_in, x_train_data):    
    if config["backbone_model"] == "MLP2048":
        d_embedding = 8
        if config["embedding_type"] == "Linear":
            net = model.MLP.MLP2048(device, config, d_in=d_in,  bins=None, d_embedding=None).to(device) 
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001) #["learning_rate"] orig: 1e-4
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[11120,11160], gamma=0.1)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "ResMLP":
        net = model.ResMLP.ResMLPDenoiser(device, config, data_dim=d_in).to(device) 
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001) #["learning_rate"] orig: 1e-4
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1420,1460], gamma=0.1)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "ResMLP2":
        net = model.ResMLP2.ResMLPDenoiser(device, config, data_dim=d_in).to(device) 
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001) #["learning_rate"] orig: 1e-4
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[11120,11160], gamma=0.1)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "DDPM":
        params = {}
        params["d_main"] = 128
        params["n_blocks"] = 3
        params["d_hidden"] = 512
        params["dropout_first"] = 0
        params["dropout_second"] = 0
        net = model.ddpm.ResNetDiffusion(params, d_in=d_in).to(device) 
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001) #["learning_rate"] orig: 1e-4
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[11120,11160], gamma=0.1)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "HybridMLP":
        net = model.HybridMLP.HybridResMLP_Denoiser(data_dim=d_in).to(device) 
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001) #["learning_rate"] orig: 1e-4
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[220,260], gamma=0.1)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "MLP2":
        net = model.MLP2.MLP(device, config, d_in=d_in,  bins=None, d_embedding=None).to(device)
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001) #["learning_rate"] orig: 1e-4
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1220,1260], gamma=0.1)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "Base_Transformer":
        num_heads, num_layers = 4, 6
        d_embedding = 8
        if config["embedding_type"] == "Linear":
            net = model.Base_Transformer.TransformerBackbone(device, config, dim=d_in, bins=None, d_embedding=None, num_heads=num_heads, num_layers=num_layers)
        elif config["embedding_type"] == "PiecewiseLinearEmbeddings":
            bins = rtdl_num_embeddings.compute_bins(x_train_data)
            net = model.Base_Transformer.TransformerBackbone(device, config, dim=d_in, bins=bins, d_embedding=d_embedding, num_heads=num_heads, num_layers=num_layers)
        elif config["embedding_type"] == "PeriodicEmbeddings":
            net = model.Base_Transformer.TransformerBackbone(device, config, dim=d_in, bins=None, d_embedding=d_embedding, num_heads=num_heads, num_layers=num_layers)

        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "TabNet":
        net = model.TabNet.TabNet(device, config, d_in)
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001)
        # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) #based on original paper
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[220,260], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "TabDiff":
        net = model.TabDiff.TabDiff(device, config, d_in)
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) #based on original paper
        #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[120,160], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "MambaTab":
        net = model.MambaTab.MambaTab(device, config, d_in)
        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=0.001)
        # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) #based on original paper
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[220,260], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema
    
    elif config["backbone_model"] == "TabM":
        if config["bins"]:
            bins = rtdl_num_embeddings.compute_bins(x_train_data.squeeze())
        else:
            bins = None

        net = model.TabM.TabM(device=device, config=config,n_num_features=d_in,
                              backbone={
                                    'type': 'MLP',
                                    'n_blocks': 3 if bins is None else 2,
                                    'd_block': 512,
                                    'dropout': 0.1,
                                    'd_out': None
                                },
                                bins=bins,
                                num_embeddings=(
                                    None
                                    if bins is None
                                    else {
                                        'type': config["embedding_type"],
                                        'd_embedding': 8,
                                        'activation': False,
                                        'version': 'B',
                                    }
                                ),
                                cat_cardinalities=[],
                                n_output=d_in,
                                arch_type=config["arch_type"],
                                k=32,
                                share_training_batches=True,
        )

        optimizer = AdamW(net.parameters(), lr=config["learning_rate"], weight_decay=3e-4)#weight_decay=0.001)
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95) #based on original paper
        #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[120,160], gamma=0.1, verbose = False)
        ema = ExponentialMovingAverage(net.parameters(), decay=0.9999)
        return net, optimizer, scheduler, ema

    else:
        raise Exception("Model not defined")
    