'''
Author: pengjie pengjieb@mail.ustc.edu.cn
Date: 2022-05-25 22:30:26
LastEditors: pengjie pengjieb@mail.ustc.edu.cn
LastEditTime: 2022-07-15 14:11:10
FilePath: /MultiBench/utils/tools.py
Description: 

Copyright (c) 2022 by pengjie pengjieb@mail.ustc.edu.cn, All Rights Reserved. 
'''
import random
from typing import Any
import numpy as np
import torch
import os
from abc import ABC, abstractmethod

from torch.utils.data import DataLoader, Dataset

def count_param(model: torch.nn.Module):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def create_dir(dir):
    if not os.path.exists(dir):
        os.mkdir(dir)

def dataloader_info(dataset:DataLoader):
    # dataset.batch_size
    if type(dataset.dataset[1]) is list:
        num_modalities = len(dataset.dataset[1]) - 2 # one for label, another for index
        data_sizes = []
        for single_modality in dataset.dataset[1][:-2]:
            data_sizes.append(single_modality.shape if hasattr(single_modality, 'shape') else 'index')
        print(num_modalities, data_sizes)
        print('number of modalities:{}'.format(num_modalities))
        print(*['modality {}\t\t'.format(i) for i in range(num_modalities)])
        print(*['{}\t'.format(item) for item in data_sizes])
    # print(type(dataset.dataset[0]))
    
class OutterCallback(ABC):
    def __init__(self) -> None:
        pass
    
    @abstractmethod
    def train_stage(self, *args: Any, **kwds: Any):
        pass
    
    @abstractmethod
    def valid_stage(self, *args: Any, **kwds: Any):
        pass
    
    @abstractmethod
    def test_stage(self, *args: Any, **kwds: Any):
        pass
    
    def __call__(self, stage, *args: Any, **kwds: Any) -> Any:
        if stage == 'train':
            self.train_stage(*args, **kwds)
        elif stage == 'valid':
            self.valid_stage(*args, **kwds)
        elif stage == 'test':
            self.test_stage(*args, **kwds)
        # return super().__call__(*args, **kwds)