from typing import Dict
import torch
from torch import Tensor
import torch.nn.functional as F
import math

multimae_results = torch.zeros((20))

def get_random_split(size: int):
    train_size = int(0.8 * size)
    valid_size = int(0.1 * size)
    test_size = size - train_size - valid_size
    train_indices, valid_indices, test_indices = \
        torch.utils.data.random_split(range(size), [train_size, valid_size, test_size])
    return {
        'train': list(train_indices),
        'valid': list(valid_indices),
        'test': list(test_indices)
    }


def get_sinusoid_pe_tensor(d_model: int, maxlen: int) -> Tensor:
    r"""
    Calculate sinusoidal positional encoding tensor.
    Tensor shapes:
        pe: (maxlen, d_model)
    """
    pe = torch.zeros(maxlen, d_model)
    position = torch.arange(0, maxlen, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(
        0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe.requires_grad = False
    return pe


def split_dict_transform(split_dict, is_valid):
    cnt = 0
    new_ind = []
    for i in range(len(is_valid)):
        if is_valid[i] == 1:
            new_ind.append(cnt)
            cnt += 1
        else:
            new_ind.append(-1)
    
    new_dict = dict.fromkeys(['train', 'valid', 'test'])
    ft = lambda x: x != -1
    new_dict['train'] = list(filter(ft, [new_ind[i] for i in split_dict['train']]))
    new_dict['valid'] = list(filter(ft, [new_ind[i] for i in split_dict['valid']]))
    new_dict['test'] = list(filter(ft, [new_ind[i] for i in split_dict['test']]))
    return new_dict


class Evaluator():
    
    def __init__(self, metric: str):
        self.metric = metric

    def eval(self, input_dict: Dict):
        y_pred = input_dict['y_pred']
        y_true = input_dict['y_true']
        
        if self.metric == 'mae':
            result = F.l1_loss(y_pred, y_true, reduction='mean')
        elif self.metric == 'multi_mae':
            tensor_shape = y_pred.shape
            for i in range(tensor_shape[1]):
                y_pred_col = y_pred[:, i]
                y_true_col = y_true[:, i]
                multimae_results[i] = F.l1_loss(y_pred_col, y_true_col, reduction='mean')
            result = F.l1_loss(y_pred, y_true, reduction='mean') * tensor_shape[1]
        elif self.metric == 'rmse':
            result = math.sqrt(F.mse_loss(y_pred, y_true, reduction='mean'))
        elif self.metric == 'ap':
            ...

        return {self.metric: result}

