from .graph_dataset import TransTUDataset
import torch
from torch.linalg import svd
from torch_geometric.data import Data
from torch_geometric.datasets import TUDataset
import tqdm
from .split import split_tu_dataset
import os.path as osp


def align_feat(datasets, root, dim, scales, device, save_path, test_name, gamma_scale=0.0001, iter_n=100, 
               return_org=False):
    M = len(datasets)
    
    datasets_list = [
        TransTUDataset(
            root, 
            name, 
            dim=dim, 
            scales=scales, 
            mode=None, 
            align_feat=False            
        ) for name in datasets
    ]
    Rs = []
    Mus = []    # [len(scales), M, dim]
    for j in range(len(scales)):
        Mu = torch.empty((len(datasets), dim), dtype=torch.float)   # [M, dim]
        for idx, dataset in enumerate(datasets_list):
            mu = torch.mean(dataset.x[:, j*dim : (j+1)*dim], dim=0)
            Mu[idx, :] = mu
        Mus.append(Mu)
        Mu = Mu.to(device)  # M x dim
        norm2_Mu = torch.norm(Mu, dim=1, p=2)**2    # [M]
        gamma = gamma_scale / (torch.cdist(Mu, Mu).mean() ** 2)
        # print(gamma_scale)
        # print(gamma)
        W = torch.exp(
            - gamma * (norm2_Mu[:, None] + norm2_Mu[None, :])
            )    # M x M
        R = torch.eye(dim).unsqueeze(0).repeat(M, 1, 1).to(device)  # [M, dim, dim]
        
        for i in tqdm.tqdm(range(iter_n)):
            R_Mu = torch.bmm(R, Mu.unsqueeze(2)).squeeze(2)  # [M, d]
            K = torch.exp(2 * gamma * (R_Mu @ R_Mu.t()))      # [M, M]
            R_Mu_outer = R_Mu.unsqueeze(0).unsqueeze(3) * Mu.unsqueeze(1).unsqueeze(2)    # [1, M, d, 1] x [M, 1, 1, d] -> [M, M, d, d]

            G = torch.sum((W * K).unsqueeze(2).unsqueeze(3) * R_Mu_outer, dim=1)
            U, _, V = svd(G)
            R = torch.bmm(U, V) # [M, dim, dim]
        Rs.append(R)
    Mus  = torch.stack(Mus, dim=0).transpose(0, 1)  #[M, len(scales), dim, dim]
    torch.save(Mus, osp.join(save_path, f'pretrain_wo_{test_name}_Mus.pt'))

    return_dataset = []
    
    for idx, dataset in enumerate(datasets_list):
        print(f'Aligning {datasets[idx]}')
        Zs = []
        for j in range(len(scales)):
            Z = dataset.x[:, j*dim:(j+1)*dim]
            Z = torch.mm(Z, Rs[j][idx].t().cpu())
            Zs.append(Z)
        Zs = torch.cat(Zs, dim=-1)
        Zs = torch.cat([Zs, dataset.x[:, len(scales) * dim:]], dim=-1)
        # print(Z.shape)
        # 将Z保存为List[Data]
        data_list = []
        start = 0
        for i, data in enumerate(dataset):
            n_node = data.x.shape[0]
            end = start + n_node
            data_list.append(Data(x=Zs[start: end], edge_index=data.edge_index, y=data.y))
            start = end

        return_dataset.append(data_list)

    Rs = torch.stack(Rs, dim=0).cpu()   # [len(scales), M, dim, dim]
    Rs = Rs.transpose(0, 1) # [M, len(scales), dim, dim]
    torch.save(Rs, osp.join(save_path, f'pretrain_wo_{test_name}_Rs.pt'))
    if return_org:
        return datasets_list, return_dataset
    return return_dataset
    

