""" Functions for gradient descent and teleportations. """

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from curvature_utils import W_list_to_vec, vec_to_W_list, compute_curvature
import math

def init_param_MLP(dim, seed=54321):
    # dim: list of dimensions of weight matrices. 
    # Example: [4, 5, 6, 7, 8] -> X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4
    torch.manual_seed(seed)
    W_list = []
    for i in range(len(dim) - 2):
        k = 1 / np.sqrt(dim[i+1]) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
        W = 2 * k * torch.rand(dim[i+2], dim[i+1], requires_grad=True) - k
        W_list.append(W)
    return W_list

def init_param_transformer(input_dim, model_dim, out_dim, layer_num, seed=54321):
    # dim: list of dimensions of weight matrices. 
    # Example: [4, 5, 6, 7, 8] -> X: 5x4, W1:6x5, W2:7x6, W3:8x7, Y:8x4
    torch.manual_seed(seed)
    
    Wq_list = []
    Wk_list = []
    Wv_list = []
    W1_list = []
    W2_list = []

    k = 1 / np.sqrt(input_dim) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
    W = 2 * k * torch.rand(model_dim, input_dim, requires_grad=True) - k
    W_embed = W
    k = 1 / np.sqrt(model_dim) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
    W = 2 * k * torch.rand(out_dim, model_dim, requires_grad=True) - k
    W_out = W
    
    for i in range(layer_num):
        k = 1 / np.sqrt(model_dim) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
        W = 2 * k * torch.rand(model_dim, model_dim, requires_grad=True) - k
        Wq_list.append(W)
    for i in range(layer_num):
        k = 1 / np.sqrt(model_dim) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
        W = 2 * k * torch.rand(model_dim, model_dim, requires_grad=True) - k
        Wk_list.append(W)
    for i in range(layer_num):
        k = 1 / np.sqrt(model_dim) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
        W = 2 * k * torch.rand(model_dim, model_dim, requires_grad=True) - k
        Wv_list.append(W)
    for i in range(layer_num):
        k = 1 / np.sqrt(model_dim) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
        W = 2 * k * torch.rand(model_dim, model_dim, requires_grad=True) - k
        W1_list.append(W)
    for i in range(layer_num):
        k = 1 / np.sqrt(model_dim) # https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
        W = 2 * k * torch.rand(model_dim, model_dim, requires_grad=True) - k
        W2_list.append(W)

    W_list_all = [W_embed, Wq_list, Wk_list, Wv_list, W1_list, W2_list, W_out]
    
    return W_list_all
    
def loss_multi_layer(W_list, X, Y, sigma):
    h = X
    for i in range(len(W_list)-1):
        h = sigma(torch.matmul(W_list[i], h))
    pred = torch.matmul(W_list[-1], h)
    pred = F.log_softmax(pred, dim=0)
    return F.nll_loss(torch.t(pred), Y), pred, h

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        device = x.device
        x = x + self.pe[:x.size(0)].to(device)
        return self.dropout(x)
        
def loss_multi_layer_transformer(W_list_all, X, Y, sigma):
    h = X
    W_embed, Wq_list, Wk_list, Wv_list, W1_list, W2_list, W_out = W_list_all
    d_model = W_embed.shape[0]
    seq_len = X.shape[-1]
    positional_encoding = PositionalEncoding(d_model, max_len=seq_len)
    h = torch.matmul(W_embed, h)
    h = positional_encoding(h.permute(2, 0, 1)).permute(1, 2, 0)
    h_res = h
    for i in range(len(Wq_list)):
        # Compute query, key, and value matrices
        Q = torch.matmul(Wq_list[i], h)
        K = torch.matmul(Wk_list[i], h)
        V = torch.matmul(Wv_list[i], h)
        
        # Compute attention scores (scaled dot-product attention)
        attn_scores = torch.matmul(Q.transpose(-2, -1), K) / torch.sqrt(torch.tensor(K.size(-2), dtype = torch.float32))
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # Apply attention weights to values
        attn_output = torch.matmul(V, attn_weights)
        # Apply the feedforward layers
        h = sigma(torch.matmul(W1_list[i], attn_output))
        
        h = torch.matmul(W2_list[i], h)
        # residual connection and layernorm
        h = h + h_res
        
        h = F.layer_norm(h, h.size()[1:])

        h_res = h
        
    pred = torch.matmul(W_out, h.mean(-1).T)
    
    pred = F.log_softmax(pred, dim=0)
    return F.nll_loss(torch.t(pred), Y), pred, attn_output[0]

