import torch
import torch.nn as nn
import torch.nn.functional as Fnn
import torch.optim as optim
import os

import numpy as np
import sys
from matplotlib import pyplot as plt
import warnings
from scipy.linalg import qr, sqrtm
import seaborn as sns
from tqdm import tqdm
from pytorch_metric_learning import losses
from sklearn.decomposition import PCA
import argparse
import math
import pandas as pd
from scipy.linalg import block_diag

from scipy.spatial.distance import cdist, pdist, squareform
from scipy.special import digamma
from sklearn.neighbors import NearestNeighbors
from sklearn.linear_model import LogisticRegression
from utils import *


import torch
import torch.nn as nn
import torch.nn.functional as Fnn
from torch import optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import accuracy_score, recall_score, classification_report



def permute(x):
  # shuffle the sequence order
  idx = torch.randperm(x.shape[0])
  return x[idx]

def noise(x):
  noise = torch.randn(x.shape) * 0.1
  return x + noise.to(x.device)

def drop(x):
  # drop 20% of the sequences
  drop_num = x.shape[0] // 5
  
  x_aug = torch.clone(x)
  drop_idxs = np.random.choice(x.shape[0], drop_num, replace=False)
  x_aug[drop_idxs] = 0.0
  return x_aug  

def mixup(x, alpha=1.0):
    indices = torch.randperm(x.shape[0])
    lam = np.random.beta(alpha, alpha)
    aug_x = x * lam + x[indices] * (1 - lam)

    return aug_x

def identity(x):
  return x


def augment(x_batch):
  v1 = x_batch
  v2 = torch.clone(v1)
  transforms = [permute, noise, drop, identity]

  for i in range(x_batch.shape[0]):
    t_idxs = np.random.choice(4, 2, replace=False)
    t1 = transforms[t_idxs[0]]
    t2 = transforms[t_idxs[1]]
    v1[i] = t1(v1[i])
    v2[i] = t2(v2[i])
  
  return v1, v2

def augment_single(x_batch):
  v1 = x_batch
  v2 = torch.clone(v1)
  transforms = [permute, noise, drop, identity]

  for i in range(x_batch.shape[0]):
    t_idxs = np.random.choice(4, 1, replace=False)
    t = transforms[t_idxs[0]]
    v2[i] = t(v2[i])
  
  return v2


def augment_embed_single(x_batch):
  v1 = x_batch
  v2 = torch.clone(v1)
  transforms = [noise, mixup, identity]

  t_idxs = np.random.choice(3, 1, replace=False)
  t = transforms[t_idxs[0]]
  v2 = t(v2)

  return v2


def augment_mimic(x_batch):
  if x_batch.dim() == 2:
    return augment_embed_single(x_batch)
  else:
    return augment_single(x_batch)





def train_clip(
    X, Y, 
    model_x, model_y, 
    max_epochs=800, batch_size=256, lr=1e-4, wd=1e-4, 
    tau_fix=1.0, tau_tune=True, tau_lower=1e-4, tau_lr_fac=2, spectral=False, device='cpu'
):
    n = X.shape[0]
    middim = max(100, max(X.shape[1], Y.shape[1]))

    model_x.to(device)
    model_y.to(device)

    X_t = torch.as_tensor(X, device=device, dtype=torch.float32)
    Y_t = torch.as_tensor(Y, device=device, dtype=torch.float32)

    labels_buf = torch.arange(batch_size, device=device)

    optimizer_x = optim.Adam(model_x.parameters(), lr=lr, weight_decay=wd)
    optimizer_y = optim.Adam(model_y.parameters(), lr=lr, weight_decay=wd)

    loss_clip, loss_align, loss_x, loss_y, tau_seq = [], [], [], [], []

    num_batches = n // batch_size  

    for epoch in tqdm(range(max_epochs)):
        losses_clip = []
        losses_align = []
        losses_x = []
        losses_y = []
        tau_seq_ep = []

        for batch_idx in range(num_batches):
            optimizer_x.zero_grad(set_to_none=True)
            optimizer_y.zero_grad(set_to_none=True)

            start = batch_idx * batch_size
            end = start + batch_size
            x = X_t[start:end]
            y = Y_t[start:end]

            h_x = model_x(x)
            h_y = model_y(y)

            if tau_tune:
                tau = (tau_lr_fac * model_x.logit_scale).exp() + tau_lower
                tau_seq_ep.append(tau.item())  # safe float
            else:
                tau = torch.tensor(tau_fix, device=device, dtype=h_x.dtype)

            if spectral:
                loss = spectral_loss(h_x, h_y, tau)
            else:
                logits = (h_x @ h_y.T) / tau
                bn = logits.shape[0]
                labels = labels_buf[:bn]
                loss_x_ = Fnn.cross_entropy(logits, labels)
                loss_y_ = Fnn.cross_entropy(logits.T, labels)
                loss = (loss_x_ + loss_y_) / 2.0

            # Logging
            with torch.no_grad():
                losses_clip.append(loss.item())

                hx_u = Fnn.normalize(h_x, dim=1)
                hy_u = Fnn.normalize(h_y, dim=1)
                diag_align = (hx_u * hy_u).sum(dim=1).mean()
                losses_align.append(-diag_align.item())

                if not spectral:
                    diag_raw = (h_x * h_y).sum(dim=1).mean().item()
                    losses_x.append(loss_x_.item() / 2.0 - diag_raw)
                    losses_y.append(loss_y_.item() / 2.0 - diag_raw)

            # Step
            loss.backward()
            optimizer_x.step()
            optimizer_y.step()

        tau_seq.append(np.mean(tau_seq_ep))
        loss_clip.append(np.mean(losses_clip))
        loss_align.append(np.mean(losses_align))
        if not spectral:
            loss_x.append(np.mean(losses_x))
            loss_y.append(np.mean(losses_y))

    return {
        'model_x': model_x,
        'model_y': model_y,
        'loss_clip': loss_clip,
        'loss_align': loss_align,
        'loss_x': loss_x,
        'loss_y': loss_y,
        'tau_seq': tau_seq,
    }