def align_feat_downstream(name, root, dim, scales, device, save_path, test_name, gamma_scale=0.01, iter_n=100, k_shot=50):
  
    train_path, test_path, _, test_indices = split_tu_dataset(
        name, 
        root, 
        mode='fewshot', 
        force_reload=True, 
        k_shot=k_shot, 
        return_indices=True)
    train = TransTUDataset(
                    train_path, 
                    name, 
                    dim=dim, 
                    scales=scales, 
                    mode='train')
    test = TransTUDataset(
                    test_path,  
                    name, 
                    dim=dim, 
                    scales=scales, 
                    mode='test', 
                    align_feat=True, 
                    force_reload=True)
    train_mu= torch.empty((1, dim))
    pretrain_mu = torch.load(osp.join(save_path, f'pretrain_wo_{test_name}_Mus.pt'), weights_only=True)
    R_pre = torch.load(osp.join(save_path, f'pretrain_wo_{test_name}_Rs.pt'), weights_only=True)
    M = R_pre.shape[0]
    
    Rs_train = []

    test_mu = torch.empty((M+2, len(scales), dim))
    for j in range(len(scales)):
        train_mu_j = torch.mean(train.x[:, j*dim:(j+1)*dim], dim=0) # [dim,]
        test_mu_j = torch.mean(test.x[:, j*dim:(j+1)*dim], dim=0)   # [dim,]
        train_mu = torch.cat([pretrain_mu[:, j, :], train_mu_j.unsqueeze(0)], dim=0)    # [M+1, dim]
        test_mu[:, j, :] = torch.cat([train_mu, test_mu_j.unsqueeze(0)], dim=0)   #[M+2, dim]
        train_mu = train_mu.to(device)
        

        gamma = gamma_scale / (torch.cdist(pretrain_mu[:, j, :], pretrain_mu[:, j, :]).mean()**2)
        print(gamma)

        norm2_Mu_train = torch.norm(train_mu, dim=1, p=2)**2    # [M]
        W = torch.exp(
            - gamma * (norm2_Mu_train[:, None] + norm2_Mu_train[None, :])
            )    # [M+1, M+1]
        
        R_train = torch.eye(dim).unsqueeze(0).reshape(1, dim, dim).to(device)  # [1, dim, dim]
        R_train = torch.cat([R_pre[:, j, :, :].to(device), R_train], dim=0)     # [M+1, dim, dim]
        for i in range(iter_n):
            R_Mu = torch.bmm(R_train, train_mu.unsqueeze(2)).squeeze(2)  # [M, dim]
            K = torch.exp(2 * gamma * (R_Mu @ R_Mu.t()))      # [M, M]
            # R_Mu_outer = R_Mu.unsqueeze(0).unsqueeze(3) * train_mu.unsqueeze(1).unsqueeze(2)    # [1, M, d, 1] x [M, 1, 1, d] -> [M, M, d, d]

            R_Mu_outer = R_Mu.unsqueeze(2) * train_mu[-1].unsqueeze(0).unsqueeze(0) # [M, dim, 1] x [1, 1, dim] -> [M, dim, dim]
            G = torch.sum((W[-1] * K[-1]).unsqueeze(1).unsqueeze(2) * R_Mu_outer, dim=0)
            U, _, V = svd(G)
            R_train[-1, :, :] = torch.mm(U, V)
        Rs_train.append(R_train[-1])
    Rs_train_torch = torch.stack(Rs_train, dim=0)
    torch.save(Rs_train_torch, osp.join(save_path, f'{test_name}_align_Rs.pt'))
    print(f'Save Rs_train_torch at {osp.join(save_path, f"{test_name}_align_Rs.pt")}')

    train_dataset, test_dataset = [], []
    
    z_train = []
    for j in range(len(scales)):
        z_j_train = train.x[:, j*dim:(j+1)*dim]
        z_j_train = torch.mm(z_j_train, Rs_train[j].t().cpu())
        z_train.append(z_j_train)
    z_train = torch.cat(z_train, dim=-1)
    z_train = torch.cat([z_train, train.x[:, len(scales) * dim:]], dim=-1)
    start = 0
    for idx, data in enumerate(train):
        n_node = data.x.shape[0]
        end = start + n_node
        train_dataset.append(Data(x=z_train[start:end, :], edge_index=data.edge_index, y=data.y))
        start = end


    z_test = []
    for j in range(len(scales)):
        z_j_test = test.x[:, j*dim:(j+1)*dim]
        z_j_test = torch.mm(z_j_test, Rs_train[j].t().cpu())
        z_test.append(z_j_test)
    z_test = torch.cat(z_test, dim=-1)
    z_test = torch.cat([z_test, test.x[:, len(scales) * dim:]], dim=-1)
    start = 0
    for idx, data in enumerate(test):
        n_node = data.x.shape[0]
        end = start + n_node
        test_dataset.append(Data(x=z_test[start:end, :], edge_index=data.edge_index, y=data.y))    
        start = end

    return train_dataset, test_dataset



