from models.moe_v3 import GuneyNet_MoE_v3
from models.tmae_former import CausalEEGAutoencoder2D_former
from training_utils import concatenate_subjects, make_loader
import numpy as np
import torch
from torch import nn
from train import train, test
import copy
from log import log

def guney_int_moe(dataset_config):

    if dataset_config.moe_version == 'v3':
        moe_model = GuneyNet_MoE_v3(
            n_classes=dataset_config.targets,
            n_channels=dataset_config.channels,
            n_samples=dataset_config.samples,
            n_bands=dataset_config.multi_band,
            moe_num_experts=dataset_config.moe_experts,
            moe_top_k=dataset_config.moe_top_k,
            moe_switch=dataset_config.moe_switch,  
            n_spatial_filters=dataset_config.n_spatial_filters,
            n_time1_filters=dataset_config.n_time1_filters,
            n_time2_filters=dataset_config.n_time2_filters,
            time2_kernel=dataset_config.time2_kernel,
        )
    return moe_model


def tmae_init(dataset_config):

    tmae = CausalEEGAutoencoder2D(
        input_len=dataset_config.time_period_real,
        hidden_bank=dataset_config.tmae_hidden_bank,
        time_kernel=dataset_config.tmae_time_kernel,
        dropout=dataset_config.tmae_dropout,
        full_len=dataset_config.total_len_time_period,
    )
    return tmae




def tmaeformer_init(dataset_config):
    tmae = CausalEEGAutoencoder2D_former(
        input_len=dataset_config.time_period_real,
        hidden_bank=dataset_config.tmae_hidden_bank,
        time_kernel=dataset_config.tmae_time_kernel,
        dropout=dataset_config.tmae_dropout,
        full_len=dataset_config.total_len_time_period,
        vit_patch_size_1=dataset_config.vit_patch_size_1,
        depth=dataset_config.depth,
        channel=dataset_config.channels,
        vit_patch_size_0 = dataset_config.channels,

    )
    return tmae


def causal_masked_loss(reconstructed, target, known_len=50):
    mask = torch.ones_like(target)
    mask[:, :, :, :known_len] = 0
    a = ((reconstructed - target) * mask).pow(2).sum() / mask.sum()
    return a
