import numpy as np
import torch
from sklearn.cluster import KMeans
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.metrics import normalized_mutual_info_score
from sklearn.metrics import adjusted_rand_score
from sklearn import linear_model
import pdb

from sklearn.preprocessing import normalize
from sklearn.svm import SVR
from sklearn.ensemble import RandomForestRegressor
from random import sample
import random
import os
import torch
import torch.nn as nn
from torch.distributions import Normal, Independent
from torch.nn.functional import softplus

import tensorly as tl
import torch
from tensorly.cp_tensor import cp_to_tensor
from tensorly.decomposition import parafac

from cca_zoo.deepmodels import architectures
import torch.nn.functional as F
from .utils import Distance_Correlation,calc_distance_correlation,calc_distance_correlation_AB
from tensorly.tenalg import multi_mode_dot
from .sqrtm import sqrtm


def _mat_pow(mat, pow_, epsilon):
    # Computing matrix to the power of pow (pow can be negative as well)
    [D, V] = torch.linalg.eigh(mat)
    mat_pow = V @ torch.diag((D + epsilon).pow(pow_)) @ V.T
    mat_pow[mat_pow != mat_pow] = epsilon  # For stability
    return mat_pow


def _demean(views):
    return tuple([view - view.mean(dim=0) for view in views])





class Encoder(nn.Module):
    def __init__(self, in_dim,out_dim):
        super(Encoder, self).__init__()

        
        # Vanilla MLP
        self.net = nn.Sequential(
            nn.Linear(in_dim, 256),
            #nn.LeakyReLU(),
            nn.ReLU(True),
            # nn.Linear(256, 256),
            # nn.ReLU(True),
            nn.Linear(256, out_dim),
        )

    def forward(self, x):
       
       return self.net(x)


def get_cor(view_1,view_2):
    n = view_1.shape[0]
    dim = view_1.shape[1]
    view_1 = view_1 - view_1.mean(dim=0)
        # Subtract the mean from each output
    view_self_1 = view_1.T @ view_1/ (n-1)
    view_self_pow_1 = _mat_pow(view_self_1, -0.5, 1e-5)

    n = view_2.shape[0]
    dim = view_2.shape[1]
    view_2 = view_2 - view_2.mean(dim=0)
        # Subtract the mean from each output
    view_self_2 = view_2.T @ view_2/ (n-1)
    view_self_pow_2 = _mat_pow(view_self_2, -0.5, 1e-5)
    #view_self_pow =  matrix_negative_half(view_self)
    view_pair = view_1.T @ view_2/ (n-1)
    T = view_self_pow_1@view_pair@view_self_pow_2
    rank = torch.trace(T.T@T)

    return rank