def align_single(name, root, dim, scales, device, save_path, test_name, force_reload=False,gamma_scale=0.1, iter_n=100):
    train = TransTUDataset(
                    root, 
                    name, 
                    dim=dim, 
                    scales=scales, 
                    mode=None, 
                    force_reload=force_reload)
    
    train_mu= torch.empty((1, dim))
    pretrain_mu = torch.load(osp.join(save_path, f'pretrain_wo_{test_name}_Mus.pt'), weights_only=True)
    R_pre = torch.load(osp.join(save_path, f'pretrain_wo_{test_name}_Rs.pt'), weights_only=True)
    M = R_pre.shape[0]
    
    Rs_train = []
    for j in range(len(scales)):
        train_mu = torch.mean(train.x[:, j*dim:(j+1)*dim], dim=0)
        train_mu = torch.cat([pretrain_mu[:, j, :], train_mu.unsqueeze(0)], dim=0)    # [M+1, dim]
        train_mu = train_mu.to(device)

        gamma = gamma_scale / (torch.cdist(pretrain_mu[:, j, :], pretrain_mu[:, j, :]).mean()**2)
        # print(gamma)
        norm2_Mu_train = torch.norm(train_mu, dim=1, p=2)**2    # [M]
        W = torch.exp(
            - gamma * (norm2_Mu_train[:, None] + norm2_Mu_train[None, :])
            )    # [M+1, M+1]
        
        R_train = torch.eye(dim).unsqueeze(0).reshape(1, dim, dim).to(device)  # [1, dim, dim]
        R_train = torch.cat([R_pre[:, j, :, :].to(device), R_train], dim=0)     # [M+1, dim, dim]
        for i in tqdm.tqdm(range(iter_n)):
            R_Mu = torch.bmm(R_train, train_mu.unsqueeze(2)).squeeze(2)  # [M, d]
            K = torch.exp(2 * gamma * (R_Mu @ R_Mu.t()))      # [M, M]
            # R_Mu_outer = R_Mu.unsqueeze(0).unsqueeze(3) * train_mu.unsqueeze(1).unsqueeze(2)    # [1, M, d, 1] x [M, 1, 1, d] -> [M, M, d, d]

            R_Mu_outer = R_Mu.unsqueeze(2) * train_mu[-1].unsqueeze(0).unsqueeze(0)
            G = torch.sum((W[-1] * K[-1]).unsqueeze(1).unsqueeze(2) * R_Mu_outer, dim=0)
            U, _, V = svd(G)
            R_train[-1, :, :] = torch.mm(U, V)
        Rs_train.append(R_train[-1])

    train_dataset = []
    
    z_train = []
    for j in range(len(scales)):
        z_j_train = train.x[:, j*dim:(j+1)*dim]
        z_j_train = torch.mm(z_j_train, Rs_train[j].t().cpu())
        z_train.append(z_j_train)
    z_train = torch.cat(z_train, dim=-1)
    for idx, data in enumerate(train):
        start = train.slices['x'][idx]
        end = train.slices['x'][idx+1]
        train_dataset.append(Data(x=z_train[start:end, :], edge_index=data.edge_index, y=data.y))
    return train_dataset


