import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from . import BaseTask, register_task
from ..dataset import build_dataset
from ..dataset.masked import MaskedDataset

@register_task("modulation_classification")
class ModulationClassification(BaseTask):
    def __init__(self, args):
        super(ModulationClassification, self).__init__()
        self.args = args
        self.signal_length = args.signal_length

    def get_data(self, mode="default"):
        dataset = build_dataset(self.args.dataset[0], self.args.test_size, self.args.dataset_path)
        self.classes = dataset.classes
        self.minSNR = dataset.minSNR
        self.maxSNR = dataset.maxSNR
        self.SNR_list = dataset.SNR_list
        train_dataset = dataset("train", mode)
        val_dataset = dataset("valid", mode)
        test_dataset = dataset("test", mode)

        return train_dataset, val_dataset, test_dataset

    def get_loss_func(self):
        return F.cross_entropy
    
    def get_classes(self):
        return self.classes

    def get_snr(self):
        return list(self.SNR_list)
