import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from scipy.spatial.distance import cdist, pdist, squareform
from scipy.special import digamma
from sklearn.neighbors import NearestNeighbors

def avg_abs_pearson_corr(U, V):
    corr_matrix = np.corrcoef(U.T, V.T)
    d_u = U.shape[1]
    cross_corr = corr_matrix[:d_u, d_u:]
    return np.mean(np.abs(cross_corr))

def distance_correlation(U, V):
    n = U.shape[0]
    A = cdist(U, U, metric='euclidean')
    B = cdist(V, V, metric='euclidean')
    A -= A.mean(axis=0)[None, :]
    A -= A.mean(axis=1)[:, None]
    A += A.mean()
    B -= B.mean(axis=0)[None, :]
    B -= B.mean(axis=1)[:, None]
    B += B.mean()
    dcov_xy = (A * B).sum() / n**2
    dcov_xx = (A * A).sum() / n**2
    dcov_yy = (B * B).sum() / n**2
    denom = np.sqrt(dcov_xx * dcov_yy)
    return 0 if denom == 0 else np.sqrt(dcov_xy / denom)

def rbf_kernel(X, sigma=None):
    dists = pdist(X, 'euclidean')
    if sigma is None:
        sigma = np.median(dists) or 1.0
    K = squareform(np.exp(-dists**2 / (2 * sigma**2)))
    np.fill_diagonal(K, 1)
    return K

def hsic(U, V, sigma=None):
    n = U.shape[0]
    K = rbf_kernel(U, sigma)
    L = rbf_kernel(V, sigma)
    H = np.eye(n) - np.ones((n, n)) / n
    Kc, Lc = H @ K @ H, H @ L @ H
    return np.trace(Kc @ Lc) / (n - 1)**2

def mutual_information_knn(U, V, k=5):
    n = U.shape[0]
    UV = np.concatenate([U, V], axis=1)
    nn = NearestNeighbors(metric='chebyshev', n_neighbors=k + 1).fit(UV)
    dist, _ = nn.kneighbors(UV)
    eps = dist[:, k] - 1e-15
    nx = np.array([len(NearestNeighbors(metric='chebyshev').fit(U).radius_neighbors([u], r, return_distance=False)[0]) - 1 for u, r in zip(U, eps)])
    ny = np.array([len(NearestNeighbors(metric='chebyshev').fit(V).radius_neighbors([v], r, return_distance=False)[0]) - 1 for v, r in zip(V, eps)])
    mi = digamma(k) + digamma(n) - np.mean(digamma(nx + 1) + digamma(ny + 1))
    return mi / np.log(2)  

def dependence_summary(U, V):
    return {
        "avg_abs_corr": avg_abs_pearson_corr(U, V),
        "distance_corr": distance_correlation(U, V),
        "hsic": hsic(U, V),
        "mi_knn_bits": mutual_information_knn(U, V)
    }


def info_nce_loss(features, temperature, batch_size, num_aug): 

    labels = torch.cat([torch.arange(batch_size) for i in range(num_aug)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()

    features = F.normalize(features, dim=1)
    similarity_matrix = torch.matmul(features, features.T)

    mask1 = torch.ones((labels.shape[0]//2, labels.shape[0]//2))
    mask = torch.block_diag(mask1, mask1).bool()
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives / temperature], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long)
    logits = logits / temperature

    return logits, labels


relu = nn.ReLU()
tanh = nn.Tanh()
tau_init = 1e0



class LinearNet(nn.Module):
    def __init__(self, input_dim, middle_dim, output_dim, tau_lower):
        super(LinearNet, self).__init__()
        self.input_dim = input_dim
        self.middle_dim = middle_dim
        self.output_dim = output_dim
        self.tau_lower = tau_lower
        self.linear1 = nn.Linear(input_dim, middle_dim, bias=False)
        self.linear2 = nn.Linear(middle_dim, output_dim, bias=False)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(tau_init))
    
    def forward(self, x):
        z = self.linear1(x)
        z = self.linear2(z)
        
        tau = self.logit_scale.exp() + self.tau_lower
    
        return z

class NonLinearNet(nn.Module):
    def __init__(self, input_dim, middle_dim, output_dim, tau_lower):
        super(NonLinearNet, self).__init__()
        self.input_dim = input_dim
        self.middle_dim = middle_dim
        self.output_dim = output_dim
        self.tau_lower = tau_lower
        self.linear1 = nn.Linear(input_dim, middle_dim, bias=True)
        self.linear2 = nn.Linear(middle_dim, output_dim, bias=True)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(tau_init))
        
    def forward(self, x):
        z = self.linear1(x)
        z = relu(z)
        z = self.linear2(z)
        
        tau = self.logit_scale.exp() + self.tau_lower
        
        return z
    
