import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.metrics import normalized_mutual_info_score
from sklearn.metrics import adjusted_rand_score
from sklearn import linear_model
import pdb

from sklearn.preprocessing import normalize
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
from random import sample
import random
import os
import itertools
import torch.nn.functional as F
from cca_zoo.models import CCA,KCCA,KGCCA,MCCA

def cor(views):
    try:
        v = [view.cpu().detach().numpy() for view in views]
    except:
        v = [view.cpu().numpy() for view in views]
    #pdb.set_trace()
    return sum(CCA(v[0].shape[1]).fit(v).score(v))

def Initialize_Seed(seed=2):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark=False
    os.environ["PYTHONHASHSEED"]= str(seed)


import torch
import torch.nn as nn
from torch.distributions import Normal, Independent
from torch.nn.functional import softplus
from cca_zoo.deepmodels import architectures

# Encoder architecture
class Encoder(nn.Module):
    def __init__(self, in_dim,z_dim):
        super(Encoder, self).__init__()

        self.z_dim = z_dim

        # # Vanilla MLP
        # self.net = nn.Sequential(
        #     nn.Linear(in_dim, 1024),
        #     nn.ReLU(True),
        #     nn.Linear(1024, 1024),
        #     nn.ReLU(True),
        #     nn.Linear(1024, z_dim * 2),
        # )
        self.net =architectures.Encoder(latent_dims=z_dim*2, feature_size=in_dim,layer_sizes=(1024,1024,1024))

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        params = self.net(x)

        mu, sigma = params[:, :self.z_dim], params[:, self.z_dim:]
        sigma = softplus(sigma) + 1e-7  # Make sigma always positive

        return Independent(Normal(loc=mu, scale=sigma), 1)  # Return a factorized Normal distribution


class Decoder(nn.Module):
    def __init__(self, z_dim, scale=0.39894):
        super(Decoder, self).__init__()

        self.z_dim = z_dim
        self.scale = scale

        # Vanilla MLP
        # self.net = nn.Sequential(
        #     nn.Linear(z_dim, 1024),
        #     nn.ReLU(True),
        #     nn.Linear(1024, 1024),
        #     nn.ReLU(True),
        #     nn.Linear(1024, 100)
        # )

        self.net =architectures.Encoder(latent_dims=z_dim, feature_size=100,layer_sizes=(512,256))

    def forward(self, z):
        x = self.net(z)
        return Independent(Normal(loc=x, scale=self.scale), 1)


# Auxiliary network for mutual information estimation
class MIEstimator(nn.Module):
    def __init__(self, size1, size2):
        super(MIEstimator, self).__init__()

        # Vanilla MLP
        self.net = nn.Sequential(
            nn.Linear(size1 + size2, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1),
        )

    # Gradient for JSD mutual information estimation and EB-based estimation
    def forward(self, x1, x2):
        pos = self.net(torch.cat([x1, x2], 1))  # Positive Samples
        neg = self.net(torch.cat([torch.roll(x1, 1, 0), x2], 1))
        return -softplus(-pos).mean() - softplus(neg).mean(), pos.mean() - neg.exp().mean() + 1

import math


# Schedulers for beta
class Scheduler:
    def __call__(self, **kwargs):
        raise NotImplemented()


class LinearScheduler(Scheduler):
    def __init__(self, start_value, end_value, n_iterations, start_iteration=0):
        self.start_value = start_value
        self.end_value = end_value
        self.n_iterations = n_iterations
        self.start_iteration = start_iteration
        self.m = (end_value - start_value) / n_iterations

    def __call__(self, iteration):
        if iteration > self.start_iteration + self.n_iterations:
            return self.end_value
        elif iteration <= self.start_iteration:
            return self.start_value
        else:
            return (iteration - self.start_iteration) * self.m + self.start_value


class ExponentialScheduler(LinearScheduler):
    def __init__(self, start_value, end_value, n_iterations, start_iteration=0, base=10):
        self.base = base

        super(ExponentialScheduler, self).__init__(start_value=math.log(start_value, base),
                                                   end_value=math.log(end_value, base),
                                                   n_iterations=n_iterations,
                                                   start_iteration=start_iteration)

    def __call__(self, iteration):
        linear_value = super(ExponentialScheduler, self).__call__(iteration)
        return self.base ** linear_value