def train_disentangle(
    X, F,
    outdim, arch='deep', max_epochs=200, batch_size=256, is_fix=True, tau_tune=False, 
    lr=1e-4, wd=1e-4, ss=20, aug=False,
    tau_lr_frac=4, tau_fix=1e-2, tau_lower=1e-4, lam=1e-1, 
    objective='recons', device=None
):

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    X = torch.tensor(X, dtype=torch.float32, device=device)
    F = torch.tensor(F, dtype=torch.float32, device=device)

    dataset = TensorDataset(F, X)
    loader  = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=(device.type == 'cuda'))

    arch_map = {
        'linear':    (PadToDim, LinearNet, LinearNet),
        'nonlinear': (PadToDim, NonLinearNet, LinearNet),
        'deep':      (PadToDim, NonLinearNetD, NonLinearNetD),
        'transformer': (PadToDim, Transformer_ma, Transformer_ma),
    }
    Pad, Wnet, Unet = arch_map.get(arch, arch_map)

    if is_fix:
        Fnet = PadToDim(outdim)
    else:
        Fnet = Wnet(F.shape[1], 50, outdim, tau_lower)

    d_x = X.shape[1]
    if arch == 'transformer':
        models = {
            'f': Fnet.to(device),
            'w': Wnet(d_x, middim=50, dim=outdim).to(device),
            'u': Unet(outdim * 2, middim=50, dim=d_x).to(device),
            'club': CLUB(outdim, outdim, 50).to(device),
            'club2': CLUB(outdim, outdim, 50).to(device),
        }
    else:
        models = {
            'f': Fnet.to(device),
            'w': Wnet(d_x, 50, outdim, tau_lower).to(device),
            'u': Unet(2*outdim, 50, d_x, tau_lower).to(device),
            'club': CLUB(outdim, outdim, 50).to(device),
            'club2': CLUB(outdim, outdim, 50).to(device),
        }

    optimizers = {
        name: torch.optim.Adam(m.parameters(), lr=lr, weight_decay=wd)
        for name, m in models.items()
        if not (name == 'f' and is_fix)
    }
    mse_loss = torch.nn.MSELoss()
    ce_loss  = lambda logits: Fnn.cross_entropy(logits, torch.arange(logits.size(0), device=device))

    tau_seq, loss_clip, loss_recons, loss_align = [], [], [], []

    for epoch in tqdm(range(max_epochs), desc='Epoch'):
        losses = {'clip':[], 'recons':[], 'align':[]}
        tau_vals = []

        for F_batch, X_batch in loader:
            F_batch = F_batch.to(device, non_blocking=True)
            X_batch = X_batch.to(device, non_blocking=True)

            h_f = models['f'](F_batch)
            h_w = models['w'](X_batch)
            h_u = models['u'](torch.cat([h_f, h_w], dim=1))

            h_w_s = Fnn.normalize(h_w, dim=-1)
            h_f_s = Fnn.normalize(h_f, dim=-1)

            tau = ((tau_lr_frac * models['w'].logit_scale).exp() + tau_lower if tau_tune else tau_fix)
            tau_vals.append(tau.item() if isinstance(tau, torch.Tensor) else tau)

            logits_fw = h_f_s @ h_w_s.T / tau
            loss_f = ce_loss(logits_fw)
            loss_w = ce_loss(logits_fw.T)
            loss_fw = (loss_f + loss_w) / 2.0

            loss_club = train_club_batch(models['club'], h_f_s, h_w_s, club_steps=10, club_lr=1e-3)
            loss_mse = mse_loss(h_u, X_batch)

            if aug == False:
                loss_ssl = 0
            else:
                X_batch_aug = augment_single(X_batch)

                h_w_aug = models['w'](X_batch_aug)
                h_w_s_aug = Fnn.normalize(h_w_aug, dim=-1)

                logits_fw = h_w_s @ h_w_s_aug.T
                loss_ssl = ce_loss(logits_fw) + ce_loss(logits_fw.T)

            if objective == 'ours1':
                clip = loss_fw
                recon = loss_mse
            elif objective == 'disen':
                pos_sim = torch.norm(torch.matmul(h_f_s.T, h_w_s))
                clip = pos_sim.mean()
                recon = 0.5 * (
                    ce_loss((Fnn.normalize(h_u,1) @ Fnn.normalize(X_batch,1).T) / tau) +
                    ce_loss((Fnn.normalize(X_batch,1) @ Fnn.normalize(h_u,1).T) / tau)
                )
            elif objective == 'recons':
                clip = loss_club
                recon = loss_mse
            else:
                clip = loss_club
                recon = 0.5 * (
                    ce_loss((Fnn.normalize(h_u,1) @ Fnn.normalize(X_batch,1).T) / tau) +
                    ce_loss((Fnn.normalize(X_batch,1) @ Fnn.normalize(h_u,1).T) / tau)
                )

            loss = clip + lam * (recon + loss_ssl)

            losses['clip'].append(clip.item())
            losses['recons'].append(recon.item())
            losses['align'].append((h_f @ h_w.T).diag().mean().item())

            for name, opt in optimizers.items():
                opt.zero_grad()
            loss.backward()
            for name, opt in optimizers.items():
                opt.step()

        tau_seq.append(np.mean(tau_vals))
        loss_clip.append(np.mean(losses['clip']))
        loss_recons.append(np.mean(losses['recons']))
        loss_align.append(np.mean(losses['align']))

    return {
        'models': models,
        'tau_seq': tau_seq,
        'loss_clip': loss_clip,
        'loss_recons': loss_recons,
        'loss_align': loss_align
    }