class NonLinearNetD(nn.Module):
    def __init__(self, input_dim, middle_dim, output_dim, tau_lower):
        super(NonLinearNetD, self).__init__()
        self.input_dim = input_dim
        self.middle_dim = middle_dim
        self.output_dim = output_dim
        self.tau_lower = tau_lower
        self.linear1 = nn.Linear(input_dim, middle_dim, bias=True)
        self.linear2 = nn.Linear(middle_dim, middle_dim, bias=True)
        self.linear3 = nn.Linear(middle_dim, middle_dim, bias=True)
        self.linear4 = nn.Linear(middle_dim, output_dim, bias=True)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(tau_init))
        
    def forward(self, x):
        z = self.linear1(x)
        z = relu(z)
        z = self.linear2(z)
        z = relu(z)
        z = self.linear3(z)
        z = relu(z)
        z = self.linear4(z)
        
        tau = self.logit_scale.exp() + self.tau_lower
        
        return z
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)              # (max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (seq_len, batch, d_model)
        return x + self.pe[:x.size(0), :].unsqueeze(1)

class TransformerEncoderNet(nn.Module):
    def __init__(self,
                 input_dim,
                 embed_dim=256,
                 num_heads=4,
                 ff_dim=128,
                 num_layers=3,
                 output_dim=128,
                 tau_lower=0.01,
                 tau_init=1e0,
                 norm=False):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, embed_dim)
        self.pos_enc     = PositionalEncoding(embed_dim)
        encoder_layer    = nn.TransformerEncoderLayer(
                               d_model=embed_dim,
                               nhead=num_heads,
                               dim_feedforward=ff_dim,
                               dropout=0.1,
                               activation='relu')
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_proj = nn.Linear(embed_dim, output_dim)
        self.tau_lower   = tau_lower
        self.tau_init   = tau_init
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(self.tau_init))

    def forward(self, x):

        batch = x.size(0)
        z = self.input_proj(x)                    
        z = z.unsqueeze(0)                        
        z = self.pos_enc(z)                       
        z = self.transformer(z)                   
        z = z.squeeze(0)                          
        z = self.output_proj(z)                   


        tau = self.logit_scale.exp() + self.tau_lower

        return z
    

class Transformer_mat(nn.Module):

    def __init__(self, n_features, dim):

        super().__init__()
        self.embed_dim = dim
        self.linear = nn.Linear(n_features, self.embed_dim, bias=False)
        layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=2, dim_feedforward=128)
        self.transformer = nn.TransformerEncoder(layer, num_layers=2)
        self.tau_lower   = 1e-3
        self.tau_init   = 1e0
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(self.tau_init))

    def forward(self, x):

        if type(x) is list:
            x = x[0]
        x = self.linear(x)
        x = self.transformer(x)[-1]

        return x



class Transformer_tens(nn.Module):
    
    def __init__(self, n_features, dim):

        super().__init__()
        self.embed_dim = dim
        self.conv = nn.Conv1d(n_features, self.embed_dim,
                              kernel_size=1, padding=0, bias=False)
        layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=2, dim_feedforward=128)
        self.transformer = nn.TransformerEncoder(layer, num_layers=2)
        self.tau_lower   = 1e-3
        self.tau_init   = 1e0
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(self.tau_init))

    def forward(self, x):

        if type(x) is list:
            x = x[0]
        x = self.conv(x.permute([0, 2, 1]))
        x = x.permute([2, 0, 1])
        x = self.transformer(x)[-1]
        return x
    


class Transformer_ma(nn.Module):


    def __init__(self, n_features, middim, dim):

        super().__init__()
        self.embed_dim = middim
        self.linear = nn.Linear(n_features, self.embed_dim, bias=False)
        layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=2, dim_feedforward=128)
        self.transformer = nn.TransformerEncoder(layer, num_layers=2)
        self.linear2 = nn.Linear(self.embed_dim, dim, bias=False)
        self.tau_lower   = 1e-3
        self.tau_init   = 1e0
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(self.tau_init))

    def forward(self, x):

        if type(x) is list:
            x = x[0]
        x = self.linear(x)
        x = self.transformer(x)
        x = self.linear2(x)

        return x

