import torch
import numpy as np
import torch.optim as optim
from torch.utils.data import random_split


class LRScheduler:
    scheduler = None
    is_batch_scheduler = False


def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def count_total_parameters(model):
    return sum(p.numel() for p in model.parameters())


def get_lr_scheduler(lr_scheduler: str, optimizer, **kwargs):
    def map_scheduler(lr_scheduler):
        lrs = LRScheduler()
        if lr_scheduler is not None:
            if lr_scheduler == "cswr":
                lrs.is_batch_scheduler = True
                lrs.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
                    optimizer,
                    T_0=kwargs.get("T_0", 10),
                    T_mult=kwargs.get("T_mult", 2),
                    eta_min=kwargs.get("min_lr", 1e-6),
                    last_epoch=-1,
                )
            elif lr_scheduler == "rlrp":
                lrs.is_batch_scheduler = False
                lrs.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer,
                    "min",
                    patience=kwargs.get("patience", 10) // 2,
                )
            else:
                raise Exception(
                    f"{lr_scheduler} lr scheduler not defined only choices are [cswr, rlrp]"
                )
        return lrs

    scheduler = map_scheduler(lr_scheduler)
    return scheduler


def load_model(model, checkpoint_path):
    model.load_state_dict(torch.load(checkpoint_path)["model_state_dict"])
    return model


def set_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def split_dataset(dataset, val_split=0.2, seed=42):
    total_len = len(dataset)
    val_len = int(total_len * val_split)
    train_len = total_len - val_len
    generator = torch.Generator().manual_seed(seed)
    return random_split(dataset, [train_len, val_len], generator=generator)


def l1_regularization(model):
    params_vec = []
    for param in model.parameters():
        params_vec.append(param.view(-1))
    return torch.linalg.norm(torch.cat(params_vec), ord=1)