def _product_of_linear_weights_only(net):
    ws = []
    for m in net.modules():
        if isinstance(m, nn.Linear):
            ws.append(m.weight.detach().cpu().numpy())
    if not ws:
        raise ValueError("No nn.Linear weights found to multiply.")
    return np.linalg.multi_dot(ws[::-1])

@torch.no_grad()
def _shape_info(x):
    b = x.shape[0]
    in_dim = int(np.prod(x.shape[1:]))
    return b, in_dim

def _jacobian_matrix(net, x, out_chunk=256, to_numpy=True):

    net.eval()
    x = x.detach().requires_grad_(True)
    y = net(x)                         
    if y.dim() == 1:
        y = y.unsqueeze(0)             
    B = y.shape[0]
    out_dim = int(np.prod(y.shape[1:]))
    _, in_dim = _shape_info(x)

    J = []
    for b in range(B):
        yb = y[b].reshape(-1)          
        row_chunks = []
        for start in range(0, out_dim, out_chunk):
            end = min(start + out_chunk, out_dim)
            e = torch.zeros_like(yb)
            e[start:end] = 1.0
            gx, = torch.autograd.grad(
                outputs=yb,
                inputs=x,
                grad_outputs=e,
                retain_graph=True,
                create_graph=False,
                allow_unused=False
            )
            gb = gx[b].reshape(-1)     
            rows = []
            for i in range(start, end):
                ei = torch.zeros_like(yb); ei[i] = 1.0
                gi, = torch.autograd.grad(yb, x, grad_outputs=ei, retain_graph=True)
                rows.append(gi[b].reshape(-1).detach())
            block = torch.stack(rows, dim=0)  
            row_chunks.append(block)
        Ji = torch.cat(row_chunks, dim=0)     
        J.append(Ji)
    J = torch.stack(J, dim=0)                 
    if to_numpy:
        return J.cpu().numpy()
    return J

def effective_weights(net, example_input=None, *, to_numpy=True, out_chunk=256):

    if example_input is None:
        return _product_of_linear_weights_only(net)
    else:
        if not isinstance(example_input, torch.Tensor):
            raise TypeError("example_input must be a torch.Tensor when provided.")
        return _jacobian_matrix(net, example_input, out_chunk=out_chunk, to_numpy=to_numpy)




def train(X, Y, X_test, Y_test, lab1, lab1_test, lab2, lab2_test, model_x, model_y, arch_disentg, max_ep):


    print("\nExtracting shared components via CLIP...\n")

    clip_results = train_clip(
        X, Y,
        model_x, model_y, 
        max_epochs=max_ep, batch_size=256, lr=1e-4, wd=1e-4,
        tau_fix=1.0, tau_tune=True, tau_lr_fac=5, spectral=False, device='cpu'
    )

    lab1, lab1_test = np.array(lab1), np.array(lab1_test)
    lab2, lab2_test = np.array(lab2), np.array(lab2_test)

    X_clip = clip_results['model_x'](torch.Tensor(X)).detach().numpy()
    Y_clip = clip_results['model_y'](torch.Tensor(Y)).detach().numpy()

    X_clip_test = clip_results['model_x'](torch.Tensor(X_test)).detach().numpy()
    Y_clip_test = clip_results['model_y'](torch.Tensor(Y_test)).detach().numpy()

    
    print("\nExtracting unique components for modality 1...\n")

    result_X = train_disentangle(
        X, X_clip,
        outdim=50, arch=arch_disentg, max_epochs=max_ep, batch_size=256
    )


    print("\nExtracting unique components for modality 2...\n")

    result_Y = train_disentangle(
        Y, Y_clip, 
        outdim=50, arch=arch_disentg, max_epochs=max_ep, batch_size=256
    )


    result_X_post = postprocess_disentangle(
        result_X['models'], X, X_test, X_clip, X_clip_test, lab1, lab1_test, lab2, lab2_test, device='cpu'
    )
    result_Y_post = postprocess_disentangle(
        result_Y['models'], Y, Y_test, Y_clip, Y_clip_test, lab1, lab1_test, lab2, lab2_test, device='cpu'
    )

    result_post = postprocess_disentangle_joint(
        result_X['models'], result_Y['models'], X, X_test, X_clip, X_clip_test, Y, Y_test, Y_clip, Y_clip_test, lab1, lab1_test, lab2, lab2_test
    )

    return {
        'clip_results': clip_results,
        'disent_X_results': result_X,
        'disent_Y_results': result_Y,
        'post_X': result_X_post,
        'post_Y': result_Y_post,
        'post_concat': result_post,
        'X': X,
        'Y': Y,
        'X_test': X_test,
        'Y_test': Y_test,
        'X_clip': X_clip,
        'Y_clip': Y_clip,
        'X_clip_test': X_clip_test,
        'Y_clip_test': Y_clip_test
    }