def Distance_Correlation(view):
    latent, control = view
    latent = F.normalize(latent)
    control = F.normalize(control)

    matrix_a = torch.sqrt(torch.sum(torch.square(latent.unsqueeze(0) - latent.unsqueeze(1)), dim = -1) + 1e-12)
    matrix_b = torch.sqrt(torch.sum(torch.square(control.unsqueeze(0) - control.unsqueeze(1)), dim = -1) + 1e-12)

    matrix_A = matrix_a - torch.mean(matrix_a, dim = 0, keepdims= True) - torch.mean(matrix_a, dim = 1, keepdims= True) + torch.mean(matrix_a)
    matrix_B = matrix_b - torch.mean(matrix_b, dim = 0, keepdims= True) - torch.mean(matrix_b, dim = 1, keepdims= True) + torch.mean(matrix_b)

    Gamma_XY = torch.sum(matrix_A * matrix_B)/ (matrix_A.shape[0] * matrix_A.shape[1])
    Gamma_XX = torch.sum(matrix_A * matrix_A)/ (matrix_A.shape[0] * matrix_A.shape[1])
    Gamma_YY = torch.sum(matrix_B * matrix_B)/ (matrix_A.shape[0] * matrix_A.shape[1])

        
    correlation_r = Gamma_XY/torch.sqrt(Gamma_XX * Gamma_YY + 1e-9)
        # correlation_r = torch.pow(Gamma_XY,2)/(Gamma_XX * Gamma_YY + 1e-9)
    return correlation_r


def Distance_Correlation_1(view):
    latent, control = view
    latent = F.normalize(latent)
    control = F.normalize(control)
    pdb.set_trace()

    matrix_a = torch.sqrt(torch.sum(torch.square(latent.unsqueeze(0) - latent.unsqueeze(1)), dim = -1) + 1e-12)
    matrix_b = torch.sqrt(torch.sum(torch.square(control.unsqueeze(0) - control.unsqueeze(1)), dim = -1) + 1e-12)

    matrix_A = matrix_a - torch.mean(matrix_a, dim = 0, keepdims= True) - torch.mean(matrix_a, dim = 1, keepdims= True) + torch.mean(matrix_a)
    matrix_B = matrix_b - torch.mean(matrix_b, dim = 0, keepdims= True) - torch.mean(matrix_b, dim = 1, keepdims= True) + torch.mean(matrix_b)

    Gamma_XY = torch.sum(matrix_A * matrix_B)/ (matrix_A.shape[0] * matrix_A.shape[1])
    Gamma_XX = torch.sum(matrix_A * matrix_A)/ (matrix_A.shape[0] * matrix_A.shape[1])
    Gamma_YY = torch.sum(matrix_B * matrix_B)/ (matrix_A.shape[0] * matrix_A.shape[1])

        
    correlation_r = Gamma_XY/torch.sqrt(Gamma_XX * Gamma_YY + 1e-9)
        # correlation_r = torch.pow(Gamma_XY,2)/(Gamma_XX * Gamma_YY + 1e-9)
    return correlation_r


def calc_distance_correlation_AB(A,A_1):
    N, dim = A.shape
    #A = A.reshape(N,dim,1)
    matrix_A = torch.abs(A.unsqueeze(0) - A.unsqueeze(1))  # N * N *dim 
    #matrix_A = torch.sqrt(torch.sum(torch.square(A.unsqueeze(0) - A.unsqueeze(1)), dim = -1) + 1e-12)
    #pdb.set_trace()
    #dim_avg = torch.mean(matrix_A.reshape(-1,dim),dim=0,keepdim=True).reshape((1,1,dim))
    #pdb.set_trace()
    matrix_A = matrix_A - torch.mean(matrix_A, dim = 0, keepdims= True) - torch.mean(matrix_A, dim = 1, keepdims= True) + torch.mean(matrix_A,dim=(0,1),keepdim=True)   # N*N *dim
    
    matrix_A_1 = torch.abs(A_1.unsqueeze(0) - A_1.unsqueeze(1))  # N * N *dim 
    #matrix_A = torch.sqrt(torch.sum(torch.square(A.unsqueeze(0) - A.unsqueeze(1)), dim = -1) + 1e-12)
    #pdb.set_trace()
    #dim_avg = torch.mean(matrix_A.reshape(-1,dim),dim=0,keepdim=True).reshape((1,1,dim))
    #pdb.set_trace()
    matrix_A_1 = matrix_A_1 - torch.mean(matrix_A_1, dim = 0, keepdims= True) - torch.mean(matrix_A_1, dim = 1, keepdims= True) + torch.mean(matrix_A_1,dim=(0,1),keepdim=True)   # N*N *dim
    
    #return cal(matrix_A)
    # B = torch.zeros((dim, dim))
    # for i in range(dim):
    #     for j in range(i, dim):
    #         s = torch.mean(matrix_A[:,:,i]*matrix_A[:,:,j])/torch.sqrt((torch.mean(matrix_A[:,:,i]*matrix_A[:,:,i])*torch.mean(matrix_A[:,:,j]*matrix_A[:,:,j]))+1e-9)
    #         B[i, j] = s
    #         B[j, i] = B[i, j]
    #print(B)
    #B_1 = cal(matrix_A)
    N,M,dim = matrix_A.shape
    A_flat = matrix_A.reshape(N * M, dim)
    A_flat_1 = matrix_A_1.reshape(N * M, dim)
    #print(A_flat.T @ A_flat / (N - 1))
    B = A_flat_1.T @ A_flat / (N * M)

    #C = torch.div(B, torch.sqrt(torch.ger(torch.diag(B), torch.diag(B))))
    #print(C)
    return B

