import math
import torch
import numpy as np
from open_spiel.python import rl_environment

def centering(K):
    n = K.shape[0]
    unit = np.ones([n, n])
    I = np.eye(n)
    H = I - unit / n

    return np.dot(np.dot(H, K), H)  # HKH are the same with KH, KH is the first centering, H(KH) do the second time, results are the sme with one time centering
    # return np.dot(H, K)  # KH


def rbf(X, sigma=None):
    GX = np.dot(X, X.T)
    KX = np.diag(GX) - GX + (np.diag(GX) - GX).T
    if sigma is None:
        mdist = np.median(KX[KX != 0])
        sigma = math.sqrt(mdist)
    KX *= - 0.5 / (sigma * sigma)
    KX = np.exp(KX)
    return KX


def kernel_HSIC(X, Y, sigma):
    return np.sum(centering(rbf(X, sigma)) * centering(rbf(Y, sigma)))


def linear_HSIC(X, Y):
    L_X = np.dot(X, X.T)
    L_Y = np.dot(Y, Y.T)
    return np.sum(centering(L_X) * centering(L_Y))


def linear_CKA(X, Y):
    hsic = linear_HSIC(X, Y)
    var1 = np.sqrt(linear_HSIC(X, X))
    var2 = np.sqrt(linear_HSIC(Y, Y))

    return hsic / (var1 * var2)


def kernel_CKA(X, Y, sigma=None):
    hsic = kernel_HSIC(X, Y, sigma)
    var1 = np.sqrt(kernel_HSIC(X, X, sigma))
    var2 = np.sqrt(kernel_HSIC(Y, Y, sigma))

    return hsic / (var1 * var2)



def get_similarity(game, policy1, policy2, device="cuda"):
    env = rl_environment.Environment(game)
    env.seed(seed=1)
    state_num = env.observation_spec()["info_state"][0]
    # randomly sampled input, shape: [batch_size, state_dim]
    # x = np.eye(state_num)
    x = np.random.randint(2, size=[1000, state_num])

    samples = torch.FloatTensor(x).to(device)
    
    layer_out1 = policy1.get_layer_out(samples)
    w1, w2, w3, w4 = layer_out1[0].cpu().detach().numpy(), layer_out1[1].cpu().detach().numpy(), layer_out1[2].cpu().detach().numpy(), layer_out1[3].cpu().detach().numpy()
    layer_out2 = policy2.get_layer_out(samples)
    w1_, w2_, w3_, w4_ = layer_out2[0].cpu().detach().numpy(), layer_out2[1].cpu().detach().numpy(), layer_out2[2].cpu().detach().numpy(), layer_out2[4].cpu().detach().numpy()
    # compute CKA similarity of each layer
    cka_w1 = kernel_CKA(w1, w1_)
    cka_w2 = kernel_CKA(w2, w2_)
    cka_w3 = kernel_CKA(w3, w3_)
    cka_w4 = kernel_CKA(w4, w4_)

    kel_ckas = (cka_w1, cka_w2, cka_w3, cka_w4)
    return kel_ckas


class bc_model_cka():
    def __init__(self, bc_model):
        self.model = bc_model

    def get_layer_out(self, state):
        out = []
        x = torch.relu(self.model.state(state))
        out.append(x)
        x = torch.relu(self.model.linear_hidden(x))
        out.append(x)
        x = torch.relu(self.model.linear_hidden_2(x))
        out.append(x)
        x = torch.softmax(self.model.out(x), dim=1)
        # x = self.model.out(x)
        out.append(x)
        return out

class mb_model_cka():
    def __init__(self, bc_model):
        self.model = bc_model

    def get_layer_out(self, x):
        out = []
        for layer in self.model.model:
            x = layer(x)
            out.append(x)
        out.append(torch.softmax(x, dim=1))
        return out