def train_clip_raw(
    X, Y, 
    model_x, model_y, 
    max_epochs=800, batch_size=256, lr=1e-4, wd=1e-4,
    tau_fix=1.0, tau_tune=True, tau_lower=1e-4, tau_lr_fac=2, spectral=False, device='cpu'
):
    model_x.to(device)
    model_y.to(device)

    X_t = torch.as_tensor(X, device=device, dtype=torch.float32)
    Y_t = torch.as_tensor(Y, device=device, dtype=torch.float32)
    labels_buf = torch.arange(batch_size, device=device)

    optimizer_x = optim.Adam(model_x.parameters(), lr=lr, weight_decay=wd)
    optimizer_y = optim.Adam(model_y.parameters(), lr=lr, weight_decay=wd)

    loss_clip, loss_align, loss_x, loss_y, tau_seq = [], [], [], [], []
    num_batches = X.shape[0] // batch_size

    for epoch in tqdm(range(max_epochs)):
        losses_clip, losses_align, losses_x, losses_y, tau_seq_ep = [], [], [], [], []

        for batch_idx in range(num_batches):
            optimizer_x.zero_grad(set_to_none=True)
            optimizer_y.zero_grad(set_to_none=True)

            start = batch_idx * batch_size
            end = start + batch_size
            x = X_t[start:end]
            y = Y_t[start:end]

            h_x = model_x(x)
            h_y = model_y(y)

            if tau_tune:
                tau = (tau_lr_fac * model_x.logit_scale).exp() + tau_lower
                tau_seq_ep.append(tau.item())
            else:
                tau = torch.tensor(tau_fix, device=device, dtype=h_x.dtype)

            if spectral:
                loss = spectral_loss(h_x, h_y, tau)
            else:
                logits = (h_x @ h_y.T) / tau
                bn = logits.shape[0]
                labels = labels_buf[:bn]
                loss_x_ = Fnn.cross_entropy(logits, labels)
                loss_y_ = Fnn.cross_entropy(logits.T, labels)
                loss = (loss_x_ + loss_y_) / 2.0

            with torch.no_grad():
                losses_clip.append(loss.item())
                hx_u = Fnn.normalize(h_x, dim=1)
                hy_u = Fnn.normalize(h_y, dim=1)
                diag_align = (hx_u * hy_u).sum(dim=1).mean()
                losses_align.append(-diag_align.item())

                if not spectral:
                    diag_raw = (h_x * h_y).sum(dim=1).mean().item()
                    losses_x.append(loss_x_.item()/2.0 - diag_raw)
                    losses_y.append(loss_y_.item()/2.0 - diag_raw)

            loss.backward()
            optimizer_x.step()
            optimizer_y.step()

        tau_seq.append(np.mean(tau_seq_ep))
        loss_clip.append(np.mean(losses_clip))
        loss_align.append(np.mean(losses_align))
        if not spectral:
            loss_x.append(np.mean(losses_x))
            loss_y.append(np.mean(losses_y))

    return {
        'model_x': model_x,
        'model_y': model_y,
        'loss_clip': loss_clip,
        'loss_align': loss_align,
        'loss_x': loss_x,
        'loss_y': loss_y,
        'tau_seq': tau_seq,
    }



def train_disentangle_raw(
    X, F,
    outdim, 
    seq_len=50, arch='deep', max_epochs=200, batch_size=256, is_fix=True, tau_tune=False, 
    lr=1e-4, wd=1e-4, ss=20, aug=False,
    tau_lr_frac=4, tau_fix=1e-2, tau_lower=1e-4, lam=1e-1, 
    objective='recons', device=None
):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    X = X.clone().detach().to(dtype=torch.float32, device=device)
    F = F.clone().detach().to(dtype=torch.float32, device=device)

    dataset = TensorDataset(F, X)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    arch_map = {
        'linear':    (PadToDim, LinearNet, LinearNet),
        'nonlinear': (PadToDim, NonLinearNet, LinearNet),
        'deep':      (PadToDim, NonLinearNetD, NonLinearNetD),
        'transformer': (PadToDim, Transformer_tens, Transformer_ma), 
        'transformer_ma': (PadToDim, Transformer_ma, Transformer_ma), 
    }
    Pad, Wnet, Unet = arch_map.get(arch, arch_map)

    if is_fix:
        Fnet = PadToDim(outdim)
    else:
        Fnet = Wnet(F.shape[1], 50, outdim, tau_lower)

    d_x = X.shape[2]
    if arch == 'transformer' or arch == 'transformer_ma':
        models = {
            'f': Fnet.to(device),
            'w': Wnet(d_x, outdim).to(device),
            'u': Unet(outdim * 2, middim=50, dim=d_x * seq_len).to(device),
            'club': CLUB(outdim, outdim, 50).to(device),
            'club2': CLUB(outdim, outdim, 50).to(device),
        }
    else:
        models = {
            'f': Fnet.to(device),
            'w': Wnet(d_x, 50, outdim, tau_lower).to(device),
            'u': Unet(2*outdim, 50, d_x, tau_lower).to(device),
            'club': CLUB(outdim, outdim, 50).to(device),
            'club2': CLUB(outdim, outdim, 50).to(device),
        }

    optimizers = {
        name: torch.optim.Adam(m.parameters(), lr=lr, weight_decay=wd)
        for name, m in models.items()
        if not (name == 'f' and is_fix)
    }
    mse_loss = torch.nn.MSELoss()
    ce_loss = lambda logits: Fnn.cross_entropy(logits, torch.arange(logits.size(0), device=device))

    tau_seq, loss_clip, loss_recons, loss_align = [], [], [], []

    for epoch in tqdm(range(max_epochs), desc='Epoch'):
        losses = {'clip': [], 'recons': [], 'align': []}
        tau_vals = []

        for F_batch, X_batch in loader:
            F_batch = F_batch.to(device, non_blocking=True)
            X_batch = X_batch.to(device, non_blocking=True)

            h_f = F_batch
            h_w = models['w'](X_batch)
            h_u = models['u'](torch.cat([h_f, h_w], dim=1)).view(X_batch.size(0), seq_len, d_x)

            X_flat = X_batch.reshape(X_batch.shape[0], -1)
            h_u_flat = h_u.reshape(h_u.shape[0], -1)

            h_w_s = Fnn.normalize(h_w, dim=-1)
            h_f_s = Fnn.normalize(h_f, dim=-1)

            tau = ((tau_lr_frac * models['w'].logit_scale).exp() + tau_lower if tau_tune else tau_fix)
            tau_vals.append(tau.item() if isinstance(tau, torch.Tensor) else tau)


            loss_club = train_club_batch(models['club'], h_f_s, h_w_s, club_steps=10, club_lr=1e-3)
            loss_mse = mse_loss(h_u, X_batch)

            if aug == False:
                loss_ssl = 0
            else:
                X_batch_aug = augment_single(X_batch)

                h_w_aug = models['w'](X_batch_aug)
                h_w_s_aug = Fnn.normalize(h_w_aug, dim=-1)

                logits_fw = h_w_s @ h_w_s_aug.T #/ tau
                loss_ssl = ce_loss(logits_fw) + ce_loss(logits_fw.T)

            if objective == 'ours1':
                clip = loss_fw
                recon = loss_mse
            elif objective == 'disen':
                pos_sim = torch.norm(torch.matmul(h_f_s.T, h_w_s))
                clip = pos_sim.mean()
                recon = 0.5 * (
                    ce_loss((Fnn.normalize(h_u_flat, 1) @ Fnn.normalize(X_flat, 1).T) / tau) +
                    ce_loss((Fnn.normalize(X_flat, 1) @ Fnn.normalize(h_u_flat, 1).T) / tau)
                )
            elif objective == 'recons':
                clip = loss_club
                recon = loss_mse
            else:
                clip = loss_club
                recon = 0.5 * (
                    ce_loss((Fnn.normalize(h_u_flat, 1) @ Fnn.normalize(X_flat, 1).T) / tau) +
                    ce_loss((Fnn.normalize(X_flat, 1) @ Fnn.normalize(h_u_flat, 1).T) / tau)
                )

            loss = clip + lam * (recon + loss_ssl)

            losses['clip'].append(clip.item())
            losses['recons'].append(recon.item())
            losses['align'].append((h_f @ h_w.T).diag().mean().item())

            for name, opt in optimizers.items():
                opt.zero_grad()
            loss.backward()
            for name, opt in optimizers.items():
                opt.step()

        tau_seq.append(np.mean(tau_vals))
        loss_clip.append(np.mean(losses['clip']))
        loss_recons.append(np.mean(losses['recons']))
        loss_align.append(np.mean(losses['align']))

    return {
        'models': models,
        'tau_seq': tau_seq,
        'loss_clip': loss_clip,
        'loss_recons': loss_recons,
        'loss_align': loss_align
    }