def align_graph_feat(node_features, node_slice, batch, dim, mode, 
                     load_path=None, prefix=None, device=None, 
                     gamma=0.1, iter_n=100, 
                     scale = None):

    # device = 'cuda' if torch.cuda.is_available() else 'cpu'

    M = node_slice.shape[0] - 1
    Mu = torch.empty((batch[-1] + 1, dim))
    for i in range(batch[-1] + 1):
        start = node_slice[i]
        end = node_slice[i+1]
        Mu[i, :] = torch.mean(node_features[start:end, :], dim=0)

    R = torch.eye(dim).unsqueeze(0).repeat(M, 1, 1).to(device) # [M, dim, dim]
    Mu = Mu.to(device)  # M x dim   
    if mode == 'test':
        if scale is None:
            train_Mu = torch.load(osp.join(load_path, f'{prefix}_train_Mu.pt'), weights_only=True, map_location=device)
            train_R = torch.load(osp.join(load_path, f'{prefix}_train_R.pt'), weights_only=True, map_location=device)
        else:
            train_Mu = torch.load(osp.join(load_path, f'scale{scale}_{prefix}_train_Mu.pt'), weights_only=True, map_location=device)
            train_R = torch.load(osp.join(load_path, f'scale{scale}_{prefix}_train_R.pt'), weights_only=True, map_location=device)
        Mu = torch.cat([train_Mu, Mu], dim=0)
        R = torch.cat([train_R, R], dim=0)
        idx_base = len(train_Mu)
    else:
        idx_base = 0

    # R = R.to(device)

    norm2_Mu = torch.norm(Mu, dim=1, p=2)**2    # [M]
    W = torch.exp(
        - gamma * (norm2_Mu[:, None] + norm2_Mu[None, :])
        )    # M x M
    
    if R.shape[0] <=2000:
        for i in tqdm.tqdm(range(iter_n)):
            R_Mu = torch.bmm(R, Mu.unsqueeze(2)).squeeze(2)  # [M, d]
            K = torch.exp(2 * gamma * (R_Mu @ R_Mu.t()))      # [M, M]
            R_Mu_outer = R_Mu.unsqueeze(0).unsqueeze(3) * Mu.unsqueeze(1).unsqueeze(2)    # [1, M, d, 1] x [M, 1, 1, d] -> [M, M, d, d]

            G = torch.sum((W * K).unsqueeze(2).unsqueeze(3) * R_Mu_outer, dim=1)
            U, _, V = svd(G)
            R[idx_base:, :, :] = torch.bmm(U, V)[idx_base:, :, :]
    else:
        for i in tqdm.tqdm(range(iter_n)):
            R_Mu = torch.bmm(R, Mu.unsqueeze(2)).squeeze(2)  # [M, d]
            for idx in range(M):
                idx+=idx_base
                K = torch.exp(2 * gamma * (R_Mu[idx].unsqueeze(0) @ R_Mu.t())).squeeze(0)      # [M]
                R_Mu_outer = R_Mu.unsqueeze(2) * Mu[idx].unsqueeze(0)   # [M, d, d]
                G = torch.sum((W[idx] * K).unsqueeze(1).unsqueeze(2) * R_Mu_outer, dim=0)
                U, _, V = svd(G)
                R[idx, :, :] = torch.mm(U, V)

    if mode == 'train':
        if scale is None:
            torch.save(R, osp.join(load_path, f'{prefix}_train_R.pt'))
            torch.save(Mu, osp.join(load_path, f'{prefix}_train_Mu.pt'))
        else:
            torch.save(R, osp.join(load_path, f'scale{scale}_{prefix}_train_R.pt'))
            torch.save(Mu, osp.join(load_path, f'scale{scale}_{prefix}_train_Mu.pt'))            

    for i in range(batch[-1] + 1):
        start = node_slice[i]
        end = node_slice[i+1]
        node_features[start: end, :] = torch.mm(node_features[start:end, :], R[idx_base + i].cpu().t())
    return node_features

