import time
import torch
from math import prod
from sympy import factorint
from cola.ops import Dense
from ops.operators import OptBlockTT


def build_op(struct, device, dim_in, dim_out, **kwargs):
    return STRUCT_MAP[struct]["mvm"](device, dim_in, dim_out, **kwargs)


def get_flops(struct, batch_size, dim_in, dim_out, **kwargs):
    return STRUCT_MAP[struct]["flops"](batch_size, dim_in, dim_out, **kwargs)


def build_dense(device, dim_in, dim_out, **kwargs):
    W = torch.randn(dim_in, dim_out, device=device)
    return Dense(W)


def get_dense_flops(batch_size, dim_in, dim_out, **kwargs):
    return 2 * batch_size * dim_in * dim_out


def construct_cores(device, dim_in, dim_out, tt_rank, tt_dim):
    ns, ms = factorize(dim_out, tt_dim), factorize(dim_in, tt_dim)
    rs = [1] + [tt_rank] * (tt_dim - 1) + [1]
    shapes = (rs, ms, ns)
    cores = []
    for idx in range(tt_dim):
        size = ns[:idx] + ms[idx + 1:] + (rs[idx] * ms[idx], rs[idx + 1] * ns[idx])
        core = torch.randn(*size, dtype=torch.float32, device=device)
        cores.append(core)
    return cores, shapes


def build_btt(device, dim_in, dim_out, tt_rank, tt_dim, **kwargs):
    cores, shapes = construct_cores(device, dim_in, dim_out, tt_rank, tt_dim)
    return OptBlockTT(cores, shapes)


def get_btt_flops(batch_size, dim_in, dim_out, tt_rank, tt_dim, **kwargs):
    ns, ms = factorize(dim_out, tt_dim), factorize(dim_in, tt_dim)
    rs = [1] + [tt_rank] * (tt_dim - 1) + [1]
    total = 0
    for idx in range(tt_dim):
        size = ns[:idx] + ms[idx + 1:] + (rs[idx] * ms[idx], rs[idx + 1] * ns[idx])
        out = batch_size * prod(size)
        total += out
    return 2 * total


STRUCT_MAP = {
    "dense": {
        "mvm": build_dense,
        "flops": get_dense_flops
    },
    "btt": {
        "mvm": build_btt,
        "flops": get_btt_flops
    },
}


def compute_elapsed_gpu(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    fn()
    end.record()
    torch.cuda.synchronize()
    elapsed = start.elapsed_time(end) / 1000
    return elapsed


def compute_elapsed_cpu(fn):
    t0 = time.time()
    fn()
    t1 = time.time()
    elapsed = t1 - t0
    return elapsed


def factorize(x, n):
    prime_factors = factorint(x)
    numbers = [1] * n
    for prime, count in prime_factors.items():
        for _ in range(count):
            min_index = min(range(n), key=lambda i: numbers[i])
            numbers[min_index] *= prime
    return tuple(sorted(numbers))
