import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.metrics import normalized_mutual_info_score
from sklearn.metrics import adjusted_rand_score
from sklearn import linear_model
import pdb

from sklearn.preprocessing import normalize
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
from random import sample
import random
import os
import torch
import torch.nn as nn
from torch.distributions import Normal, Independent
from torch.nn.functional import softplus

import tensorly as tl
import torch
from tensorly.cp_tensor import cp_to_tensor
from tensorly.decomposition import parafac

from cca_zoo.deepmodels import architectures
import torch.nn.functional as F
from .utils import Distance_Correlation,calc_distance_correlation,calc_distance_correlation_AB
from tensorly.tenalg import multi_mode_dot
from .sqrtm import sqrtm

def _mat_pow(mat, pow_, epsilon):
    #pdb.set_trace()
    # Computing matrix to the power of pow (pow can be negative as well)
    [D, V] = torch.linalg.eigh(mat)
    D = torch.clamp(D,min=epsilon)  # linear cca 需要不clamp, nr-dcca 需要clamp 很多
    
    #pdb.set_trace()
    mat_pow = V @ torch.diag(D.pow(pow_)) @ V.T
    #mat_pow[mat_pow != mat_pow] = epsilon  # For stability
    return mat_pow



from typing import List
class _MCCALoss:
    """Differentiable MCCA Loss. Solves the multiset eigenvalue problem.

    References
    ----------
    https://arxiv.org/pdf/2005.11914.pdf

    """

    def __init__(self, eps: float = 0, clamp: float = 1e-5):
        self.eps = eps
        self.clamp = clamp

    def C(self, representations: List[torch.Tensor]):
        """Calculate cross-covariance matrix."""
        all_views = torch.cat(representations, dim=1)
        C = torch.cov(all_views.T)
        C = C - torch.block_diag(
            *[torch.cov(representation.T) for representation in representations]
        )
        return C / len(representations)

    def D(self, representations: List[torch.Tensor]):
        """Calculate block covariance matrix."""
        D = torch.block_diag(
            *[
                (1 - self.eps) * torch.cov(representation.T)
                + self.eps
                * torch.eye(representation.shape[1], device=representation.device)
                for representation in representations
            ]
        )
        return D / len(representations)

    def correlation(self, representations: List[torch.Tensor]):
        """Calculate correlation."""
        latent_dims = representations[0].shape[1]
        representations = [
            representation - representation.mean(dim=0)
            for representation in representations
        ]
        C = self.C(representations)
        D = self.D(representations)
        C += D
        R = _mat_pow(D, -0.5,epsilon=self.clamp)
        C_whitened = R @ C @ R.T
        eigvals = torch.linalg.eigvalsh(C_whitened)
        idx = torch.argsort(eigvals, descending=True)
        eigvals = eigvals[idx[:latent_dims]]
        return eigvals

    def __call__(self, representations: List[torch.Tensor]):
        """Calculate loss."""
        eigvals = self.correlation(representations)
        eigvals = torch.nn.LeakyReLU()(eigvals[torch.gt(eigvals, 0)]-1)
        corr = eigvals.sum()
        return -corr


class _GCCALoss:
    """Differentiable GCCA Loss. Solves the generalized CCA eigenproblem.

    References
    ----------
    https://arxiv.org/pdf/2005.11914.pdf
    """

    def __init__(self, eps: float = 0, clamp: float = 1e-5):
        self.eps = eps
        self.clamp = clamp

    def Q(self, representations: List[torch.Tensor]):
        """Calculate Q matrix."""
        projections = [
            representation
            @ _mat_pow(torch.cov(representation.T),-1,self.clamp)
            @ representation.T
            for representation in representations
        ]
        Q = torch.stack(projections, dim=0).sum(dim=0)
        return Q

    def correlation(self, representations: List[torch.Tensor]):
        """Calculate correlation."""
        latent_dims = representations[0].shape[1]
        representations = [
            representation - representation.mean(dim=0)
            for representation in representations
        ]
        Q = self.Q(representations)
        eigvals = torch.linalg.eigvalsh(Q)
        idx = torch.argsort(eigvals, descending=True)
        eigvals = eigvals[idx[:latent_dims]]
        return torch.nn.LeakyReLU()(eigvals)

    def __call__(self, representations: List[torch.Tensor]):
        """Calculate loss."""
        eigvals = self.correlation(representations)
        corr = eigvals.sum()
        return -corr