def calc_distance_correlation(A):
    N, dim = A.shape
    #A = A.reshape(N,dim,1)
    matrix_A = torch.abs(A.unsqueeze(0) - A.unsqueeze(1))  # N * N *dim 
    #matrix_A = torch.sqrt(torch.sum(torch.square(A.unsqueeze(0) - A.unsqueeze(1)), dim = -1) + 1e-12)
    #pdb.set_trace()
    #dim_avg = torch.mean(matrix_A.reshape(-1,dim),dim=0,keepdim=True).reshape((1,1,dim))
    #pdb.set_trace()
    matrix_A = matrix_A - torch.mean(matrix_A, dim = 0, keepdims= True) - torch.mean(matrix_A, dim = 1, keepdims= True) + torch.mean(matrix_A,dim=(0,1),keepdim=True)   # N*N *dim
    #return cal(matrix_A)
    # B = torch.zeros((dim, dim))
    # for i in range(dim):
    #     for j in range(i, dim):
    #         s = torch.mean(matrix_A[:,:,i]*matrix_A[:,:,j])/torch.sqrt((torch.mean(matrix_A[:,:,i]*matrix_A[:,:,i])*torch.mean(matrix_A[:,:,j]*matrix_A[:,:,j]))+1e-9)
    #         B[i, j] = s
    #         B[j, i] = B[i, j]
    #print(B)
    #B_1 = cal(matrix_A)
    N,M,dim = matrix_A.shape
    A_flat = matrix_A.reshape(N * M, dim)
    #print(A_flat.T @ A_flat / (N - 1))
    B = A_flat.T @ A_flat / (N * M)

    C = torch.div(B, torch.sqrt(torch.ger(torch.diag(B), torch.diag(B))))
    #print(C)
    return C

#torch.cov(torch.hstack((zi, zj)).T)

def _mat_pow(mat, pow_, epsilon):
    # Computing matrix to the power of pow (pow can be negative as well)
    [D, V] = torch.linalg.eigh(mat)
    mat_pow = V @ torch.diag((D + epsilon).pow(pow_)) @ V.T
    mat_pow[mat_pow != mat_pow] = epsilon  # For stability
    return mat_pow


def _demean(views):
    return tuple([view - view.mean(dim=0) for view in views])

def cca_loss(views,dim=100):
    n = views[0].shape[0]
        # Subtract the mean from each output
    views = _demean(views)

        # Concatenate all views and from this get the cross-covariance matrix
    all_views = torch.cat(views, dim=1)
    C = all_views.T @ all_views / (n - 1)
        #pdb.set_trace()

    # Get the block covariance matrix placing Xi^TX_i on the diagonal
    D = torch.block_diag(
            *[
                (1 - 0.0) * m.T @ m / (n - 1)
                + 0.0 * torch.eye(m.shape[1], device=m.device)
                for i, m in enumerate(views)
            ]
    )
    # pdb.set_trace()
    C = C - torch.block_diag(*[view.T @ view / (n - 1) for view in views]) + D

    R = _mat_pow(D, -0.5, 1e-3)

    # In MCCA our eigenvalue problem Cv = lambda Dv
    C_whitened = R @ C @ R.T

    eigvals = torch.linalg.eigvalsh(C_whitened)

    # Sort eigenvalues so lviewest first
    idx = torch.argsort(eigvals, descending=True)

    eigvals = eigvals[idx[: dim]]

    # leaky relu encourages the gradient to be driven by positively correlated dimensions while also encouraging
    # dimensions associated with spurious negative correlations to become more positive
    eigvals = torch.nn.LeakyReLU()(eigvals[torch.gt(eigvals, 0)] - 1)

    corr = eigvals.sum()

    return -corr