def postprocess_disentangle(models, X, X_test, F, F_test, lab1_train, lab1_test, lab2_train, lab2_test, device="cuda"):
    import torch
    import numpy as np
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score, recall_score, classification_report

    F_test_t = torch.tensor(F_test, dtype=torch.float32, device=device)
    X_test_t = torch.tensor(X_test, dtype=torch.float32, device=device)
    F_t = models['f'](F_test_t)
    W_t = models['w'](X_test_t)
    U_t = torch.cat([F_t, W_t], dim=1)
    U_r = models['u'](U_t).detach().cpu().numpy()

    W_w = effective_weights(models['w'])
    W_u = effective_weights(models['u'])
    s_w = np.linalg.svd(W_w, full_matrices=False)[1] if W_w is not None else None

    X_train_np = X if isinstance(X, np.ndarray) else X.detach().cpu().numpy()
    F_train_np = F if isinstance(F, np.ndarray) else F.detach().cpu().numpy()
    W_train_np = models['w'](torch.tensor(X, dtype=torch.float32, device=device)).detach().cpu().numpy()
    X_test_np = X_test if isinstance(X_test, np.ndarray) else X_test.detach().cpu().numpy()
    F_test_np = F_test if isinstance(F_test, np.ndarray) else F_test.detach().cpu().numpy()
    W_test_np = W_t.detach().cpu().numpy()

    accs, per_class, reports = {}, {}, {}

    for label_name, lab_train, lab_test in [
        ("lab1", lab1_train, lab1_test),
        ("lab2", lab2_train, lab2_test)
    ]:
        lab_train_np = lab_train if isinstance(lab_train, np.ndarray) else lab_train.detach().cpu().numpy()
        lab_test_np = lab_test if isinstance(lab_test, np.ndarray) else lab_test.detach().cpu().numpy()

        clf_x    = LogisticRegression(max_iter=2000).fit(X_train_np, lab_train_np)
        clf_f    = LogisticRegression(max_iter=2000).fit(F_train_np, lab_train_np)
        clf_w    = LogisticRegression(max_iter=2000).fit(W_train_np, lab_train_np)
        clf_comb = LogisticRegression(max_iter=2000).fit(np.concatenate([W_train_np, F_train_np], axis=1), lab_train_np)

        preds = {
            "X": clf_x.predict(X_test_np),
            "F": clf_f.predict(F_test_np),
            "W": clf_w.predict(W_test_np),
            "Comb": clf_comb.predict(np.concatenate([W_test_np, F_test_np], axis=1)),
        }

        accs[label_name] = {
            k: accuracy_score(lab_test_np, v) for k, v in preds.items()
        }

        per_class[label_name] = {
            k: recall_score(lab_test_np, v, average=None) for k, v in preds.items()
        }

        reports[label_name] = {
            k: classification_report(lab_test_np, v, digits=4, output_dict=True) for k, v in preds.items()
        }

    return {
        "W_w": W_w,
        "W_u": W_u,
        "s_w": s_w,
        "accs": accs,
        "preds": preds,
        "per_class": per_class,
        "reports": reports,
    }