class CLUB(nn.Module):  

    def __init__(self, x_dim, y_dim, hidden_size):
        super(CLUB, self).__init__()
        self.p_mu = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim))

        self.p_logvar = nn.Sequential(nn.Linear(x_dim, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, hidden_size//2),
                                       nn.ReLU(),
                                       nn.Linear(hidden_size//2, y_dim),
                                       nn.Tanh())


    def get_mu_logvar(self, x_samples):
        mu = self.p_mu(x_samples)
        logvar = self.p_logvar(x_samples)
        return mu, logvar
    
    def forward(self, x_samples, y_samples): 
        mu, logvar = self.get_mu_logvar(x_samples)
        
        positive = - (mu - y_samples)**2 /2./logvar.exp()  
        
        prediction_1 = mu.unsqueeze(1)          # shape [nsample,1,dim]
        y_samples_1 = y_samples.unsqueeze(0)    # shape [1,nsample,dim]

        negative = - ((y_samples_1 - prediction_1)**2).mean(dim=1)/2./logvar.exp() 

        return (positive.sum(dim = -1) - negative.sum(dim = -1)).mean()

    def loglikeli(self, x_samples, y_samples): # unnormalized loglikelihood 
        mu, logvar = self.get_mu_logvar(x_samples)
        return (-(mu - y_samples)**2 /logvar.exp()-logvar).sum(dim=1).mean(dim=0)
    
    def learning_loss(self, x_samples, y_samples):
        return - self.loglikeli(x_samples, y_samples)


def mlp(dim, hidden_dim, output_dim, layers, activation):
    activation = {
        'relu': nn.ReLU,
        'tanh': nn.Tanh,
    }[activation]

    seq = [nn.Linear(dim, hidden_dim), activation()]
    for _ in range(layers):
        seq += [nn.Linear(hidden_dim, hidden_dim), activation()]
    seq += [nn.Linear(hidden_dim, output_dim)]

    return nn.Sequential(*seq)

class CLUBInfoNCECritic(nn.Module):
    def __init__(self, A_dim, B_dim, hidden_dim, layers, activation, **extra_kwargs):
          super(CLUBInfoNCECritic, self).__init__()
 
          self._f = mlp(A_dim + B_dim, hidden_dim, 1, layers, activation)

    # CLUB loss
    def forward(self, x_samples, y_samples):
        sample_size = y_samples.shape[0]

        x_tile = x_samples.unsqueeze(0).repeat((sample_size, 1, 1))
        y_tile = y_samples.unsqueeze(1).repeat((1, sample_size, 1))

        T0 = self._f(torch.cat([y_samples,x_samples], dim = -1)) 
        T1 = self._f(torch.cat([y_tile, x_tile], dim = -1))  

        return T0.mean() - T1.mean()

    # InfoNCE loss
    def learning_loss(self, x_samples, y_samples):
        sample_size = y_samples.shape[0]

        x_tile = x_samples.unsqueeze(0).repeat((sample_size, 1, 1))
        y_tile = y_samples.unsqueeze(1).repeat((1, sample_size, 1))

        T0 = self._f(torch.cat([y_samples,x_samples], dim = -1))
        T1 = self._f(torch.cat([y_tile, x_tile], dim = -1)) 

        lower_bound = T0.mean() - (T1.logsumexp(dim = 1).mean() - np.log(sample_size)) 
        return -lower_bound
    

def pad_to_dim(x, K):

    k = x.shape[-1]
    if k >= K:
        return x[..., :K]
    pad_shape = list(x.shape[:-1]) + [K - k]
    zeros = torch.zeros(pad_shape, dtype=x.dtype, device=x.device)
    return torch.cat([x, zeros], dim=-1)


class PadToDim(nn.Module):

    def __init__(self, target_dim):
        super().__init__()
        self.target_dim = target_dim

    def forward(self, x):
        return pad_to_dim(x, self.target_dim)
    

def train_inner_mlp(h_w, x, num_steps=10, lr=1e-3):
    mlp = NonLinearNetD(input_dim=h_w.shape[1], middle_dim=50, output_dim=x.shape[1],tau_lower=1e-3).to(h_w.device)
    optimizer = torch.optim.Adam(mlp.parameters(), lr=lr)
    loss_fn = nn.MSELoss()

    for _ in range(num_steps):
        pred = mlp(h_w)
        loss = loss_fn(pred, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return mlp, loss.item()



def data_gen(n, n_test, d_x, d_y, d_z, eps=0.0, link='nonlinear', seed=0):
    
    N = n + n_test

    Y_all = np.random.multivariate_normal(mean=np.zeros(d_y), cov=np.eye(d_y), size=N)
    noise_all = np.random.multivariate_normal(mean=np.zeros(d_x), cov=np.eye(d_x), size=N)

    A = np.zeros((d_y, d_x))
    for i in range(min(d_x, d_y)):
        A[i, i] = 1.0

    X_all = Y_all @ A + eps * noise_all

    A_f = np.zeros((d_x, d_z))
    for i in range(min(d_z, d_x)):
        A_f[i, i] = 1.0

    F0_all = X_all @ A_f
    if link == 'linear':
        F_all = F0_all
    else:
    	F_all = 0.5 * F0_all + 0.2 * np.sin(F0_all) + 0.2 + (F0_all)**3

    # Split into train/test
    X = torch.tensor(X_all[:n], dtype=torch.float32)
    Y = torch.tensor(Y_all[:n], dtype=torch.float32)
    F = torch.tensor(F_all[:n], dtype=torch.float32)

    X_test = torch.tensor(X_all[n:], dtype=torch.float32)
    Y_test = torch.tensor(Y_all[n:], dtype=torch.float32)
    F_test = torch.tensor(F_all[n:], dtype=torch.float32)

    return {
        "X": X,
        "Y": Y,
        "F": F,
        "X_test": X_test,
        "Y_test": Y_test,
        "F_test": F_test,
        "A_f": A_f
    }



def train_club_batch(club_model, x_samples, y_samples,
                     club_steps=10, club_lr=1e-3):

    x0 = x_samples.detach()
    y0 = y_samples.detach()
    opt = torch.optim.Adam(club_model.parameters(), lr=club_lr)
    
    for _ in range(club_steps):
        loss_q = club_model.learning_loss(x0, y0)   
        opt.zero_grad()
        loss_q.backward()
        opt.step()
    
    mi_xy = club_model(x0, y0)
    return mi_xy