def loss_MLP_from_vec(W_vec_all, X, Y, dim, sigma):
    W_list = vec_to_W_list(W_vec_all, dim)
    L, _ = loss_multi_layer(W_list, X, Y, sigma)
    return L

def valid_MLP(model, criterion, valid_loader, device):
    model.eval()
    test_loss = 0.0
    test_correct = 0
    for data, target in valid_loader:
        batch_size = data.shape[0]
        data = torch.t(data.view(batch_size, -1)).to(device)
        target = target.to(device)
        output = model(data)
        L = criterion(output.T, target)
        test_loss += L.item()*data.size(1)

        _, pred = torch.max(output, 0)
        test_correct += pred.eq(target.data.view_as(pred)).sum().item()

    test_loss = test_loss / len(valid_loader.sampler)
    test_correct = 100.0 * test_correct / len(valid_loader.sampler)
    return test_loss, test_correct

def valid_transformer(model, criterion, valid_loader, device):
    model.eval()
    test_loss = 0.0
    test_correct = 0
    for data, target in valid_loader:
        batch_size = data.shape[0]
        data = data.view(batch_size, -1).unsqueeze(1).to(device)
        target = target.to(device)
        output = model(data)
        L = criterion(output.T, target)
        test_loss += L.item()*data.size(0)

        _, pred = torch.max(output, 0)
        test_correct += pred.eq(target.data.view_as(pred)).sum().item()

    test_loss = test_loss / len(valid_loader.sampler)
    test_correct = 100.0 * test_correct / len(valid_loader.sampler)
    return test_loss, test_correct

def train_step(x_train, y_train, model, criterion, optimizer):
    model.train()
    optimizer.zero_grad()
    output = model.forward(x_train)
    loss = criterion(output.T, y_train)
    loss.backward()
    optimizer.step()
    return loss

def test_MLP(model, criterion, test_loader, device):
    model.eval()
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    for data, target in test_loader:
        batch_size = data.shape[0]
        data = torch.t(data.view(batch_size, -1)).to(device)
        target = target.to(device)
        output = model(data)
        L = criterion(output.T, target)
        test_loss += L.item()*data.size(1)
        _, pred = torch.max(output, 0)
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    test_loss = test_loss/len(test_loader.sampler)
    print('Test Loss: {:.6f}\n'.format(test_loss))

    for i in range(10):
        print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
            str(i), 100 * class_correct[i] / class_total[i],
            np.sum(class_correct[i]), np.sum(class_total[i])))
    print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
        100. * np.sum(class_correct) / np.sum(class_total),
        np.sum(class_correct), np.sum(class_total)))
    return test_loss, np.sum(class_correct) / np.sum(class_total)


##############################################################
# group actions
def group_action(U, V, X, X_inv, T, sigma):
    # U, V -> U sigma(VX) sigma((I+T)VX)^+, (I+T)V
    device = X.device
    k = list(T.size())[0]
    I = torch.eye(k).to(device)

    V_out = torch.matmul((I+T), V)
    Wh = torch.matmul(V, X)
    sigma_Wh = sigma(Wh)
    sigma_gWh = sigma(torch.matmul((I+T), Wh))
    sigma_gWh_inv = torch.linalg.pinv(sigma_gWh)
    U_out = torch.matmul(torch.matmul(U, sigma_Wh), sigma_gWh_inv)
    return U_out, V_out

def group_action_large(U, V, X, X_inv, g, g_inv, sigma):
    # U, V -> U sigma(VX) sigma(gVX)^+, gV
    device = X.device
    k = list(g.size())[0]
    I = torch.eye(k).to(device)

    V_out = torch.matmul(g, V)
    Wh = torch.matmul(V, X)
    sigma_Wh = sigma(Wh)
    sigma_gWh = sigma(torch.matmul(g, Wh))
    sigma_gWh_inv = torch.linalg.pinv(sigma_gWh)
    U_out = torch.matmul(torch.matmul(U, sigma_Wh), sigma_gWh_inv)
    return U_out, V_out

def group_action_exp(t, U, V, X, X_inv, M, sigma):
    # U, V -> U sigma(VX) sigma(exp(tM)VX)^+, exp(tM)V

    g = torch.linalg.matrix_exp(t * M)
    g_inv = torch.linalg.pinv(g)

    V_out = torch.matmul(g, V)
    Wh = torch.matmul(V, X)
    sigma_Wh = sigma(Wh)
    sigma_gWh = sigma(torch.matmul(g, Wh))
    sigma_gWh_inv = torch.linalg.pinv(sigma_gWh)
    U_out = torch.matmul(torch.matmul(U, sigma_Wh), sigma_gWh_inv)
    return U_out, V_out