def postprocess_disentangle_joint(model_x, model_y, X, X_test, F, F_test, Y, Y_test, G, G_test, lab1_train, lab1_test, lab2_train, lab2_test, device="cuda"):
    import torch
    import numpy as np
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score

    F_test_t = torch.tensor(F_test, dtype=torch.float32, device=device)
    X_test_t = torch.tensor(X_test, dtype=torch.float32, device=device)
    F_t = model_x['f'](F_test_t)
    W_t = model_x['w'](X_test_t)
    U_r_x = model_x['u'](torch.cat([F_t, W_t], dim=1)).detach().cpu().numpy()

    G_test_t = torch.tensor(G_test, dtype=torch.float32, device=device)
    Y_test_t = torch.tensor(Y_test, dtype=torch.float32, device=device)
    G_t = model_y['f'](G_test_t)
    W_y_t = model_y['w'](Y_test_t)
    U_r_y = model_y['u'](torch.cat([G_t, W_y_t], dim=1)).detach().cpu().numpy()

    W_w = effective_weights(model_x['w'])
    W_u = effective_weights(model_x['u'])
    s_w = np.linalg.svd(W_w, full_matrices=False)[1] if W_w is not None else None

    X_train_np = X if isinstance(X, np.ndarray) else X.detach().cpu().numpy()
    F_train_np = F if isinstance(F, np.ndarray) else F.detach().cpu().numpy()
    W_train_np = model_x['w'](torch.tensor(X, dtype=torch.float32, device=device)).detach().cpu().numpy()
    X_test_np = X_test if isinstance(X_test, np.ndarray) else X_test.detach().cpu().numpy()
    F_test_np = F_test if isinstance(F_test, np.ndarray) else F_test.detach().cpu().numpy()
    W_test_np = W_t.detach().cpu().numpy()

    Y_train_np = Y if isinstance(Y, np.ndarray) else Y.detach().cpu().numpy()
    G_train_np = G if isinstance(G, np.ndarray) else G.detach().cpu().numpy()
    Wy_train_np = model_y['w'](torch.tensor(Y, dtype=torch.float32, device=device)).detach().cpu().numpy()
    Y_test_np = Y_test if isinstance(Y_test, np.ndarray) else Y_test.detach().cpu().numpy()
    G_test_np = G_test if isinstance(G_test, np.ndarray) else G_test.detach().cpu().numpy()
    Wy_test_np = W_y_t.detach().cpu().numpy()

    accs = {}
    for label_name, lab_train, lab_test in [
        ("lab1", lab1_train, lab1_test),
        ("lab2", lab2_train, lab2_test)
    ]:
        lab_train_np = lab_train if isinstance(lab_train, np.ndarray) else lab_train.detach().cpu().numpy()
        lab_test_np = lab_test if isinstance(lab_test, np.ndarray) else lab_test.detach().cpu().numpy()

        clf_clip = LogisticRegression(max_iter=2000).fit(np.concatenate([F_train_np, G_train_np], axis=1), lab_train_np)
        

        clf_concat = LogisticRegression(max_iter=2000).fit(
            np.concatenate([W_train_np, Wy_train_np, F_train_np, G_train_np], axis=1), lab_train_np
        )

        accs[label_name] = {
            "clip": accuracy_score(lab_test_np, clf_clip.predict(np.concatenate([F_test_np, G_test_np], axis=1))),
            "concat": accuracy_score(lab_test_np, clf_concat.predict(np.concatenate([W_test_np, Wy_test_np, F_test_np, G_test_np], axis=1))),
        }

    return {
        "accs": accs,
    }





def postprocess_disentangle_raw(models, X, X_test, F, F_test, lab1_train, lab1_test, lab2_train, lab2_test, device="cpu"):
    
    F_t = models['f'](F_test)
    W_t = models['w'](X_test)
    U_t = torch.cat([F_t, W_t], dim=1)
    _ = models['u'](U_t).detach().cpu().numpy()

    W_w = effective_weights(models['w'])
    W_u = effective_weights(models['u'])
    s_w = np.linalg.svd(W_w, full_matrices=False)[1] if W_w is not None else None

    X_train_np = X if isinstance(X, np.ndarray) else X.cpu().numpy()
    X_train_np = X_train_np.reshape(X_train_np.shape[0], -1)
    F_train_np = F if isinstance(F, np.ndarray) else F.cpu().numpy()
    W_train_np = models['w'](X).detach().cpu().numpy()
    X_test_np = X_test if isinstance(X_test, np.ndarray) else X_test.cpu().numpy()
    X_test_np = X_test_np.reshape(X_test_np.shape[0], -1)
    F_test_np = F_test if isinstance(F_test, np.ndarray) else F_test.cpu().numpy()
    W_test_np = models['w'](X_test).detach().cpu().numpy()

    accs, per_class, reports = {}, {}, {}
    for label_name, lab_train, lab_test in [("lab1", lab1_train, lab1_test), ("lab2", lab2_train, lab2_test)]:
        lab_train_np = lab_train if isinstance(lab_train, np.ndarray) else lab_train.cpu().numpy()
        lab_test_np = lab_test if isinstance(lab_test, np.ndarray) else lab_test.cpu().numpy()
        clf_x = LogisticRegression(max_iter=2000, penalty="l2", solver="saga").fit(X_train_np, lab_train_np)
        clf_f = LogisticRegression(max_iter=2000, penalty="l2", solver="saga").fit(F_train_np, lab_train_np)
        clf_w = LogisticRegression(max_iter=2000, penalty="l2", solver="saga").fit(W_train_np, lab_train_np)
        clf_comb = LogisticRegression(max_iter=2000, penalty="l2", solver="saga").fit(np.concatenate([W_train_np, F_train_np], axis=1), lab_train_np)

        preds = {
            "X": clf_x.predict(X_test_np),
            "F": clf_f.predict(F_test_np),
            "W": clf_w.predict(W_test_np),
            "Comb": clf_comb.predict(np.concatenate([W_test_np, F_test_np], axis=1)),
        }

        accs[label_name] = {k: accuracy_score(lab_test_np, v) for k, v in preds.items()}
        per_class[label_name] = {k: recall_score(lab_test_np, v, average=None) for k, v in preds.items()}
        reports[label_name] = {k: classification_report(lab_test_np, v, digits=4, output_dict=True) for k, v in preds.items()}

    return {"W_w": W_w, "W_u": W_u, "s_w": s_w, "accs": accs, "preds": preds, "per_class": per_class, "reports": reports}


