import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.optimizer import Optimizer

from torch import linalg
import numpy as np


from typing import List
import torch.jit

@torch.jit.script
def block_matrix(blocks: List[List[torch.Tensor]], dim0: int = -2, dim1: int = -1):
    # [[A, B], [C, D]] ->
    # [AB]
    # [CD]
    hblocks = []
    for mats in blocks:
        hblocks.append(torch.cat(mats, dim=dim1))
    return torch.cat(hblocks, dim=dim0)


class RGD_QR(Optimizer):
    ''' Implementation of stochastic Riemannian gradient descent with QR retraction.
    '''

    def __init__(self, params, lr):
        defaults = dict(lr=lr)
        super(RGD_QR, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        group = self.param_groups[0]
        loss = None
        for p in group['params']:
            d_p = p.grad
            # Riem grad
            XtG = torch.mm(torch.transpose(p.data, 0, 1), d_p.data)
            symXtG = 0.5 * (XtG + torch.transpose(XtG, 0, 1))
            Riema_grad = d_p.data - torch.mm(p.data, symXtG)
            # qr_unique
            q_temp, r_temp = linalg.qr(p.data - group['lr'] * Riema_grad)
            unflip = torch.diagonal(r_temp).sign().add(0.5).sign()
            q_temp *= unflip[..., None, :]

            p.data = q_temp
        return loss



class RGD_GEN(Optimizer):
    ''' Implementation of stochastic Riemannian gradient descent with general retractions
        '''
    def __init__(self, params, lr, retraction="QR"):
        defaults = dict(lr=lr)
        super(RGD_GEN, self).__init__(params, defaults)

        self.retraction = retraction

    def qr(self, X, U):
        q_temp, r_temp = linalg.qr(X + U)
        unflip = torch.diagonal(r_temp).sign().add(0.5).sign()
        q_temp *= unflip[..., None, :]
        return q_temp

    def exp(self, X, U):
        # xtu = X.T @ U
        # utu = U.T @ X
        # eye = torch.zeros_like(utu)
        # eye[ torch.arange(utu.shape[-2]), torch.arange(utu.shape[-2])] += 1
        # logw = block_matrix(((xtu, -utu), (eye, xtu)))
        # w = torch.matrix_exp(logw)
        # z = torch.cat((torch.matrix_exp(-xtu), torch.zeros_like(utu)), dim=-2)
        # Y = torch.cat((X, U), dim=-1) @ w @ z

        p = X.shape[1]
        I = torch.eye(p, dtype=X.dtype, device=X.device)
        Z = torch.zeros((p, p), dtype=X.dtype, device=X.device)

        upper = torch.cat((X.T @ U, -U.T @ U), dim=1)
        lower = torch.cat((I, X.T @ U), dim=1)
        exp_term = linalg.matrix_exp(torch.cat((upper, lower), dim=0))

        Y = torch.cat((X, U), dim=1) @ exp_term @ torch.cat((linalg.matrix_exp(-X.T @ U), Z), dim=0)

        # due to numerical instability, it seems necessary for re-normalization
        q_temp, r_temp = linalg.qr(Y)
        unflip = torch.diagonal(r_temp).sign().add(0.5).sign()
        q_temp *= unflip[..., None, :]
        return q_temp

    def polar(self, X, U):
        U, _, Vh = linalg.svd(X + U, full_matrices=False)
        Y = U @ Vh
        # due to numerical instability, it seems necessary for re-normalization
        q_temp, r_temp = linalg.qr(Y)
        unflip = torch.diagonal(r_temp).sign().add(0.5).sign()
        q_temp *= unflip[..., None, :]
        return q_temp

    def cayley(self, X, U):
        n = X.shape[0]
        xtu = X.T @ U
        Wu = (U - 0.5 * (X @ xtu)) @ X.T
        Wu = Wu - Wu.T
        rhs = (torch.eye(n, dtype=X.dtype, device=X.device) + 0.5 * Wu) @ X
        lhs = torch.eye(n, dtype=X.dtype, device=X.device) - 0.5 * Wu
        Y = linalg.solve(lhs, rhs)
        return Y


    @torch.no_grad()
    def step(self):
        group = self.param_groups[0]
        loss = None
        for p in group['params']:
            d_p = p.grad
            XtG = torch.mm(torch.transpose(p.data, 0, 1), d_p.data)
            symXtG = 0.5 * (XtG + torch.transpose(XtG, 0, 1))
            Riema_grad = d_p.data - torch.mm(p.data, symXtG)

            if self.retraction == 'QR':
                p.data = self.qr(p.data, - group['lr'] * Riema_grad)
            elif self.retraction == 'EXP':
                p.data = self.exp(p.data, - group['lr'] * Riema_grad)
            elif self.retraction == 'POL':
                p.data = self.polar(p.data, - group['lr'] * Riema_grad)
            elif self.retraction == 'CAY':
                p.data = self.cayley(p.data, -group['lr'] * Riema_grad)
            else:
                raise ValueError

            # q_temp, r_temp = linalg.qr(p.data - group['lr'] * Riema_grad)
            # unflip = torch.diagonal(r_temp).sign().add(0.5).sign()
            # q_temp *= unflip[..., None, :]

        return loss




class PCAL(Optimizer):
    ''' Implementation of PCAL infeasible method for optimization with orthogonality constraint
    '''

    def __init__(self, params, lr, lam, lr_type='fix'):
        defaults = dict(lr=lr)
        super(PCAL, self).__init__(params, defaults)

        print(f"Using {lr_type} stepsize rule")
        self.dir_pre = None
        self.X_pre = None
        self.lr_type = lr_type
        self.lam = lam

    @torch.no_grad()
    def step(self):
        group = self.param_groups[0]
        loss = None
        for p in group['params']:
            d_p = p.grad
            GtX = d_p.t().mm(p.data)
            GtXsym = (GtX + GtX.t())/2
            XX = p.data.t().mm(p.data)
            penalFeaX = self.lam * (XX - torch.eye(XX.size()[0], device=XX.device, dtype=XX.dtype))
            dd = torch.diag(GtX.t() - XX.mm(GtXsym) + XX.mm(penalFeaX))
            dir = d_p - p.data.mm(GtXsym) - p.data * dd + p.data.mm(penalFeaX)

            if self.lr_type == 'BB':
                if self.dir_pre is None:
                    lr = max(0.1, min(0.01*torch.norm(dir).item(), 1))
                else:
                    Sk = p.data - self.X_pre
                    Vk = dir - self.dir_pre
                    SV = (Sk * Vk).sum()
                    proxparam = torch.norm(Vk).pow(2)/torch.abs(SV)
                    lr = 1/max(0, min(proxparam.item(), 1000))

                self.X_pre = p.data
                self.dir_pre = dir
            else:
                lr = group['lr']

            p.data.add_(-lr * dir)
            nx = (p.data * p.data).sum(dim=0).sqrt()
            p.data = p.data / nx

        return loss




class Landing(Optimizer):
    ''' Implementation of Landing flow method for optimization with orthognality constraint
    '''

    def __init__(self, params, lr, lam):
        defaults = dict(lr=lr)
        super(Landing, self).__init__(params, defaults)

        self.lam = lam

    @torch.no_grad()
    def step(self):
        group = self.param_groups[0]
        loss = None
        for p in group['params']:
            d_p = p.grad
            XtX = p.data.t().mm(p.data)
            relgradX = (d_p @ XtX - p.data @ (d_p.t() @ p.data))/2
            distX = p.data @ XtX - p.data
            dir = relgradX + self.lam * distX

            p.data = p.data - group['lr'] * dir

        return loss




# class PLAM(Optimizer):
#     ''' Implementation of PLAM infeasible method for optimization with orthogonality constraint
#     '''
#     def __init__(self, params, lr, lam):
#         defaults = dict(lr=lr)
#         super(PLAM, self).__init__(params, defaults)
#
#     @torch.no_grad()
#     def step(self):
#         group = self.param_groups[0]
#         loss = None
#         for p in group['params']:
#             d_p = p.grad
#             if not p.orth:
#                 p.data.add_(-group['lr'] * d_p.data)
#             else:
#                 GtX = d_p.t().mm(p.data)








# class RCD(Optimizer):
#     ''' Implementation of stochastic Riemannian (linearized) coordinate descent
#         for orthogonal constraints.
#     '''
#
#     def __init__(self, params, lr, n, p):
#         defaults = dict(lr=lr)
#         super(RCD, self).__init__(params, defaults)
#         self.flops = 0
#         idx1, idx2 = np.triu_indices(n, k=1)
#         self.idx1 = idx1
#         self.idx2 = idx2
#
#     @torch.no_grad()
#     def step(self, idx):
#         group = self.param_groups[0]
#         loss = None
#         for p in group['params']:
#             d_p = p.grad
#             if not p.orth:
#                 p.data.add_(-group['lr'] * d_p.data)
#             else:
#                 n = p.data.size()[0]
#                 pp = p.data.size()[1]
#                 assert (n >= pp)
#
#                 ii = self.idx1[idx]
#                 jj = self.idx2[idx]
#                 XVt = torch.inner(p.data[ii, :], d_p.data[jj, :])
#                 VXt = torch.inner(d_p.data[ii, :], p.data[jj, :])
#                 eta = -group['lr'] * (XVt - VXt)
#                 vi = torch.cos(eta) * p.data[ii, :] + torch.sin(eta) * p.data[jj, :]
#                 vj = -torch.sin(eta) * p.data[ii, :] + torch.cos(eta) * p.data[jj, :]
#                 p.data[ii, :] = vi
#                 p.data[jj, :] = vj
#
#                 self.flops = self.flops + 10 * pp
#
#         return loss


# class RCD(Optimizer):
#     ''' Implementation of stochastic Riemannian (linearized) coordinate descent
#         for orthogonal constraints.
#     '''
#
#     def __init__(self, params, lr, numupdate=1):
#         defaults = dict(lr=lr)
#         super(RCD, self).__init__(params, defaults)
#         self.numupdate = numupdate
#         self.flops = 0
#
#     @torch.no_grad()
#     def step(self):
#         group = self.param_groups[0]
#         loss = None
#         for p in group['params']:
#             d_p = p.grad
#             if not p.orth:
#                 p.data.add_(-group['lr'] * d_p.data)
#             else:
#                 n = p.data.size()[0]
#                 pp = p.data.size()[1]
#                 assert (n >= pp)
#                 # GXt = torch.mm(d_p.data, torch.transpose(p.data,0,1))
#                 # skewGXt = -group['lr']*(GXt - torch.transpose(GXt,0,1))
#
#                 for it in range(self.numupdate):
#                     # assert self.numupdate < nn/2
#                     idx = (torch.randperm(n)).to(p.device)
#                     ii = idx[0]
#                     jj = idx[1]
#                     # eta = skewGXt[ii,jj] # 4p
#                     XVt = torch.inner(p.data[jj, :], d_p.data[ii, :])
#                     VXt = torch.inner(d_p.data[jj, :], p.data[ii, :])
#                     eta = -group['lr'] * (XVt - VXt)
#                     vi = torch.cos(eta) * p.data[ii, :] + torch.sin(eta) * p.data[jj, :]
#                     vj = -torch.sin(eta) * p.data[ii, :] + torch.cos(eta) * p.data[jj, :]
#                     p.data[ii, :] = vi
#                     p.data[jj, :] = vj
#
#                 self.flops = self.flops + self.numupdate * 10 * pp
#
#         return loss



class TSD(Optimizer):
    def __init__(self, params, lr, n, p):
        defaults = dict(lr=lr)
        super(TSD, self).__init__(params, defaults)
        assert n > p, "Use RCD instead for the case n = p"
        self.flops = 0

    @torch.no_grad()
    def step(self):
        group = self.param_groups[0]
        loss = None
        for p in group['params']:
            d_p = p.grad
            if not p.orth:
                p.data.add_(-group['lr'] * d_p.data)
            else:
                n = p.data.size()[0]
                pp = p.data.size()[1]
                assert (n > pp)

                for it in range(self.numupdate):
                    pthre = (pp - 1) / (pp + 1)
                    randn = torch.rand(1).item()
                    if randn < pthre:
                        # pick i,j pairs
                        b = torch.randperm(pp).to(p.device)
                        b0 = b[0]
                        b1 = b[1]
                        XtV = torch.inner(p.data[:, b0], d_p.data[:, b1])  # 2n
                        VtX = torch.inner(d_p.data[:, b0], p.data[:, b1])  # 2n
                        alpha = -group['lr'] * (XtV - VtX)
                        v = torch.cos(alpha) * p.data[:, b0] - torch.sin(alpha) * p.data[:, b1]
                        w = torch.sin(alpha) * p.data[:, b0] + torch.cos(alpha) * p.data[:, b1]
                        p.data[:, b0] = v
                        p.data[:, b1] = w

                        self.flops = self.flops + 10 * n
                    else:
                        # pick k index
                        b = torch.randperm(pp).to(p.device)
                        b0 = b[0]
                        Vj = d_p.data[:, b0:b0 + 1]  # keep dim
                        projVj = torch.mm(torch.transpose(p.data, 0, 1), Vj)  # 2np
                        projVj = Vj - torch.mm(p.data, projVj)  # 2np + n
                        projVj = -group['lr'] * projVj  # n
                        normVj = torch.norm(projVj)  # 2n
                        v = torch.cos(normVj) * p.data[:, b0] + torch.sin(normVj) / normVj * projVj.view(-1)  # 3n
                        p.data[:, b0] = v

                        self.flops = self.flops + 4 * n * pp + 7 * n

        return loss


class RandomRCD(Optimizer):
    ''' Implementation of stochastic Riemannian (linearized) coordinate descent
        for orthogonal constraints.
    '''

    def __init__(self, params, lr, numupdate=1):
        defaults = dict(lr=lr)
        super(RandomRCD, self).__init__(params, defaults)
        self.numupdate = numupdate
        self.flops = 0

    @torch.no_grad()
    def step(self):
        group = self.param_groups[0]
        loss = None
        for p in group['params']:
            d_p = p.grad
            if not p.orth:
                p.data.add_(-group['lr'] * d_p.data)
            else:
                n = p.data.size()[0]
                pp = p.data.size()[1]
                assert (n >= pp)
                # GXt = torch.mm(d_p.data, torch.transpose(p.data,0,1))
                # skewGXt = -group['lr']*(GXt - torch.transpose(GXt,0,1))

                for it in range(self.numupdate):
                    # assert self.numupdate < nn/2
                    idx = (torch.randperm(n)).to(p.device)
                    ii = idx[0]
                    jj = idx[1]
                    # eta = skewGXt[ii,jj] # 4p
                    XVt = torch.inner(p.data[ii, :], d_p.data[jj, :])
                    VXt = torch.inner(d_p.data[ii, :], p.data[jj, :])
                    eta = -group['lr'] * (XVt - VXt)
                    vi = torch.cos(eta) * p.data[ii, :] + torch.sin(eta) * p.data[jj, :]
                    vj = -torch.sin(eta) * p.data[ii, :] + torch.cos(eta) * p.data[jj, :]
                    p.data[ii, :] = vi
                    p.data[jj, :] = vj

                self.flops = self.flops + self.numupdate * 10 * pp

        return loss




class CyclicRCD(Optimizer):
    def __init__(self, params, lr, n, p):
        defaults = dict(lr=lr)
        super(CyclicRCD, self).__init__(params, defaults)
        idx1, idx2 = np.triu_indices(n, k=1)
        self.idx1 = idx1
        self.idx2 = idx2

    @torch.no_grad()
    def step(self, idx):
        group = self.param_groups[0]
        loss = None
        for p in group['params']:
            d_p = p.grad
            if not p.orth:
                p.data.add_(-group['lr'] * d_p.data)
            else:
                ii = self.idx1[idx]
                jj = self.idx2[idx]