##############################################################
# first (or second) derivatives of the component of gamma corresponding to U (or V)
def compute_gamma_1_U(t, U, V, h, h_inv, M, sigma):
    func = lambda t_: group_action_exp(t_, U, V, h, h_inv, M, sigma)[0]
    gamma_1 = torch.autograd.functional.jacobian(func, t, create_graph=True)
    gamma_1 = torch.squeeze(gamma_1)
    return gamma_1

def compute_gamma_1_V(t, U, V, h, h_inv, M, sigma):
    func = lambda t_: group_action_exp(t_, U, V, h, h_inv, M, sigma)[1]
    gamma_1 = torch.autograd.functional.jacobian(func, t, create_graph=True)
    gamma_1 = torch.squeeze(gamma_1)
    return gamma_1

def compute_gamma_2_U(t, U, V, h, h_inv, M, sigma):
    func = lambda t_: compute_gamma_1_U(t_, U, V, h, h_inv, M, sigma)
    gamma_2 = torch.autograd.functional.jacobian(func, t, create_graph=True)
    gamma_2 = torch.squeeze(gamma_2)
    return gamma_2

def compute_gamma_2_V(t, U, V, h, h_inv, M, sigma):
    func = lambda t_: compute_gamma_1_V(t_, U, V, h, h_inv, M, sigma)
    gamma_2 = torch.autograd.functional.jacobian(func, t, create_graph=True)
    gamma_2 = torch.squeeze(gamma_2)
    return gamma_2

##############################################################
# teleportation

def teleport_curvature(W_list, X, Y, lr_teleport, dim, sigma, telestep=10, reverse=False):
    # reverse = True if minimizing curvature, False if maximizing curvature.
    print("before teleport", loss_multi_layer(W_list, X, Y, sigma)[0])

    X_inv = torch.linalg.pinv(X)
    h_list = [X]
    h_inv_list = [X_inv]
    for m in range(0, len(W_list)-2):
        h = sigma(torch.matmul(W_list[m], h_list[-1]))
        h_list.append(h)
        h_inv_list.append(torch.linalg.pinv(h))

    for teleport_step in range(telestep):
        layer = 1
        t = torch.zeros(1, requires_grad=True)
        M = torch.rand(dim[layer+2], dim[layer+2], requires_grad=True)

        # compute curvature using autograd
        gamma_1_U = compute_gamma_1_U(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma)
        gamma_1_V = compute_gamma_1_V(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma)

        gamma_2_U = compute_gamma_2_U(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma)
        gamma_2_V = compute_gamma_2_V(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma)
        
        gamma_1_list = []
        gamma_2_list = []
        for m in range(0, len(W_list)):
            gamma_1_list.append(torch.zeros_like(W_list[m]))
            gamma_2_list.append(torch.zeros_like(W_list[m]))
        
        gamma_1_list[0+layer] = gamma_1_U
        gamma_1_list[1+layer] = gamma_1_V
        gamma_2_list[0+layer] = gamma_2_U
        gamma_2_list[1+layer] = gamma_2_V

        kappa = compute_curvature(gamma_1_list, gamma_2_list) # curvature
        kappa_1 = torch.autograd.grad(kappa, inputs=t, create_graph=True)[0] # derivative of curvature
        
        # gradient descent/ascent on t to decrease/increase curvature
        if reverse:
            t = t - lr_teleport * kappa_1
        else:
            t = t + lr_teleport * kappa_1
        print(kappa, kappa_1, t)
        
        # transform weights using the updated t
        g = torch.linalg.matrix_exp(t * M)
        g_inv = torch.linalg.pinv(g) 
        W_list[1+layer], W_list[0+layer] = group_action_exp(t, W_list[1+layer], W_list[0+layer], h_list[0+layer], h_inv_list[0+layer], M, sigma)

        h_list = [X]
        h_inv_list = [X_inv]
        for m in range(0, len(W_list)-2):
            h = sigma(torch.matmul(W_list[m], h_list[-1]))
            h_list.append(h)
            h_inv_list.append(torch.linalg.pinv(h))

    print("after teleport", loss_multi_layer(W_list, X, Y, sigma)[0])
        
    return W_list


