import random

import numpy as np

from rich import get_console
from rich.table import Table

import torch
import torch.nn as nn
import torch.nn.functional as F


def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def print_table(title: str, metrics: dict) -> None:
    table = Table(title=title)

    table.add_column("Metrics", style="cyan", no_wrap=True)
    table.add_column("Value", style="magenta")

    for key, value in metrics.items():
        table.add_row(key, str(value))

    console = get_console()
    console.print(table, justify="center")


def move_batch_to_device(batch: dict, device: torch.device) -> dict:
    for key in batch.keys():
        if isinstance(batch[key], torch.Tensor):
            batch[key] = batch[key].to(device)
    return batch


def count_parameters(module: nn.Module) -> float:
    num_params = sum(p.numel() for p in module.parameters())
    return round(num_params / 1e6, 3)


def get_guidance_scale_embedding(w: torch.Tensor, embedding_dim: int = 512,
                                 dtype: torch.dtype = torch.float32) -> torch.Tensor:
    assert len(w.shape) == 1
    w = w * 1000.0
    half_dim = embedding_dim // 2
    emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
    emb = w.to(dtype)[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1))
    assert emb.shape == (w.shape[0], embedding_dim)
    return emb


def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor:
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def sum_flat(tensor: torch.Tensor) -> torch.Tensor:
    return tensor.sum(dim=list(range(1, len(tensor.shape))))




def control_loss_calculate(
        vaeloss_type: str, loss_func: str, src: torch.Tensor,
        tgt: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:

    if loss_func == 'l1':
        loss = F.l1_loss(src, tgt, reduction='none')
    elif loss_func == 'l1_smooth':
        loss = F.smooth_l1_loss(src, tgt, reduction='none')
    elif loss_func == 'l2':
        loss = F.mse_loss(src, tgt, reduction='none')
    else:
        raise ValueError(f'Unknown loss func: {loss_func}')

    if vaeloss_type == 'sum':
        loss = loss.sum(-1, keepdims=True) * mask
        loss = loss.sum() / mask.sum()
    elif vaeloss_type == 'sum_mask':
        loss = loss.sum(-1, keepdims=True) * mask
        loss = sum_flat(loss) / sum_flat(mask)
        loss = loss.mean()
    elif vaeloss_type == 'mask':
        loss = sum_flat(loss * mask)
        n_entries = src.shape[-1]
        non_zero_elements = sum_flat(mask) * n_entries
        loss = loss / non_zero_elements
        loss = loss.mean()
    else:
        raise ValueError(f'Unsupported vaeloss_type: {vaeloss_type}')

    return loss