def postprocess_disentangle_joint_raw(model_x, model_y, X, X_test, F, F_test, Y, Y_test, G, G_test, lab1_train, lab1_test, lab2_train, lab2_test, device="cpu"):
    
    def to_np(arr): return arr if isinstance(arr, np.ndarray) else arr.cpu().numpy()

    X_train_np, F_train_np = to_np(X), to_np(F)
    Y_train_np, G_train_np = to_np(Y), to_np(G)
    W_train_np = model_x['w'](X).detach().cpu().numpy()
    Wy_train_np = model_y['w'](Y).detach().cpu().numpy()

    X_test_np, F_test_np = to_np(X_test), to_np(F_test)
    Y_test_np, G_test_np = to_np(Y_test), to_np(G_test)
    W_test_np = model_x['w'](X_test).detach().cpu().numpy()
    Wy_test_np = model_y['w'](Y_test).detach().cpu().numpy()

    print([X_test_np.shape, F_test_np.shape, G_test_np.shape, W_test_np.shape])

    X_train_np = X_train_np.reshape(X_train_np.shape[0], -1)
    X_test_np = X_test_np.reshape(X_test_np.shape[0], -1)
    Y_train_np = Y_train_np.reshape(Y_train_np.shape[0], -1)
    Y_test_np = Y_test_np.reshape(Y_test_np.shape[0], -1)

    accs = {}
    for label_name, lab_train, lab_test in [("lab1", lab1_train, lab1_test), ("lab2", lab2_train, lab2_test)]:
        lab_train_np = to_np(lab_train)
        lab_test_np = to_np(lab_test)
        clf_clip = LogisticRegression(max_iter=2000, penalty="l2", solver="saga").fit(np.concatenate([F_train_np, G_train_np], axis=1), lab_train_np)
        clf_feat = LogisticRegression(max_iter=2000, penalty="l2", solver="saga").fit(np.concatenate([W_train_np, Wy_train_np, F_train_np, G_train_np], axis=1), lab_train_np)
        accs[label_name] = {
            "clip": accuracy_score(lab_test_np, clf_clip.predict(np.concatenate([F_test_np, G_test_np], axis=1))),
            "concat": accuracy_score(lab_test_np, clf_feat.predict(np.concatenate([W_test_np, Wy_test_np, F_test_np, G_test_np], axis=1))),
        }

    return {"accs": accs}