def teleport_sharpness(W_list, X, Y, lr_teleport, dim, sigma, telestep=10, loss_perturb_cap=2.0, reverse=False, \
    t_start=0.001, t_end=0.2, t_interval=0.001):
    # reverse = True if minimizing sharpness, False if maximizing sharpness.

    X_inv = torch.linalg.pinv(X)
    h_list = [X]
    h_inv_list = [X_inv]
    for m in range(0, len(W_list)-2):
        h = sigma(torch.matmul(W_list[m], h_list[-1]))
        h_list.append(h)
        h_inv_list.append(torch.linalg.pinv(h))

    for teleport_step in range(telestep):
        gW_list = W_list.copy()
        T = [] # list of elements of Lie algebras

        # initialize T[i] = 0 and g.W = (I+T).W
        for m in range(0, len(gW_list)-1):
            T.append(torch.zeros(dim[m+2], dim[m+2], requires_grad=True))
            gW_list[m+1], gW_list[m] = group_action(gW_list[m+1], gW_list[m], h_list[m], h_inv_list[m], T[m], sigma)

        # compute sharpness (loss_perturb_mean)
        num_t = len(np.arange(0.1, 5.0, 0.5))
        num_d = 100
        loss_perturb_mean = 0.0
        for (idx, t) in enumerate(np.arange(0.1, 1.0, 0.1)):
            for d_idx in range(num_d):
                W_vec_all = W_list_to_vec(gW_list)
                random_dir = torch.rand(W_vec_all.size()[0], requires_grad=True)
                random_dir = random_dir / torch.norm(random_dir) * t
                W_vec_all_perturb = W_vec_all + random_dir
                loss_perturb = loss_MLP_from_vec(W_vec_all_perturb, X, Y, dim, sigma)
                loss_perturb_mean += loss_perturb
        loss_perturb_mean = loss_perturb_mean / num_t / num_d
        print(teleport_step, loss_perturb_mean)
        if loss_perturb_mean > loss_perturb_cap:
            break

        # gradient descent/ascent on T to decrease/increase sharpness (loss_perturb_mean)
        dLdt_dT_list = torch.autograd.grad(loss_perturb_mean, inputs=T, create_graph=True)
        for i in range(len(T)):
            if reverse:
                T[i] = T[i] - lr_teleport * dLdt_dT_list[i]
            else:
                T[i] = T[i] + lr_teleport * dLdt_dT_list[i]

        # transform weights using the updated T
        for m in range(0, len(W_list)-1):
            W_list[m+1], W_list[m] = group_action(W_list[m+1], W_list[m], h_list[m], h_inv_list[m], T[m], sigma)

        # update the list of hidden representations h_list
        for m in range(1, len(h_list)):
            k = list(T[m-1].size())[0]
            I = torch.eye(k)
            h_list[m] = torch.matmul(I + T[m-1], h_list[m])
            h_inv_list[m] = torch.matmul(h_inv_list[m], I - T[m-1])
        
    return W_list


import numpy as np
def compute_conv_output_size(Lin,kernel_size,stride=1,padding=1,dilation=1):
    return int(np.floor((Lin+2*padding-dilation*(kernel_size-1)-1)/float(stride)+1))
    
def get_projection_matrix(model, x, y=None, CNN = False, attention_mask = None): 
    model.eval()
    bsz = x.size(0)
    with torch.no_grad():
        if not CNN:
            if attention_mask is None:
                _, activations = model(x)
            else:
                _, activations = model(x, attention_mask)
        else:
            _, activations, maps, channel_ls = model(x)
            # strides = [3,1,1,3,1,1,3,1,1]
            strides = [1,1,1]
            mat_list = []
            for i in range(len(activations)):
                k=0
                if i<3:
                    ksz= 3
                    s=compute_conv_output_size(maps[i],ksz, strides[i])
                    mat = torch.zeros((ksz*ksz*channel_ls[i],s*s*bsz))
                    act = activations[i]
                    for kk in range(bsz):
                        for ii in range(s):
                            for jj in range(s):
                                mat[:,k]=act[kk,:,ii:ksz+ii,jj:ksz+jj].reshape(-1) 
                                k +=1
                    mat_list.append(mat.to(x.device))
                else:
                    mat_list.append(activations[i])
    
            activations = mat_list
    ## Calculate feature matrix
    threshold = 1
    feature_list=[]
    for act in activations:
        U,S,Vh = np.linalg.svd(act.cpu().numpy(), full_matrices=False)
        # criteria (Eq-5)
        sval_total = (S**2).sum()
        sval_ratio = (S**2)/sval_total
        r = np.sum(np.cumsum(sval_ratio, axis = 0)<threshold) #+1  
        if r != len(sval_ratio):
            r+=1
        feature_list.append(U[:,0:r])

    projection_mat_list = []
    for f in feature_list:
        projection_matrix = f@f.T
        projection_mat_list.append(torch.tensor(projection_matrix).to(x.device))
    return projection_mat_list    


