import os
import numpy as np
import torch

def matrix_power(A, pow, is_torch=False):
    if is_torch:
        A = A.numpy()
    U, S, Vh = np.linalg.svd(A)
    S = np.power(S, pow)
    A = U @ np.diag(S) @ Vh

    if is_torch:
        A = torch.from_numpy(A)

    return A

def convert_one_hot(y, c):
    o = np.zeros((y.size, c))
    o[np.arange(y.size), y] = 1
    return o

def setup_wandb(wandb, args):
    mode = 'online'
    if args.wandb_offline:
        mode = 'offline'

    wandb.init(entity=args.wandb_entity, project=args.wandb_proj_name, mode=mode, config=args,
               dir=args.out_dir)

    out_dir = os.path.join(args.out_dir, args.wandb_proj_name, wandb.run.id)
    os.makedirs(out_dir, exist_ok=True)

    if args.dataset == 'modular_arithmetic':
        wandb.run.name = f'{wandb.run.id} - p: {args.prime}, train_frac: {args.training_fraction}'
    elif args.dataset == 'cifar10':
        wandb.run.name = f'{wandb.run.id}'

def get_optimizer(args, model):
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=args.learning_rate,
            weight_decay=args.weight_decay,
            momentum=args.momentum
        )
    elif args.opt == 'adamw':
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.98),
            weight_decay=args.weight_decay
        )
    else:
        raise

    return optimizer

def undiag(M, is_torch=False):
    if is_torch:
        return M - torch.diag(torch.diag(M))
    return M - np.diag(np.diag(M))
