import torch
import torch.nn as nn

from models.model.prn import PRN

from models.model.encoder import Encoder
from models.block.encoder_block import EncoderBlock

from models.layer.multi_head_attention_layer import MultiHeadAttentionLayer
from models.embedding.token_embedding import TokenEmbedding

def build_model(device=torch.device("cpu"),
                n_layer = 1,
                d_model = 512,  #num_class
                h = 1,
                dr_rate = 0.1,
                norm_eps = 1e-5,
                residual_block=True):
    import copy
    copy = copy.deepcopy

    attention = MultiHeadAttentionLayer(
                                        d_model = d_model,
                                        h = h,
                                        out_fc = nn.Linear(d_model, d_model),
                                        dr_rate = dr_rate)

    norm = nn.LayerNorm(d_model, eps = norm_eps)

    encoder_block = EncoderBlock(
                                 self_attention = copy(attention),
                                 norm = copy(norm),
                                 dr_rate = dr_rate,
                                    residual_block=residual_block)

    encoder = Encoder(
                      encoder_block = encoder_block,
                      n_layer = n_layer,
                      norm = copy(norm))

    model = PRN(encoder = encoder,
                ).to(device)

    model.device = device

    return model