def teleport(model, X, Y, criterion, lr_teleport, sigma, telestep=10, dL_dt_cap=30, random_teleport=False, reverse=False, PennTree = False, CNN = False, attention_mask = None):
    model.eval()
    if attention_mask is not None:
        projection_mat = get_projection_matrix(model, X, y=None, CNN = CNN, attention_mask = attention_mask)
    else:
        projection_mat = get_projection_matrix(model, X, y=None, CNN = CNN)

    
    for teleport_step in range(telestep):
        if CNN:
            outs, _,_,_ = model(X)
        elif attention_mask is not None:
            outs, _ = model(X, attention_mask)
        else:
            outs, _ = model(X)
            
        if PennTree:
            L = criterion(outs.reshape(-1,outs.shape[-1]), Y.reshape(-1))
        else:
            L = criterion(outs, Y)
            
        dL_dt = 0
        if PennTree or attention_mask is not None:
            for param in list(model.parameters())[1:]:
                dL_dt += torch.norm(torch.autograd.grad(L, inputs=param, create_graph=True)[0])**2
        elif not CNN:
            for param in model.parameters():
                dL_dt += torch.norm(torch.autograd.grad(L, inputs=param, create_graph=True)[0])**2
        else:
            for param in model.parameters():
                if len(param.size())!=1:
                    dL_dt += torch.norm(torch.autograd.grad(L, inputs=param, create_graph=True)[0])**2
        if teleport_step == 0:
            init_dL_dt = dL_dt.cpu().detach()
            print('new')
        print(L.item(), dL_dt.item())
        if dL_dt.cpu().detach()> dL_dt_cap or dL_dt.cpu().detach() > 5 * init_dL_dt:
            # # do one backward to free, delete this line for efficiency
            # torch.autograd.grad(dL_dt, inputs=list(model.parameters())[0])
            break
        dL_dt_dW = []
        if PennTree or attention_mask is not None:
            for i, (param, proj_mat) in enumerate(zip(list(model.parameters())[1:], projection_mat)):
                if i == len(projection_mat) - 1:
                    dL_dt_dW.append(torch.autograd.grad(dL_dt, inputs=param)[0])
                else:
                    dL_dt_dW.append(torch.autograd.grad(dL_dt, inputs=param, retain_graph=True)[0])
        elif not CNN:
            for i, (param, proj_mat) in enumerate(zip(model.parameters(), projection_mat)):
                if i == len(projection_mat) - 1:
                    dL_dt_dW.append(torch.autograd.grad(dL_dt, inputs=param)[0])
                else:
                    dL_dt_dW.append(torch.autograd.grad(dL_dt, inputs=param, retain_graph=True)[0])
        else:
            i = 0
            for param in model.parameters():
                if len(param.size())!=1:
                    if i == len(projection_mat) - 1:
                        dL_dt_dW.append(torch.autograd.grad(dL_dt, inputs=param)[0])
                    else:
                        dL_dt_dW.append(torch.autograd.grad(dL_dt, inputs=param, retain_graph=True)[0])
                    i+=1
                    
        if PennTree or attention_mask is not None:
            assert len(list(model.parameters())) - 1 ==len(dL_dt_dW)
        elif not CNN:
            assert len(list(model.parameters()))==len(dL_dt_dW)

        if PennTree or attention_mask is not None: 
            for i, (param, proj_mat) in enumerate(zip(list(model.parameters())[1:], projection_mat)):
                param.requires_grad_(False)
                param += (lr_teleport * (dL_dt_dW[i] - torch.mm(dL_dt_dW[i], proj_mat)))
                param.requires_grad_(True)
        else:
            if not CNN:
                for i, (param, proj_mat) in enumerate(zip(model.parameters(), projection_mat)):
                    param.requires_grad_(False)
                    param += (lr_teleport * (dL_dt_dW[i] - torch.mm(dL_dt_dW[i], proj_mat)))
                    param.requires_grad_(True)
            else:
                i = 0
                for param in model.parameters():
                    if len(param.size())!=1:
                        param.requires_grad_(False)
                        sz =  param.size(0)
                        param += (lr_teleport * (dL_dt_dW[i] - torch.mm(dL_dt_dW[i].reshape(sz,-1), projection_mat[i]).reshape(param.size())))
                        param.requires_grad_(True)
                        i += 1
    return model
    
    