def train_disentangle_bimodal(
    X, Y, F, G,
    outdim,
    seq_len_x=50, seq_len_y=50, arch='deep', max_epochs=200, batch_size=256,
    is_fix=True, tau_tune=False,
    lr=1e-4, wd=1e-4, ss=20, aug=False,
    tau_lr_frac=4, tau_fix=1e-2, tau_lower=1e-4,
    lam=1e-1,                 
    lam_ind=1.0,              
    objective='recons', device=None
):

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    X = X.clone().detach().to(dtype=torch.float32, device=device)
    Y = Y.clone().detach().to(dtype=torch.float32, device=device)
    F = F.clone().detach().to(dtype=torch.float32, device=device)
    G = G.clone().detach().to(dtype=torch.float32, device=device)

    dataset = TensorDataset(F, X, G, Y)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    arch_map = {
        'linear':       (PadToDim, LinearNet,       LinearNet),
        'nonlinear':    (PadToDim, NonLinearNet,    LinearNet),
        'deep':         (PadToDim, NonLinearNetD,   NonLinearNetD),
        'transformer':  (PadToDim, Transformer_tens, Transformer_ma),
        'transformer_ma': (PadToDim, Transformer_ma, Transformer_ma),
    }
    Pad, Wnet, Unet = arch_map.get(arch, arch_map['deep'])

    Fx = PadToDim(outdim) if is_fix else Wnet(F.shape[1], 50, outdim, tau_lower)
    Gy = PadToDim(outdim) if is_fix else Wnet(G.shape[1], 50, outdim, tau_lower)

    d_x = X.shape[2]
    d_y = Y.shape[2]

    if arch in ('transformer', 'transformer_ma'):
        models = {
            'fx': Fx.to(device),  
            'gy': Gy.to(device), 

            'wx': Wnet(d_x, outdim).to(device),
            'wy': Wnet(d_y, outdim).to(device),

            'ux': Unet(2 * outdim, middim=50, dim=d_x * seq_len_x).to(device),
            'uy': Unet(2 * outdim, middim=50, dim=d_y * seq_len_y).to(device),

            'club_x': CLUB(outdim, outdim, 50).to(device),  # I(F; Wx)
            'club_y': CLUB(outdim, outdim, 50).to(device),  # I(G; Wy)
        }
    else:
        models = {
            'fx': Fx.to(device),
            'gy': Gy.to(device),

            'wx': Wnet(d_x, 50, outdim, tau_lower).to(device),
            'wy': Wnet(d_y, 50, outdim, tau_lower).to(device),

            'ux': Unet(2 * outdim, 50, d_x, tau_lower).to(device),
            'uy': Unet(2 * outdim, 50, d_y, tau_lower).to(device),

            'club_x': CLUB(outdim, outdim, 50).to(device),
            'club_y': CLUB(outdim, outdim, 50).to(device),
        }

    optimizers = {
        name: torch.optim.Adam(m.parameters(), lr=lr, weight_decay=wd)
        for name, m in models.items()
        if not ((name in ('fx', 'gy')) and is_fix)
    }

    mse_loss = torch.nn.MSELoss()
    ce_loss = lambda logits: Fnn.cross_entropy(
        logits, torch.arange(logits.size(0), device=device)
    )

    tau_seq, loss_clip, loss_recon_x, loss_recon_y, loss_align = [], [], [], [], []

    for epoch in tqdm(range(max_epochs), desc='Epoch'):
        logs_epoch = {'clip_x': [], 'clip_y': [], 'recon_x': [], 'recon_y': [], 'align': []}
        tau_vals = []

        for F_b, X_b, G_b, Y_b in loader:
            F_b = F_b.to(device, non_blocking=True)
            X_b = X_b.to(device, non_blocking=True)
            G_b = G_b.to(device, non_blocking=True)
            Y_b = Y_b.to(device, non_blocking=True)

            h_f = F_b if is_fix else models['fx'](F_b)  # (B, outdim)
            h_g = G_b if is_fix else models['gy'](G_b)  # (B, outdim)

            h_wx = models['wx'](X_b)  # (B, outdim)
            h_wy = models['wy'](Y_b)  # (B, outdim)

            h_u_x = models['ux'](torch.cat([h_f, h_wx], dim=1)).view(X_b.size(0), seq_len_x, d_x)
            h_u_y = models['uy'](torch.cat([h_g, h_wy], dim=1)).view(Y_b.size(0), seq_len_y, d_y)

            # normalizations
            h_wx_s = Fnn.normalize(h_wx, dim=-1)
            h_wy_s = Fnn.normalize(h_wy, dim=-1)
            h_f_s  = Fnn.normalize(h_f,  dim=-1)
            h_g_s  = Fnn.normalize(h_g,  dim=-1)

            tau = ((tau_lr_frac * models['wx'].logit_scale).exp() + tau_lower if tau_tune else tau_fix)
            tau_vals.append(tau.item() if isinstance(tau, torch.Tensor) else tau)


            X_flat = X_b.reshape(X_b.shape[0], -1)
            h_u_x_flat = h_u_x.reshape(h_u_x.shape[0], -1)

            Y_flat = Y_b.reshape(Y_b.shape[0], -1)
            h_u_y_flat = h_u_y.reshape(h_u_y.shape[0], -1)

            recon_x = mse_loss(h_u_x, X_b)
            recon_y = mse_loss(h_u_y, Y_b)

            recon_nce = 0.5 * (
                    ce_loss((Fnn.normalize(h_u_x_flat, 1) @ Fnn.normalize(X_flat, 1).T) / tau) +
                    ce_loss((Fnn.normalize(X_flat, 1) @ Fnn.normalize(h_u_x_flat, 1).T) / tau) +
                    ce_loss((Fnn.normalize(h_u_y_flat, 1) @ Fnn.normalize(Y_flat, 1).T) / tau) +
                    ce_loss((Fnn.normalize(Y_flat, 1) @ Fnn.normalize(h_u_y_flat, 1).T) / tau)
                )

            loss_club_x = train_club_batch(models['club_x'], h_f_s, h_wx_s, club_steps=10, club_lr=1e-3)
            loss_club_y = train_club_batch(models['club_y'], h_g_s, h_wy_s, club_steps=10, club_lr=1e-3)
            loss_club = loss_club_x + loss_club_y

            if aug:
                with torch.no_grad():
                    X_b_aug = augment_single(X_b)
                    Y_b_aug = augment_single(Y_b)

                h_wx_aug = models['wx'](X_b_aug)
                h_wy_aug = models['wy'](Y_b_aug)
                h_wx_aug_s = Fnn.normalize(h_wx_aug, dim=-1)
                h_wy_aug_s = Fnn.normalize(h_wy_aug, dim=-1)

                logits_xx = h_wx_s @ h_wx_aug_s.T
                logits_yy = h_wy_s @ h_wy_aug_s.T
                loss_ssl_1 = 0.5 * (
                    ce_loss(logits_xx) + ce_loss(logits_xx.T) +
                    ce_loss(logits_yy) + ce_loss(logits_yy.T)
                )

                logits_xy = torch.cat([h_wx_s, h_wx_aug_s]) @ torch.cat([h_wy_s, h_wy_aug_s]).T
                loss_ssl_2 = 0.5 * (ce_loss(logits_xy) + ce_loss(logits_xy.T))

                loss_ssl = loss_ssl_1 + loss_ssl_2
            else:
                loss_ssl = 0.0

            if objective == 'disen':
                pos_sim_x = torch.norm(torch.matmul(h_f_s.T, h_wx_s))
                pos_sim_y = torch.norm(torch.matmul(h_g_s.T, h_wy_s))
                clip_term = pos_sim_x.mean() + pos_sim_y.mean()
                recon_term = recon_nce
            elif objective == 'recons':
                clip_term = loss_club  
                recon_term = recon_x + recon_y
            else:
                clip_term = loss_club
                recon_term = recon_nce

            loss = clip_term + lam * (recon_term + loss_ssl)

            for name, opt in optimizers.items():
                opt.zero_grad()
            loss.backward()
            for name, opt in optimizers.items():
                opt.step()

        models_x = {
            'f': models['fx'],
            'w': models['wx'],
            'u': models['ux'],
            'club': models['club_x'],
        }

        models_y = {
            'f': models['gy'],
            'w': models['wy'],
            'u': models['uy'],
            'club': models['club_y'],
        }

        result_X = {
            'models': models_x
        }

        result_Y = {
            'models': models_y
        }

    return result_X, result_Y

