import math
import numpy as np
import torch

from itertools import repeat
from time import perf_counter

from utils import cubic_roots_real, fit_newtonschulz_coeff

def warm_up(device):
    A = torch.randn(5000, 5000, device=device)
    for _ in range(10):
        torch.matmul(A.mT, A)

def time_polar_express_sign(A, num_iterations):

    coeffs_list = [
        (8.28721201814563, -23.595886519098837, 17.300387312530933),
        (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
        (3.9486908534822946, -2.908902115962949, 0.5518191394370137),
        (3.3184196573706015, -2.488488024314874, 0.51004894012372),
        (2.300652019954817, -1.6689039845747493, 0.4188073119525673),
        (1.891301407787398, -1.2679958271945868, 0.37680408948524835),
        (1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
        (1.875, -1.25, 0.375),
    ]

    coeffs_list = [
        (a / 1.01, b / 1.01**3, c / 1.01**5) 
        for (a, b, c) in coeffs_list[:-1]] + [coeffs_list[-1]]

    hs = coeffs_list[:num_iterations] + list(repeat(coeffs_list[-1], num_iterations - len(coeffs_list)))

    identity = torch.eye(A.shape[1], dtype=A.dtype, device=A.device)
    X = A

    time_elapsed = 0
    time = [time_elapsed]
    residual = [torch.norm(identity - A.mT@A).item()]
    alpha_list = []

    for a, b, c in hs:

        torch.cuda.synchronize(A.device)
        time_start = perf_counter()

        AA = X.mT @ X
        BB = b * AA + c * AA @ AA
        X = a * X + X @ BB

        torch.cuda.synchronize(A.device)
        time_end = perf_counter()
        time_elapsed += time_end - time_start
        time.append(time_elapsed)
        residual.append(torch.norm(identity - X.mT@X).item())

    return time, residual

def time_polar_express_sqrt(A, num_iterations):

    coeffs_list = [
        (8.28721201814563, -23.595886519098837, 17.300387312530933),
        (4.107059111542203, -2.9478499167379106, 0.5448431082926601),
        (3.9486908534822946, -2.908902115962949, 0.5518191394370137),
        (3.3184196573706015, -2.488488024314874, 0.51004894012372),
        (2.300652019954817, -1.6689039845747493, 0.4188073119525673),
        (1.891301407787398, -1.2679958271945868, 0.37680408948524835),
        (1.8750014808534479, -1.2500016453999487, 0.3750001645474248),
        (1.875, -1.25, 0.375),
    ]

    coeffs_list = [
        (a / 1.01, b / 1.01**3, c / 1.01**5) 
        for (a, b, c) in coeffs_list[:-1]] + [coeffs_list[-1]]
    
    hs = coeffs_list[:num_iterations] + list(repeat(coeffs_list[-1], num_iterations - len(coeffs_list))) 

    identity = torch.eye(A.shape[1], dtype=A.dtype, device=A.device)
    X = A
    Y = identity

    time_elapsed = 0
    time = [time_elapsed]
    residual = [torch.norm(identity - A).item()]
        
    for a, b, c in hs:

        torch.cuda.synchronize(A.device)
        time_start = perf_counter()
        
        YX = Y @ X
        H = a * identity + b * YX + c * YX @ YX
        X = X @ H
        Y = H @ Y

        torch.cuda.synchronize(A.device)
        time_end = perf_counter()
        time_elapsed += time_end - time_start
        time.append(time_elapsed)
        residual_norm = torch.norm(identity - Y.mT@X).item()
        residual.append(residual_norm)

    return time, residual


def time_newton_schulz_sqrt(A, degree, adaptive, num_iterations):

    identity = torch.eye(A.shape[1], dtype=A.dtype, device=A.device)
    X = A
    Y = identity
    sketch_dim = 1
    sketch = torch.randn(A.shape[0], sketch_dim, device=A.device, dtype=A.dtype) / math.sqrt(sketch_dim)

    time_elapsed = 0
    time = [time_elapsed]
    residual = [torch.norm(identity - A).item()]
    alpha_list = []

    for _ in range(num_iterations):

        torch.cuda.synchronize(A.device)
        time_start = perf_counter()

        if degree == 3:
            R = (Y @ X).neg_().add_(identity)
            if adaptive:
                R1 = R @ sketch
                R2 = R @ R1
                R3 = R @ R2
                r2 = (R1*R1).sum().item()
                r3 = (R2*R1).sum().item()
                r4 = (R2*R2).sum().item()
                r5 = (R3*R2).sum().item()
                r6 = (R3*R3).sum().item()
                a = r6 - 2*r5 + r4
                b = 4*r5 - 8*r4 + 4*r3
                c = 6*r4 - 10*r3 + 4*r2
                d = 4*r3 - 4*r2
                alpha = fit_newtonschulz_coeff(a, b, c, d, low=0.5, high=1, x0=0.5)
                alpha_list.append(alpha)
            else:
                alpha = 0.5
                alpha_list.append(alpha)
            X = (X @ R).mul_(alpha).add_(X)
            Y = (R @ Y).mul_(alpha).add_(Y)

        elif degree == 5:
            R = (Y @ X).neg_().add_(identity)
            if adaptive:
                R1 = R @ sketch
                R2 = R @ R1
                R3 = R @ R2
                R4 = R @ R3
                R5 = R @ R4
                r2 = (R1*R1).sum().item()
                r3 = (R2*R1).sum().item()
                r4 = (R2*R2).sum().item()
                r5 = (R3*R2).sum().item()
                r6 = (R3*R3).sum().item()
                r7 = (R4*R3).sum().item()
                r8 = (R4*R4).sum().item()
                r9 = (R5*R4).sum().item()
                r10 = (R5*R5).sum().item()
                a = r10 - 2*r9 + r8
                b = 2*r9 - 6*r7 + 4*r6
                c = 1.5*r8 + 3*r7 - 4.5*r6 - 4*r5 + 4*r4
                d = 0.5*r7 + 2*r6 + 0.5*r5 - 3*r4
                alpha = fit_newtonschulz_coeff(a, b, c, d, low=0.375, high=1.45, x0=0.375)
                alpha_list.append(alpha)
            else:
                alpha = 0.375
                alpha_list.append(alpha)
            H = (R @ R).mul_(alpha).add_(R.mul_(0.5)).add_(identity)
            X = X @ H
            Y = H @ Y

        torch.cuda.synchronize(A.device)
        time_end = perf_counter()
        time_elapsed += time_end - time_start
        time.append(time_elapsed)
        residual.append(torch.norm(identity - X@Y).item())

    return time, residual, alpha_list


def time_newton_schulz_sign(A, degree, adaptive, num_iterations):

    identity = torch.eye(A.shape[1], dtype=A.dtype, device=A.device)
    assert A.ndim >= 2
    X = A

    sketch_dim = 1
    sketch = torch.randn(A.shape[1], sketch_dim, device=A.device, dtype=A.dtype) / math.sqrt(sketch_dim)

    time_elapsed = 0
    time = [time_elapsed]
    residual = [torch.norm(identity - X.mT@X).item()]
    alpha_list = []

    for _ in range(num_iterations):

        torch.cuda.synchronize(A.device)
        time_start = perf_counter()

        if degree == 3:
            R = (X.mT @ X).neg_().add_(identity)
            if adaptive:
                R1 = R @ sketch
                R2 = R @ R1
                R3 = R @ R2
                r2 = (R1*R1).sum().item()
                r3 = (R2*R1).sum().item()
                r4 = (R2*R2).sum().item()
                r5 = (R3*R2).sum().item()
                r6 = (R3*R3).sum().item()
                a = r6 - 2*r5 + r4
                b = 4*r5 - 8*r4 + 4*r3
                c = 6*r4 - 10*r3 + 4*r2
                d = 4*r3 - 4*r2
                alpha = fit_newtonschulz_coeff(a, b, c, d, low=0.5, high=1, x0=0.5)
                alpha_list.append(alpha)
            else:
                alpha = 0.5
                alpha_list.append(alpha)
            X = (X @ R).mul_(alpha).add_(X)

        elif degree == 5:
            R = (X.mT @ X).neg_().add_(identity)
            if adaptive:
                R1 = R @ sketch
                R2 = R @ R1
                R3 = R @ R2
                R4 = R @ R3
                R5 = R @ R4
                r2 = (R1*R1).sum().item()
                r3 = (R2*R1).sum().item()
                r4 = (R2*R2).sum().item()
                r5 = (R3*R2).sum().item()
                r6 = (R3*R3).sum().item()
                r7 = (R4*R3).sum().item()
                r8 = (R4*R4).sum().item()
                r9 = (R5*R4).sum().item()
                r10 = (R5*R5).sum().item()
                a = r10 - 2*r9 + r8
                b = 2*r9 - 6*r7 + 4*r6
                c = 1.5*r8 + 3*r7 - 4.5*r6 - 4*r5 + 4*r4
                d = 0.5*r7 + 2*r6 + 0.5*r5 - 3*r4
                alpha = fit_newtonschulz_coeff(a, b, c, d, low=0.375, high=1.45, x0=0.375)
                alpha_list.append(alpha)
            else:
                alpha = 0.375
                alpha_list.append(alpha)
            H = (R @ R).mul_(alpha).add_(R.mul_(0.5)).add_(identity)
            X = X @ H

        torch.cuda.synchronize(A.device)
        time_end = perf_counter()
        time_elapsed += time_end - time_start
        time.append(time_elapsed)
        residual.append(torch.norm(identity - X.mT@X).item())

    return time, residual, alpha_list