def _demean(views):
    return tuple([view - view.mean(dim=0) for view in views])


class DCCA_Noise_Norm_M(torch.nn.Module):
    def __init__(self, in_dims,out_dim,view_num,loss_name='cca',recon=False,private=False,linear=False,noise=None):
        super(DCCA_Noise_Norm_M, self).__init__()
        self.view_num = view_num
        self.latent_dim = out_dim
        self.noise = noise
        self.cca_reg = 0
        self.private = private
        self.linear = linear
        if recon:
            if private:
                self.encoder = nn.ModuleList([architectures.Encoder(latent_dims=out_dim, dropout=0.1,feature_size=in_dims[i],layer_sizes=(1024,1024,1024)) for i in range(view_num)])
                self.encoder_p = nn.ModuleList([architectures.Encoder(latent_dims=out_dim, dropout=0.1,feature_size=in_dims[i],layer_sizes=(1024,1024,1024)) for i in range(view_num)])
                self.decoder = nn.ModuleList([architectures.Decoder(latent_dims=2*out_dim, layer_sizes=(1024,1024,1024),feature_size=in_dims[i]) for i in range(view_num)])
            else:
                # self.encoder = nn.ModuleList([architectures.Encoder(latent_dims=out_dim,feature_size=in_dims[i],layer_sizes=(1024,1024,1024)) for i in range(view_num)])
                self.encoder = nn.ModuleList([architectures.Encoder(latent_dims=out_dim,feature_size=in_dims[i],layer_sizes=(1024,1024,1024)) for i in range(view_num)])
                self.decoder = nn.ModuleList([architectures.Decoder(latent_dims=out_dim, layer_sizes=(1024,1024,1024),feature_size=in_dims[i]) for i in range(view_num)])
        else:
            if linear:
                self.encoder = nn.ModuleList([nn.Linear(in_dims[i],out_dim) for i in range(view_num)])
            else:
                self.encoder =  nn.ModuleList([architectures.Encoder(latent_dims=out_dim,feature_size=in_dims[i],layer_sizes=(1024,1024,1024)) for i in range(view_num)])
        self.loss_name = loss_name

    def get_loss(self,views):
        #pdb.set_trace()
        if self.loss_name=='cca':
            if self.linear:
                loss = _MCCALoss(0,1e-4)
            elif  self.noise=='normal':
                loss = _MCCALoss(0,1e-4)
            else:
                loss = _MCCALoss(1e-3,1e-4)
            return loss(views)
        else:
            if self.loss_name=='gcca':
                if self.linear:
                    loss = _GCCALoss(0,1e-4)
                elif  self.noise=='normal':
                    loss = _GCCALoss(0,1e-4)
                else:
                    loss = _GCCALoss(1e-3,1e-4)
            return loss(views)
     

    def forward(self, multiview,noise=None,recon=False):
        #view1,view2 = multiview
        view_recon = []
        view_project = []
        noise_project = []
        

        for i in range(self.view_num):
            if self.private:
                proj = self.encoder[i](multiview[i])
                proj_p = self.encoder_p[i](multiview[i])
                proj = torch.cat([proj,proj_p],dim=1)
                #view_recon.append(self.decoder[i](proj))
                view_project.append(proj)
                if recon:
                    view_recon.append(self.decoder[i](proj))
            else:


                proj = self.encoder[i](multiview[i])
            
                if recon:
                    view_recon.append(self.decoder[i](proj))
                view_project.append(proj)

                if noise:
                    noise_proj = self.encoder[i](noise[i])
                    noise_project.append(noise_proj)

        
        if noise:
            if recon:
                return view_project,view_recon,noise_project
            return view_project,noise_project#,self.gen(multiview)#,view_recon
        if recon:
            return view_project,view_recon
        else:
            return view_project