class DCCA_Noise_Norm_M(torch.nn.Module):
    def __init__(self, in_dims,out_dim,view_num,loss_name='cca',recon=False,private=False):
        super(DCCA_Noise_Norm_M, self).__init__()
        self.view_num = view_num
        self.latent_dim = out_dim
       
        self.cca_reg = 0
        self.private = private
      
        if recon:
            if private:
                self.encoder = nn.ModuleList([architectures.Encoder(latent_dims=out_dim, dropout=0.1,feature_size=in_dims[i]) for i in range(view_num)])
                self.encoder_p = nn.ModuleList([architectures.Encoder(latent_dims=out_dim, dropout=0.1,feature_size=in_dims[i]) for i in range(view_num)])
                self.decoder = nn.ModuleList([architectures.Decoder(latent_dims=2*out_dim,feature_size=in_dims[i]) for i in range(view_num)])
            else:
                self.encoder = nn.ModuleList([architectures.Encoder(latent_dims=out_dim,feature_size=in_dims[i]) for i in range(view_num)])
                self.decoder = nn.ModuleList([architectures.Decoder(latent_dims=out_dim, feature_size=in_dims[i]) for i in range(view_num)])
        else:
            self.encoder = nn.ModuleList([architectures.Encoder(latent_dims=out_dim, feature_size=in_dims[i]) for i in range(view_num)])
        self.loss_name = loss_name

    def get_loss(self,views):
        if self.loss_name=='cca':
            return self.cca_loss(views)
        elif self.loss_name=='gcca':
            return self.GCCA_loss(views)
     

    def forward(self, multiview,noise=None,recon=False):
        #view1,view2 = multiview
        view_recon = []
        view_project = []
        noise_project = []
        

        for i in range(self.view_num):
            if self.private:
                proj = self.encoder[i](multiview[i])
                proj_p = self.encoder_p[i](multiview[i])
                proj = torch.cat([proj,proj_p],dim=1)
                #view_recon.append(self.decoder[i](proj))
                view_project.append(proj)
                if recon:
                    view_recon.append(self.decoder[i](proj))
            else:


                proj = self.encoder[i](multiview[i])
            
                if recon:
                    view_recon.append(self.decoder[i](proj))
                view_project.append(proj)

                if noise:
                    noise_proj = self.encoder[i](noise[i])
                    noise_project.append(noise_proj)

        
        if noise:
            if recon:
                return view_project,view_recon,noise_project
            return view_project,noise_project#,self.gen(multiview)#,view_recon
        if recon:
            return view_project,view_recon
        else:
            return view_project

    def cca_loss(self, views):
        n = views[0].shape[0]
        # Subtract the mean from each output
        views = _demean(views)

        # Concatenate all views and from this get the cross-covariance matrix
        all_views = torch.cat(views, dim=1)
        C = all_views.T @ all_views / (n - 1)
        #pdb.set_trace()

        # Get the block covariance matrix placing Xi^TX_i on the diagonal
        D = torch.block_diag(
            *[
                (1 - 0) * m.T @ m / (n - 1)
                + 0 * torch.eye(m.shape[1], device=m.device)
                for i, m in enumerate(views)
            ]
        )
       # pdb.set_trace()
        C = C - torch.block_diag(*[view.T @ view / (n - 1) for view in views]) + D

        R = _mat_pow(D, -0.5, 1e-5)

        # In MCCA our eigenvalue problem Cv = lambda Dv
        C_whitened = R @ C @ R.T
        # r = 1e-5
        # C_whitened = (1-r)*C_whitened+r*torch.eye(C_whitened.shape[-1], device=C_whitened.device)

        eigvals = torch.linalg.eigvalsh(C_whitened)
        #pdb.set_trace()

        # Sort eigenvalues so lviewest firsts
        idx = torch.argsort(eigvals, descending=True)

        eigvals = eigvals[idx[: self.latent_dim]]

        # leaky relu encourages the gradient to be driven by positively correlated dimensions while also encouraging
        # dimensions associated with spurious negative correlations to become more positive
        eigvals = torch.nn.LeakyReLU()(eigvals[torch.gt(eigvals, 0)] - 1)

        corr = eigvals.sum()

        return -corr
        
    def GCCA_loss(self, views):
        # https: // www.uta.edu / math / _docs / preprint / 2014 / rep2014_04.pdf
        n = views[0].shape[0]
        # H is n_views * n_samples * k
        views = _demean(views)

        eigen_views = [
            view @ _mat_pow(view.T @ view / (n - 1), -1, 1e-5) @ view.T
            for view in views
        ]

        Q = torch.stack(eigen_views, dim=0).sum(dim=0)
        eigvals = torch.linalg.eigvalsh(Q)
        #pdb.set_trace()

        idx = torch.argsort(eigvals, descending=True)

        eigvals = eigvals[idx[: self.latent_dim]]

        # leaky relu encourages the gradient to be driven by positively correlated dimensions while also encouraging
        # dimensions associated with spurious negative correlations to become more positive
        eigvals = torch.nn.LeakyReLU()(eigvals[torch.gt(eigvals, 0)] - 1)
        #pdb.set_trace()
        corr = eigvals.sum()#/dim

        return -corr