from typing import Optional, Callable, List

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.optim.lr_scheduler import MultiStepLR

class CustomTensorDataset(Dataset):
    """TensorDataset with support of transforms.
    """
    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        if len(self.tensors) == 2:
            y = self.tensors[1][index]
            return (x, y)

        if len(self.tensors) == 3:
            y = self.tensors[1][index]
            w = self.tensors[2][index]
            return (x, y, w)

        if len(self.tensors) == 4:
            y = self.tensors[1][index]
            w = self.tensors[2][index]
            r = self.tensors[3][index]
            return (index, x, y, w, r)

        if len(self.tensors) == 5:
            y = self.tensors[1][index]
            w = self.tensors[2][index]
            r = self.tensors[3][index]
            nx = self.tensors[4][index]
            return (index, x, y, w, r, nx)

        return (x, )

    def __len__(self):
        return self.tensors[0].size(0)

def _generate_noise(
    std: float,
    reference: torch.Tensor,
) -> torch.Tensor:
    """
    """
    zeros = torch.zeros(reference.shape, device=reference.device)
    if std == 0:
        return zeros
    return torch.normal(
        mean=0,
        std=std,
        size=reference.shape,
        device=reference.device,
    )

def params(optimizer: torch.optim.Optimizer) -> List[nn.Parameter]:
    """
    Return all parameters controlled by the optimizer
    Args:
        optimizer: optimizer
    Returns:
        Flat list of parameters from all ``param_groups``
    """
    ret = []
    for param_group in optimizer.param_groups:
        ret += [p for p in param_group["params"] if p.requires_grad]
    return ret

class NoisyOptimizerWrapper(torch.optim.Optimizer):
    """
    """

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        noise_multiplier: float,
    ):
        """
        Args:
            optimizer: wrapped optimizer.
            noise_multiplier: noise multiplier
        """
        self.original_optimizer = optimizer
        self.noise_multiplier = noise_multiplier

        self.param_groups = optimizer.param_groups
        self.state = optimizer.state

    @property
    def params(self) -> List[nn.Parameter]:
        """
        Returns a flat list of ``nn.Parameter`` managed by the optimizer
        """
        return params(self)

    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
        if closure is not None:
            with torch.enable_grad():
                closure()

        for p in self.params:
            noise = _generate_noise(
                std=self.noise_multiplier,
                reference=p.grad,
            )
            p.grad = p.grad + noise

        return self.original_optimizer.step(closure)

    def __repr__(self):
        return self.original_optimizer.__repr__()

    def state_dict(self):
        return self.original_optimizer.state_dict()

    def load_state_dict(self, state_dict) -> None:
        self.original_optimizer.load_state_dict(state_dict)

def get_optimizer(model, optimizer: str, learning_rate: float, momentum,
                  weight_decay, additional_vars=None, noise_multiplier=0.):
    if additional_vars is None:
        parameters = model.parameters()
    else:
        parameters = [p for p in model.parameters()] + additional_vars

    if 'adam' in optimizer:
        ret = optim.Adam(parameters, lr=learning_rate, weight_decay=weight_decay)
    elif 'sgd' in optimizer:
        ret = optim.SGD(parameters, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
    else:
        raise ValueError(f"Not supported optimizer {optimizer}")
    
    if 'noisy' in optimizer:
        ret = NoisyOptimizerWrapper(optimizer=ret, noise_multiplier=noise_multiplier)
    return ret

def get_loss(loss_name: str, reduction):
    if 'ce' in loss_name:
        ret = nn.CrossEntropyLoss(reduction=reduction)
    elif 'mse' in loss_name:
        ret = nn.MSELoss(reduction=reduction)
    else:
        raise ValueError(f"Not supported loss {loss_name}")
    return ret

def get_scheduler(optimizer, n_epochs: int, loss_name=None):
    scheduler = MultiStepLR

    if n_epochs <= 30:
        scheduler = scheduler(optimizer, milestones=[15, 25], gamma=0.1)
    elif n_epochs <= 40:
        scheduler = scheduler(optimizer, milestones=[20, 30], gamma=0.1)
    elif n_epochs <= 50:
        scheduler = scheduler(optimizer, milestones=[25, 40], gamma=0.1)
    elif n_epochs <= 60:
        scheduler = scheduler(optimizer, milestones=[30, 50], gamma=0.1)
    elif n_epochs <= 70:
        scheduler = scheduler(optimizer, milestones=[40, 50, 60], gamma=0.1)

    elif n_epochs <= 80:
        scheduler = scheduler(optimizer, milestones=[40, 60, 70], gamma=0.1)
    elif n_epochs <= 120:
        scheduler = scheduler(optimizer, milestones=[50, 90, 110], gamma=0.1)

    elif n_epochs <= 140:
        scheduler = scheduler(optimizer, milestones=[60, 100, 120], gamma=0.1)
    elif n_epochs <= 160:
        scheduler = scheduler(optimizer, milestones=[40, 80, 120, 140], gamma=0.1)
    elif n_epochs <= 180:
        scheduler = scheduler(optimizer, milestones=[80, 140, 180], gamma=0.1)

    elif n_epochs <= 200:
        scheduler = scheduler(optimizer, milestones=[80, 140, 180], gamma=0.1)
    else:
        scheduler = scheduler(optimizer, milestones=[60, 120], gamma=0.1)
    return scheduler
