import gc
import os
os.environ['CUDA_PATH']='C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.8'
from torch_geometric.nn import  summary
from torch_geometric.utils import to_dense_adj, dense_to_sparse
import numpy as np
from torchmetrics.regression import SpearmanCorrCoef
import torchsort
from torch.utils.data import Sampler

import copy
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchaudio
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from os import path as pt
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric_temporal.nn.recurrent import A3TGCN2
from torch_geometric_temporal.signal import temporal_signal_split
from data_v1 import get_data,transform_back,generate_cbm,kurtosis_torch,skew_torch
import tqdm
# GPU support
DEVICE = torch.device('cuda') # cuda
shuffle=True
batch_size = 32
import ot
from torch_geometric.data import Data, Dataset, Batch
from torch_geometric.utils import dense_to_sparse
from torch_geometric.loader import DataLoader
from GNNGAN_v1 import ARFNN_Net,get_edge_index,jaccard_index
from plot_helper import plt_figures,evaluate
from torchmetrics.functional import jaccard_index as jaccard_index1
import pandas as pd
import datetime
torch.autograd.set_detect_anomaly(True)
def _rbf(norm, sigma):
    return torch.exp(-norm / (2 * sigma ** 2))


import torch


def correlated_noise(correlation_matrix, Z):
    eigvals, eigvecs = torch.linalg.eigh(correlation_matrix)

    eigvals = torch.clamp(eigvals, min=0)

    sqrt_eigvals = torch.sqrt(eigvals)
    sqrt_eigval_mat = torch.diag_embed(sqrt_eigvals)
    L = eigvecs @ sqrt_eigval_mat

    correlated_noise = torch.matmul( L.to(Z.device),Z)

    return correlated_noise
def batched_pairwise_noise_matrix(noise: torch.Tensor, flat_L: torch.Tensor) -> torch.Tensor:
    """
    Batched version of pairwise noise correlation:
    Z[b, i, j] = rho_ij[b] * noise[b, i] + sqrt(1 - rho_ij[b]**2) * noise[b, j]

    Args:
        noise: [B, n] - batch of independent noise vectors
        flat_L: [B, (n*(n-1))//2] - batch of flattened lower triangular correlation vectors

    Returns:
        Z: [B, n, n] - batch of lower-triangular noise matrices
    """
    B, n = noise.shape
    expected_flat_size = n * (n - 1) // 2
    assert flat_L.shape == (B, expected_flat_size), "Incorrect shape for flat_L"

    Z = torch.zeros(B, n, n, device=noise.device, dtype=noise.dtype)
    tril_i, tril_j = torch.tril_indices(n, n, offset=-1)

    # Gather noise at the right positions for each pair
    noise_i = noise[:, tril_i]  # [B, num_pairs]
    noise_j = noise[:, tril_j]

    rho = flat_L  # [B, num_pairs]
    Z_vals = rho * noise_i + torch.sqrt(1 - rho ** 2 + 1e-6) * noise_j

    Z[:, tril_i, tril_j] = Z_vals

    triu_i, triu_j = torch.triu_indices(n, n, offset=1)

    # Gather noise at the right positions for each pair
    noise_i = noise[:, triu_i]  # [B, num_pairs]
    noise_j = noise[:, triu_j]

    rho = flat_L  # [B, num_pairs]
    Z_vals = rho * noise_i + torch.sqrt(1 - rho ** 2 + 1e-6) * noise_j

    Z[:, triu_i, triu_j] = Z_vals
    return Z


def get_batched_lower_triangular_values(Z: torch.Tensor) -> torch.Tensor:
    """
    Extract strictly lower triangular values from a batched matrix [B, n, n]

    Returns:
        [B, (n*(n-1))//2] flattened lower triangular values per batch
    """
    B, n, _ = Z.shape
    tril_i, tril_j = torch.tril_indices(n, n, offset=-1)
    return Z[:, tril_i, tril_j]

def cov(m):
    fact = 1.0 / (m.shape[-1] - 1)  # 1 / N
    m -= torch.mean(m, dim=(1, 2), keepdim=True)
    mt = torch.transpose(m, 1, 2)  # if complex: mt = m.t().conj()
    return fact * m.matmul(mt).squeeze()

def corrcoef(x, y):
    batch_size = x.shape[0]
    x = torch.stack((x, y), 1)
    c = cov(x)
    d = torch.diagonal(c, dim1=1, dim2=2)
    stddev = torch.pow(d, 0.5)
    stddev = stddev.repeat(1, 2).view(batch_size, 2, 2)
    c = c.div(stddev)
    c = c.div(torch.transpose(stddev, 1, 2))
    return c[:, 1, 0]


def batch_corr(x,full=False):
    B, L, N = x.shape
    x=x+torch.rand(x.size()).to(x.device)*1e-6
    c=torch.stack([torch.corrcoef(x1.T) for x1 in x])
    if not full:
        indices = torch.triu_indices(x.size(2), x.size(2), 1)
        correlations=c[:, indices[0], indices[1]]
        return correlations
    else:
        return c
def pairwise_distance(X, Y=None):
    n = X.size(0)
    X = X.contiguous().view(n, -1)
    if Y is None:
        Y = X
    else:
        m = Y.size(0)
        Y = Y.contiguous().view(m, -1)
    l2_dist= ot.dist(X, Y)



    return l2_dist


def median_pairwise_distance(X, Y=None):
    """
    Heuristic for bandwidth of the RBF. Median pairwise distance of joint data.
    If Y is missing, just calculate it from X:
        this is so that, during training, as Y changes, we can use a fixed
        bandwidth (and save recalculating this each time we evaluated the mmd)
    At the end of training, we do the heuristic "correctly" by including
    both X and Y.
    Note: most of this code is assuming tensorflow, but X and Y are just ndarrays
    """
    if Y is None:
        Y = X
    if len(X.shape) == 2:
        # matrix
        X_sqnorms = torch.einsum('...i,...i', X, X)
        Y_sqnorms = torch.einsum('...i,...i', Y, Y)
        XY = torch.einsum('ia,ja', X, Y)
    elif len(X.shape) == 3:
        # tensor -- this is computing the Frobenius norm
        X_sqnorms = torch.einsum('...ij,...ij', X, X)
        Y_sqnorms = torch.einsum('...ij,...ij', Y, Y)
        XY = torch.einsum('iab,jab', X, Y)
    distances =X_sqnorms.reshape(-1, 1) - 2 * XY + Y_sqnorms.reshape(1, -1)
    return torch.median(distances)


def _partial_mmd(X, Y=None, bandwidth=None, heuristic=True,multiplier=None):
    l2_dist = pairwise_distance(X, Y)
    # if multiplier is  None:
    #     multiplier=1
    if heuristic:
        heuristic_sigma = median_pairwise_distance(X, Y).detach()
        return torch.mean(multiplier*_rbf(l2_dist, heuristic_sigma))
    else:
        if multiplier is None:
            return torch.mean(_rbf(l2_dist, bandwidth))
        else:
            return torch.mean(multiplier.T@_rbf(l2_dist, bandwidth))


def mmd_loss(real_data, fake_data, bandwidths=[0.5,1,5], heuristic=False,multiplier=None):
    real_data=real_data
    fake_data=fake_data

    if heuristic:
        mmd_gen_real = _partial_mmd(real_data, fake_data, bandwidth=None, heuristic=heuristic)
        mmd_gen = _partial_mmd(fake_data, bandwidth=None, heuristic=heuristic)
        mmd_real = _partial_mmd(real_data, bandwidth=None, heuristic=heuristic)
        mmd = mmd_real - 2 * mmd_gen_real + mmd_gen
    else:
        mmd = 0
        # bandwidths=[0.5,1,5]
        for div in [1]:
            if div!=1:
                div1=int(real_data.size(1)/div)
                ix=np.random.choice(real_data.size(1)-div1,1,replace=False)[0]
                temp_x_real=real_data[:,ix:ix+div1]
                temp_x_false=fake_data[:,ix:ix+div1]
            else:
                temp_x_real=real_data
                temp_x_false=fake_data
            for bandwidth in bandwidths:
                mmd_gen_real = _partial_mmd(temp_x_real, temp_x_false, bandwidth=bandwidth, heuristic=heuristic,multiplier=multiplier)
                mmd_gen = _partial_mmd(temp_x_false, bandwidth=bandwidth, heuristic=heuristic,multiplier=multiplier)
                mmd_real = _partial_mmd(temp_x_real, bandwidth=bandwidth, heuristic=heuristic,multiplier=multiplier)
                mmd += mmd_real - 2 * mmd_gen_real + mmd_gen
    return mmd

def tensor_to_geometric(tensor, batch_size, num_nodes):
    B = batch_size
    N = num_nodes
    SL = tensor.size(1)  # Number of features per node
    reshaped_tensor = tensor.view(B, N, SL)

    data_list = []
    for i in range(B):
        x = reshaped_tensor[i]  # Shape: (N, SL)

        edge_index = torch.combinations(torch.arange(N), r=2).t()
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

        data = Data(x=x, edge_index=edge_index)
        data_list.append(data)

    batch = Batch.from_data_list(data_list)
    return batch


class MyGeometricDataset(Dataset):
    def __init__(self, tensor,inst):
        super(MyGeometricDataset, self).__init__()
        self.tensor = tensor
        self.num_graphs = tensor.shape[0]
        self.inst = inst
        self.init_corr_matrix()


    def init_corr_matrix(self):
        self.edge_index=[]
        self.edge_attr=[]
        for i in tqdm.tqdm(range(self.tensor.size(0))):
            x = self.tensor[i]  # Shape: (l, N)
            x = x.transpose(0, 1).float()
            N = x.size(0)
            corr_matrix = torch.corrcoef(x[:self.inst, :40])
            thresh = 0.15

            edge_index_i = []
            edge_index_j = []
            edge_attrs=[]
            n = corr_matrix.size(0)
            for i in range(n):
                for j in range(n):
                    if i != j and corr_matrix[i, j] > thresh:
                        edge_index_i.append(i)
                        edge_index_j.append(j)
                        cij=corr_matrix[i, j]
                        edge_attrs.append(torch.tensor((cij*10)).long())
            # Convert edge indices to a tensor
            self.edge_index.append( torch.tensor([edge_index_i, edge_index_j], dtype=torch.long))

            self.edge_attr.append(torch.stack(edge_attrs))

    def len(self):
        return self.num_graphs

    def get(self, idx):
        x = self.tensor[idx]  # Shape: (l, N)
        x = x.transpose(0, 1).float()  # Transpose to get shape: (N, l)


        return Data(x=x,x2=x[-13:,:], edge_index=self.edge_index[idx], edge_attr=self.edge_attr[idx])

def compute_loss( d_out, target):
    targets = d_out.new_full(size=d_out.size(), fill_value=target)
    return torch.nn.functional.binary_cross_entropy_with_logits(d_out, targets)
def toggle_grad(model, requires_grad):
    for p in model.parameters():
        p.requires_grad_(requires_grad)

class RepeatSingleIndexSampler(Sampler):
    def __init__(self, index, repeat):
        self.index = index
        self.repeat = repeat

    def __iter__(self):
        return iter([self.index] * self.repeat)

    def __len__(self):
        return self.repeat
def G_trainstep(Generator,Discriminator,G_optimizer, x_fake,bs,validation=False):

    Generator.train()
    G_optimizer.zero_grad()
    x_fake = tensor_to_geometric(x_fake, bs, N).to('cuda').clone()
    d_fake = Discriminator(x_fake.x, x_fake.edge_index, x_fake.batch)
    gloss = compute_loss(d_fake, 1)
    gloss.backward()
    scheduler.step(gloss)

    G_optimizer.step()

    return gloss
def D_trainstep(Discriminator,D_optimizer, x_fake,x_real,bs, validation=False):

    Discriminator.train()
    D_optimizer.zero_grad()
    d_real = Discriminator(x_real.x[:,40:],x_real.edge_index,x_real.batch)
    dloss_real = compute_loss(d_real, 1)
    x_fake = tensor_to_geometric(x_fake, bs, N).to('cuda').clone()
    d_fake = Discriminator(x_fake.x, x_fake.edge_index, x_fake.batch)
    dloss_fake = compute_loss(d_fake, 0)
    dloss = dloss_fake + dloss_real
    dloss.backward()
    D_optimizer.step()
    return dloss



def corr(x_batch: torch.Tensor,full:True) -> torch.Tensor:
    if len(x_batch.shape) == 2:
        x_batch = x_batch.unsqueeze(0)
    n_features = x_batch.shape[1]
    indices = torch.triu_indices(n_features, n_features, 1)
    correlations = batch_corr(x_batch,full)

    return correlations
def train_step(ninst,p,q,Generator, G_optimizer,scheduler,x_fake, x_real,layers_g,ave_grads_g,mm_list,ident=0,weightstab=None,corr_loss=True,epoch=0,step_number=0,option=0,desired_corr_matrix=None,full_desired_corr_matrix=None,desired_vol_mat=None,seed=0):

    desired_corr_matrix=desired_corr_matrix.to(x_real.x.device)
    Generator.train()
    G_optimizer.zero_grad()
    N=ninst
    bs=desired_corr_matrix.size(0)
    NL=int((ninst+13)*q)
    NL2=int((ninst)*q)
    NLix=int((2*13)*q)
    multq=torch.tensor(1., device=x_fake.device)
    N=ninst
    if weightstab is not None:
        xf = x_fake.reshape(bs, N, -1).permute(0, 2, 1)
        indexvalues = torch.matmul(xf[:, :, :-13], weightstab.to(x_fake.device).T)
        xf=torch.cat((xf,indexvalues),dim=2)
        x_fake=xf.permute(0, 2, 1).reshape(-1,q)
        xreal=torch.cat((x_real.x.reshape(bs, N, -1).permute(0, 2, 1)[:,p:p+q],x_real.x2.reshape(bs, 13, -1).permute(0, 2, 1)[:,p:p+q]),dim=2)
        xreal=xreal.permute(0, 2, 1).reshape(-1,q)
        xreal_past=torch.cat((x_real.x.reshape(bs, N, -1).permute(0, 2, 1)[:,:p],x_real.x2.reshape(bs, 13, -1).permute(0, 2, 1)[:,:p]),dim=2)
        xreal_past=xreal_past.permute(0, 2, 1).reshape(-1,p)

        constraint_loss=torch.mean(torch.pow(indexvalues-xf[:, :, -26:-13],2))
    else:
        xreal=x_real.x[:, p:p+q]


        xreal_past=x_real.x[:, :p]


    jump_div = xreal_past.std(dim=1).reshape(-1,1) + 1e-2
    xreal_jump=xreal/jump_div
    xreal_jump[xreal_jump.abs()<5]=0
    xfake_jump=x_fake/jump_div
    xfake_jump[xfake_jump.abs()<5]=0

    if step_number == 0:

            a1=desired_corr_matrix.reshape(bs, -1)
            a2=corr(x_fake.reshape(bs, N+13, -1).permute(0,2,1)[:,:,:-13],False).reshape(bs, -1)
            b_3 = torch.median(ot.dist(a1, a2)).detach().item()


            Generator.b_3 = [b_3 * 0.05, b_3 * .1, b_3 * .5]
            acorr_fake=corr(x_fake.reshape(bs, N+13, -1),True).reshape(bs, -1)
            ac_real=corr(xreal.reshape(bs, N+13, -1),True).reshape(bs, -1)
            b_4 = torch.median(ot.dist(ac_real,
                                       acorr_fake)).detach().item()
            Generator.b_4 = [b_4 * 0.5, b_4 * 1, b_4 * 5]

            b_1s = torch.median(ot.dist(xreal.reshape(bs, N + 13, -1).permute(0, 2, 1).sum(axis=1).reshape(-1, N + 13),
                                 x_fake.reshape(bs, N + 13, -1).permute(0, 2, 1).sum(axis=1).reshape(-1, N + 13))).detach().item()
            b_1sa =torch.median(ot.dist(xreal.reshape(bs, N + 13, -1).permute(0, 2, 1).abs().sum(axis=1).reshape(-1, N + 13),
                                 x_fake.reshape(bs, N + 13, -1).permute(0, 2, 1).abs().sum(axis=1).reshape(-1, N + 13))).detach().item()
            Generator.b_1s = [b_1s * 0.5, b_1s * 1, b_1s * 5]
            Generator.b_1sa = [b_1sa * 0.5, b_1sa * 1, b_1sa * 5]


            b_1six = torch.median(ot.dist(xreal.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1).sum(axis=1).reshape(-1, 2*13),
                                 x_fake.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1).sum(axis=1).reshape(-1, 2*13))).detach().item()
            b_1saix =torch.median(ot.dist(xreal.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1).abs().sum(axis=1).reshape(-1, 2*13),
                                 x_fake.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1).abs().sum(axis=1).reshape(-1, 2*13))).detach().item()
            Generator.b_1six = [b_1six * 0.5, b_1six * 1, b_1six * 5]
            Generator.b_1saix = [b_1saix * 0.5, b_1saix * 1, b_1saix * 5]


            kurtosis_fakeat = kurtosis_torch(x_fake.reshape(bs, N+13, -1))
            kurtosis_realat = kurtosis_torch(xreal.reshape(bs, N+13, -1))

            b_5 = torch.median(ot.dist(kurtosis_fakeat.reshape(bs, -1),
                                       kurtosis_realat.reshape(bs, -1))).detach().item()
            Generator.b_5 = [b_5 * 0.5, b_5 * 1, b_5 * 5]
            acorr_fake_abs = corr(x_fake.abs().reshape(bs, N+13, -1), True)[:, 1:5,:5].reshape(bs, -1)
            ac_real_abs = corr(xreal.abs().reshape(bs, N+13, -1), True)[:, 1:5,:5].reshape(bs, -1)
            b_7 = torch.median(ot.dist(ac_real_abs,
                                       acorr_fake_abs)).detach().item()
            Generator.b_7 = [b_7 * 0.5, b_7 * 1, b_7 * 5]

            b_8 = torch.median(ot.dist(xfake_jump.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL),
                                       xreal_jump.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL))).detach().item()
            Generator.b_8 = [b_8 * 0.5, b_8 * 1, b_8 * 5]


            b_9 = torch.median(ot.dist(xfake_jump.abs().reshape(bs,N+13, -1).permute(0,2,1).reshape(-1,NL),
                                       xreal_jump.abs().reshape(bs,N+13, -1).permute(0,2,1).reshape(-1,NL))).detach().item()
            Generator.b_9 = [b_9 * 0.5, b_9 * 1, b_9 * 5]


            b_10 = torch.median(ot.dist(x_fake.std(dim=1).unsqueeze(1),
                                       xreal.std(dim=1).unsqueeze(1))).detach().item()
            Generator.b_10 = [b_10 * 0.5, b_10 * 1, b_10 * 5]

    corr_fake= corr(x_fake.reshape(bs,N+13, -1).permute(0,2,1)[:,:,:-13],False)

    acorr_real = corr(xreal.reshape(bs,N+13, -1),True)
    acorr_fake = corr(x_fake.reshape(bs,N+13, -1),True)

    acorr_real_abs = corr(xreal.abs().reshape(bs,N+13, -1),True)
    acorr_fake_abs = corr(x_fake.abs().reshape(bs,N+13, -1),True)
    if torch.sum(torch.isnan(mmd_loss(desired_corr_matrix.reshape(bs, -1), corr_fake.reshape(bs, -1),
                  bandwidths=Generator.b_3, heuristic=False)))<0:
        print('nan for dcm')
    c_loss = mmd_loss(desired_corr_matrix.reshape(bs, -1), corr_fake.reshape(bs, -1),
                  bandwidths=Generator.b_3, heuristic=False)


    stdev_loss = mmd_loss(x_fake.std(dim=1).unsqueeze(1), xreal.std(dim=1).unsqueeze(1),
                  bandwidths=Generator.b_10, heuristic=False)

    ac_loss = mmd_loss(acorr_real, acorr_fake, bandwidths=Generator.b_4, heuristic=False)
    ac_loss2=torch.mean(torch.pow(acorr_real-acorr_fake, 2))



    ac_loss_abs = mmd_loss(acorr_real_abs, acorr_fake_abs, bandwidths=Generator.b_7, heuristic=False)
    ac_loss2_abs=torch.mean(torch.pow(acorr_real_abs-acorr_fake_abs, 2))
    corr_real1 = full_desired_corr_matrix.to(x_fake.device)
    corr_fake1 = corr(x_fake.reshape(bs,N+13, -1).permute(0, 2, 1)[:,:,:-13],True)
    if weightstab is not None:

        C_desired_ix = full_desired_corr_matrix[:,-13:,-13:].to(x_fake.device) #torch.clamp(C_index_batch, -1.0, 1.0)
        indices = torch.triu_indices(C_desired_ix.size(2), C_desired_ix.size(2), 1)
        correlations_ix = C_desired_ix[:, indices[0], indices[1]]

        corr_fake1ix_acts = corr(x_fake.reshape(bs, N + 13, -1).permute(0, 2, 1)[:, :, -26:-13], True)
        corr_fake1ix = corr(x_fake.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, -13:], True)
        correlations_fake =corr_fake1ix_acts[:, indices[0], indices[1]]
        correlations_fake1 = corr_fake1ix[:, indices[0], indices[1]]
        c_loss_ix = torch.pow(torch.mean(torch.pow(correlations_ix - correlations_fake, 2)), 0.5)
        if step_number == 0:
            b_3a = torch.median(ot.dist(correlations_ix.reshape(bs, -1),
                                        correlations_fake.reshape(bs, -1)
                                        )).detach().item()
            Generator.b_3a = [b_3a * 0.5, b_3a * 1, b_3a * 5]

            b_3b = torch.median(ot.dist(correlations_ix.reshape(bs, -1),
                                        correlations_fake1.reshape(bs, -1)
                                        )).detach().item()
            Generator.b_3b = [b_3b * 0.5, b_3b * 1, b_3b * 5]

        c_loss_ix1 = mmd_loss(correlations_ix.reshape(bs, -1), correlations_fake.reshape(bs, -1),
                          bandwidths=Generator.b_3a, heuristic=False)
        c_loss_ix2 = mmd_loss(correlations_ix.reshape(bs, -1), correlations_fake1.reshape(bs, -1),
                          bandwidths=Generator.b_3b, heuristic=False)

        eig_realix, eigenvectors_real_ix = torch.linalg.eigh(C_desired_ix)
        eig_fakeix, eigenvectors_fake_ix = torch.linalg.eigh(corr_fake1ix_acts)
        eig_realix = torch.linalg.eigvalsh(C_desired_ix)
        eig_fakeix = torch.linalg.eigvalsh(corr_fake1ix_acts)

        eigenvectors_real_detach = eig_realix.cpu().detach().numpy()
        explained_var = eigenvectors_real_detach / np.sum(eigenvectors_real_detach, axis=1).reshape(-1, 1)
        mask = explained_var < 1e-2
        explained_var[mask] = 0
        idx_scale = np.argmax(np.mean(explained_var, axis=0) != 0)


        eig_lossix = torch.norm(eig_realix[:, :idx_scale] - eig_fakeix[:, :idx_scale], p=2, dim=-1)
        eig_lossix2 = torch.norm(eig_realix[:, idx_scale:] - eig_fakeix[:, idx_scale:], p=2, dim=-1)


        idx = torch.argsort(eig_realix, descending=True)
        idx_expanded = idx.unsqueeze(1).expand(-1, eigenvectors_real_ix.shape[1], -1)  # (B, N, N)

        eigenvectors_real_ix = torch.gather(eigenvectors_real_ix, dim=2, index=idx_expanded)

        standardized_returns_ix_fake = (x_fake.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, ninst:] - x_fake.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, ninst:].mean(dim=1, keepdim=True)) / (x_fake.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, ninst:].std(dim=1, keepdim=True)+1e-5)
        standardized_returns_ix_real = (xreal.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, ninst:] - xreal.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, ninst:].mean(dim=1, keepdim=True)) / (xreal.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, ninst:].std(dim=1, keepdim=True)+1e-5)

        returns_transposed_fake = standardized_returns_ix_fake.transpose(1, 2)  # (B, N, T)
        returns_transposed_real = standardized_returns_ix_real.transpose(1, 2)  # (B, N, T)


        factor_returns_fake_ix = torch.matmul(eigenvectors_fake_ix.transpose(1, 2), returns_transposed_fake).transpose(1, 2)
        factor_returns_real_ix = torch.matmul(eigenvectors_real_ix.transpose(1, 2), returns_transposed_real).transpose(1, 2)

    cf=0.5*(corr_fake1[:,:-13,:-13]+corr_fake1[:,:-13,:-13].transpose(-1, -2))+ 1e-3 * torch.randn(corr_fake1[:,:-13,:-13].size()).to(corr_fake1.device)
    eye = torch.eye(cf.size(1), device=cf.device).expand(cf.size(0), -1, -1)
    cf = cf * (1 - eye) + eye
    eig_real, eigenvectors_real = torch.linalg.eigh(corr_real1[:,:-13,:-13])
    eig_fake, eigenvectors_fake = torch.linalg.eigh(cf)

    eigvals, indices = torch.sort(eig_fake, descending=True)
    eigvecs = eigenvectors_fake[:, indices]

    k = 10
    eigvals_k = eigvals[:k]                        # shape: (k,)
    eigvecs_k = eigvecs[:, :k]                     # shape: (N, k)

    eiglosscheck=torch.isnan(eig_fake).sum().item()
    if eiglosscheck>0:
        print('Nans in fake corr',eiglosscheck)

    eigenvectors_real_detach = eig_real.cpu().detach().numpy()
    explained_var = eigenvectors_real_detach / np.sum(eigenvectors_real_detach,axis=1).reshape(-1,1)
    mask=explained_var<1e-2
    explained_var[mask]=0
    idx_scale = np.argmax(np.mean(explained_var,axis=0) != 0)
    if eiglosscheck==0:
        eig_loss = torch.norm(eig_real[:, :idx_scale] - eig_fake[:, :idx_scale], p=2, dim=-1)
        eig_loss2 = torch.norm(eig_real[:, idx_scale:] - eig_fake[:, idx_scale:], p=2, dim=-1)
    idx = torch.argsort(eig_real, descending=True)
    idx_expanded = idx.unsqueeze(1).expand(-1, eigenvectors_real.shape[1], -1)  # (B, N, N)

    eigenvectors_real = torch.gather(eigenvectors_real, dim=2, index=idx_expanded)

    standardized_returns_fake = (x_fake.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, :-26] - x_fake.reshape(bs, N+13,-1).permute(0,2,1)[ :, :, :-26].mean(dim=1,keepdim=True)) /( x_fake.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, :-26].std(dim=1, keepdim=True)+1e-5)
    standardized_returns_real = (xreal.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, :-26] - xreal.reshape(bs, N+13,     -1).permute(0,  2,    1)[     :, :, :-26].mean(dim=1, keepdim=True)) / (xreal.reshape(bs, N+13, -1).permute(0, 2, 1)[:, :, :-26].std(dim=1, keepdim=True)+1e-5)

    X = standardized_returns_real  # shape (T, N)
    returns_transposed_fake = standardized_returns_fake.transpose(1, 2)  # (B, N, T)
    returns_transposed_real = standardized_returns_real.transpose(1, 2)



    factor_returns_fake = torch.matmul(eigenvectors_fake.transpose(1, 2), returns_transposed_fake).transpose(1, 2)
    factor_returns_real = torch.matmul(eigenvectors_real.transpose(1, 2), returns_transposed_real).transpose(1, 2)
    factor_returns_real=torch.cat((factor_returns_real,factor_returns_real_ix),dim=2)
    factor_returns_fake=torch.cat((factor_returns_fake,factor_returns_fake_ix),dim=2)
    k=4
    topkloss=0
    for k in [1,5,10,20,30,40,50,60,70,80,100]:
        id=torch.eye(corr_real1.size(2)).reshape((1, corr_real1.size(2), corr_real1.size(2))).repeat(corr_real1.size(0), 1, 1).to(x_fake.device)
        corr_real1=corr_real1-id
        corr_fake1=corr_fake1-id
        real_topk_values, real_topk_indices = torch.topk(corr_real1.abs(), k, dim=1)
        fake_topk_values, fake_topk_indices = torch.topk(corr_fake1.abs(), k, dim=1)
        mask_real = torch.zeros_like(corr_real1)
        mask_real.scatter_(1, real_topk_indices, 1)


        mask_fake = torch.zeros_like(corr_fake1)
        mask_fake.scatter_(1, fake_topk_indices, 1)
        mask_fake=mask_fake*corr_fake1+mask_fake-mask_fake
        mask_real=mask_real*corr_real1+mask_real-mask_real
        topkloss1=torch.mean(torch.sign(torch.abs(mask_fake-mask_real)))
        if k==1:
            mstloss=topkloss1.item()
        topkloss+=topkloss1/(np.log(k)+1)


    topklossix=0
    for k in range(1,14,1):
        id=torch.eye(13).reshape((1, 13, 13)).repeat(corr_real1.size(0), 1, 1).to(x_fake.device)
        corr_real1=corr_real1[:,-13:,-13:]-id
        corr_fake1=corr_fake1[:,-13:,-13:]-id
        real_topk_values, real_topk_indices = torch.topk(corr_real1.abs(), k, dim=1)
        fake_topk_values, fake_topk_indices = torch.topk(corr_fake1.abs(), k, dim=1)
        mask_real = torch.zeros_like(corr_real1)
        mask_real.scatter_(1, real_topk_indices, 1)


        mask_fake = torch.zeros_like(corr_fake1)
        mask_fake.scatter_(1, fake_topk_indices, 1)
        mask_fake=mask_fake*corr_fake1+mask_fake-mask_fake
        mask_real=mask_real*corr_real1+mask_real-mask_real
        topkloss1x=torch.mean(torch.sign(torch.abs(mask_fake-mask_real)))
        if k==1:
            mstloss=topkloss1x.item()
        topklossix+=topkloss1x/(np.log(k)+1)


    c_loss2=torch.pow(torch.mean(torch.pow(desired_corr_matrix - corr_fake, 2)),0.5)



    c_loss3=torch.pow(torch.mean(torch.pow(F.relu(desired_corr_matrix-.3) - F.relu(corr_fake-.3), 2)),0.5)
    c_loss3=c_loss3+torch.pow(torch.mean(torch.pow(F.relu(-1*desired_corr_matrix-.3) - F.relu(-1*corr_fake-.3), 2)),0.5)

    ioui=topkloss.item()

    if step_number==0:
        b_1=torch.median(ot.dist(xreal.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL), x_fake.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL))).detach().item()
        if desired_vol_mat is not None:
            b_2=torch.median(ot.dist(desired_vol_mat.float().reshape(-1,NL), torch.abs(x_fake.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL)))).detach().item()

        else:
            b_2=torch.median(ot.dist(torch.abs(xreal.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL)), torch.abs(x_fake.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL)))).detach().item()
        Generator.b_1=[b_1*0.5,b_1*1,b_1*5]
        Generator.b_2=[b_2*0.5,b_2*1,b_2*5]

        b_1ix=torch.median(ot.dist(xreal.reshape(bs, N+13, -1)[:,-26:].permute(0,2,1).reshape(-1,NLix), x_fake.reshape(bs, N+13, -1)[:,-26:].permute(0,2,1).reshape(-1,NLix))).detach().item()

        b_2ix=torch.median(ot.dist(torch.abs(xreal.reshape(bs, N+13, -1)[:,-26:].permute(0,2,1).reshape(-1,NLix)), torch.abs(x_fake.reshape(bs, N+13, -1)[:,-26:].permute(0,2,1).reshape(-1,NLix)))).detach().item()
        Generator.b_1ix=[b_1ix*0.5,b_1ix*1,b_1ix*5]
        Generator.b_2ix=[b_2ix*0.5,b_2ix*1,b_2ix*5]


        b_11=torch.median(ot.dist(factor_returns_real.reshape(bs, N, -1).permute(0,2,1).reshape(-1,NL2),factor_returns_fake.reshape(bs, N, -1).permute(0,2,1).reshape(-1,NL2))).detach().item()

        b_22=torch.median(ot.dist(torch.abs(factor_returns_real.reshape(bs, N, -1).permute(0,2,1).reshape(-1,NL2)), torch.abs(factor_returns_fake.reshape(bs, N, -1).permute(0,2,1).reshape(-1,NL2)))).detach().item()
        Generator.b_11=[b_11*0.005,b_11*.01,b_11*.05]
        Generator.b_22=[b_22*0.005,b_22*.01,b_22*.05]



    gloss = mmd_loss(xreal.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL), x_fake.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL),bandwidths=Generator.b_1,heuristic=False)*multq


    glossix = mmd_loss(xreal.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1).reshape(-1, NLix),
                 x_fake.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1).reshape(-1, NLix), bandwidths=Generator.b_1ix,
                 heuristic=False) * multq

    gloss_eig = mmd_loss(factor_returns_real.reshape(bs, N, -1).permute(0, 2, 1).reshape(-1, NL2),
                 factor_returns_fake.reshape(bs, N, -1).permute(0, 2, 1).reshape(-1, NL2), bandwidths=Generator.b_11, heuristic=False)


    glosssum = mmd_loss(xreal.reshape(bs, N + 13, -1).permute(0, 2, 1).sum(axis=1).reshape(-1,  N + 13),
                     x_fake.reshape(bs, N + 13, -1).permute(0, 2, 1).sum(axis=1).reshape(-1,  N + 13), bandwidths=Generator.b_1s,
                     heuristic=False) * multq
    glossasum = mmd_loss(xreal.reshape(bs, N + 13, -1).permute(0, 2, 1).abs().sum(axis=1).reshape(-1,  N + 13),
                     x_fake.reshape(bs, N + 13, -1).permute(0, 2, 1).abs().sum(axis=1).reshape(-1,  N + 13), bandwidths=Generator.b_1sa,
                     heuristic=False) * multq

    glossa = mmd_loss(torch.abs(xreal.reshape(bs, N+13, -1).permute(0,2,1).reshape(-1,NL)),
                  torch.abs(x_fake.reshape(bs, N+13, -1).permute(0, 2, 1).reshape(-1, NL)),
                  bandwidths=Generator.b_2, heuristic=False)*multq


    glossaix = mmd_loss(torch.abs(xreal.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1).reshape(-1, NLix)),
                  torch.abs(x_fake.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1).reshape(-1, NLix)),
                  bandwidths=Generator.b_2ix, heuristic=False) * multq

    glossa_eig = mmd_loss(torch.abs(factor_returns_real.reshape(bs, N, -1).permute(0, 2, 1).reshape(-1, NL2)),
                  torch.abs(factor_returns_fake.reshape(bs, N, -1).permute(0, 2, 1).reshape(-1, NL2)),
                  bandwidths=Generator.b_22, heuristic=False)

    jump_loss_euclid=torch.mean(torch.pow((xfake_jump-xreal_jump),2))
    jump_loss_euclid_abs=torch.mean(torch.pow((xfake_jump.abs()-xreal_jump.abs()),2))
    jump_loss = mmd_loss(
        xfake_jump.reshape(bs, N+13, -1).permute(0, 2, 1).reshape(-1, NL),
        xreal_jump.reshape(bs, N+13, -1).permute(0, 2, 1).reshape(-1, NL), bandwidths=Generator.b_8, heuristic=False)


    del factor_returns_fake, factor_returns_real, factor_returns_real_ix, factor_returns_fake_ix

    jump_lossa = mmd_loss(
        torch.abs(xfake_jump).reshape(bs, N+13, -1).permute(0, 2, 1).reshape(-1, NL),
        torch.abs(xreal_jump).reshape(bs, N+13, -1).permute(0, 2, 1).reshape(-1, NL), bandwidths=Generator.b_9, heuristic=False)
    std_loss = torch.pow(torch.mean(torch.pow(xreal.reshape(bs, N+13, -1).permute(0,2,1).std(axis=1)-x_fake.reshape(bs, N+13, -1).permute(0, 2, 1).std(axis=1), 2)), 0.5)


    std_loss_ix = torch.pow(torch.mean(torch.pow(
    xreal.reshape(bs, N + 13, -1)[:,-26:,].permute(0, 2, 1).std(axis=1) - x_fake.reshape(bs, N + 13, -1)[:,-26:,].permute(0, 2, 1).std(
        axis=1), 2)), 0.5)

    kurtosis_fakeat = kurtosis_torch(x_fake.reshape(bs, N+13, -1))
    kurtosis_realat = kurtosis_torch(xreal.reshape(bs, N+13, -1))

    kurtosis_fake = kurtosis_torch(x_fake.reshape(bs, N+13, -1).permute(0,2,1))
    kurtosis_real = kurtosis_torch(xreal.reshape(bs, N+13, -1).permute(0,2,1))


    kurtosis_fake_ix = kurtosis_torch(x_fake.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1))
    kurtosis_real_ix = kurtosis_torch(xreal.reshape(bs, N + 13, -1)[:,-26:].permute(0, 2, 1))


    k_loss3_ix = torch.pow(torch.mean(torch.pow(kurtosis_fake_ix - kurtosis_real_ix, 2)), 0.5)


    k_loss = torch.mean(torch.stack([mmd_loss(kurtosis_fakeat[i].reshape(-1, 1),
                                         kurtosis_realat[i].reshape(-1, 1),
                                         bandwidths=Generator.b_5, heuristic=False) for i in range(bs)]))

    k_loss2 = torch.pow(torch.mean(torch.pow(kurtosis_fakeat - kurtosis_realat, 2)), 0.5)


    k_loss3 = torch.pow(torch.mean(torch.pow(kurtosis_fake - kurtosis_real, 2)), 0.5)

    skew_fakeat = skew_torch(x_fake.reshape(bs, N+13, -1))
    skew_realat = skew_torch(xreal.reshape(bs, N+13, -1))


    s_lossat = torch.pow(torch.mean(torch.pow(skew_fakeat - skew_realat, 2)), 0.5)


    skew_fake = skew_torch(x_fake.reshape(bs, N+13, -1).permute(0,2,1))
    skew_real = skew_torch(xreal.reshape(bs, N+13, -1).permute(0,2,1))


    s_loss = torch.pow(torch.mean(torch.pow(skew_fake - skew_real, 2)), 0.5)


    generated_flat = x_fake.reshape(-1, 1)
    var = generated_flat.std(dim=0, unbiased=False)  # per-feature variance

    total_var = var.mean()  # or .sum() depending on how strong you want penalty

    diversity_loss = -total_var

    del var,generated_flat,total_var
    gl=gloss.item()
    gla=glossa.item()
    gl_e=gloss_eig.item()
    gla_e=glossa_eig.item()
    cl2=c_loss2.item()
    cl=c_loss.item()
    kl=k_loss.item()
    skl=s_loss.item()
    kl2=k_loss2.item()
    stdl=k_loss.item()
    acl=ac_loss2.item()
    sumloss=torch.sum(F.relu(x_fake.sum(axis=1).abs()-(x_fake.size(1)*.05)))-F.relu(torch.sum(F.relu(x_fake.sum(axis=1).abs()-(x_fake.size(1)*.05)))-1000)
    sumloss2=torch.sum(F.relu(x_fake.sum(axis=1)*-0.99-1))-F.relu(torch.sum(F.relu(x_fake.sum(axis=1)*-0.99-1))-1000)
    eigl=eig_loss.mean().item()
    eiglix=eig_lossix.mean().item()
    div_loss=diversity_loss.item()
    c_l_ix=c_loss_ix.item()
    j_l= jump_loss.item()
    j_la= jump_lossa.item()
    j_l_e=jump_loss_euclid.item()
    j_l_e_a=jump_loss_euclid_abs.item()
    s_lossat1=s_lossat.item()
    topklossixa=topklossix.item()
    c_loss_ix1_=c_loss_ix1.item()
    c_loss_ix2_=c_loss_ix2.item()

    sl1=sumloss.item()
    sl2=sumloss2.item()
    glix=glossix.item()
    glaix=glossaix.item()
    glsum=glosssum.item()
    glasum=glossasum.item()
    if corr_loss:
        if (step_number+1)%2==0:
            print('ident', ident)
            print('q',q)
            print('desired_corr_matrix',torch.isnan(desired_corr_matrix).sum().item())
            print('Gloss:',gloss.item())
            print('glossa:', glossa.item())
            print('Gloss_e:', gloss_eig.item())
            print('glossa_e:', glossa_eig.item())
            print('glossix:', glossix.item())
            print('glossaix:', glossaix.item())
            print('glosssum:', glosssum.item())
            print('glossasum:', glossasum.item())
            print('c_loss2:', c_loss2.item())
            print('c_loss3:', 10*c_loss3.item())
            print('j_loss:',  jump_loss.item())
            print('j_loss_a:', jump_lossa.item())
            print('jump_loss_euclid:',jump_loss_euclid.item())
            print('jump_loss_euclid_abs:',jump_loss_euclid_abs.item())
            # print('Gloss:', gloss.item())
            print('topkloss:', topkloss.item())
            print('topklossixa:', topklossixa)
            print('std_loss:', std_loss.item())
            print('std_loss_ix',std_loss_ix.item())
            print('stdev_loss',stdev_loss.item())
            # print('mstloss:', mstloss)
            print('ac_loss2:', ac_loss2.item())
            print('ac_loss:', ac_loss.item())
            # print('ac_loss2_abs:', ac_loss2_abs.item())
            # print('ac_loss_abs:', ac_loss_abs.item())
            print('c_loss:', c_loss.item())
            print('constraint_loss',constraint_loss.item())
            print('eig_loss',eig_loss.mean().item())
            print('eig_loss2', eig_loss2.mean().item())
            print('diversity_loss', diversity_loss.item())
            if weightstab is not None:
                print('c_loss ix:', c_loss_ix.item())
                print('c_loss ix1:', c_loss_ix1.item())
                print('c_loss ix2:', c_loss_ix2.item())
                print('eig_loss ix', eig_lossix.mean().item())
                print('eig_los2s ix', eig_lossix2.mean().item())
            print('k_loss2:', k_loss2.item())
            print('k_loss3:', k_loss3.item())
            print('k_loss3:', k_loss3_ix.item())
            print('s_loss:', s_loss.item())

    mults=[1 for x in range(30)]


    gloss =  torch.clamp(
    constraint_loss * mults[2] * 1e4, max=1. * mults[2]) +torch.clamp(gloss * mults[8] * 30, max=3.0 * mults[8]) + torch.clamp(
    c_loss2 * mults[9] * 8., max=3.0 * mults[9]) + torch.clamp(c_loss * 500. * mults[10],
                                                               max=3.0 * mults[10]) + sumloss + sumloss2 + torch.clamp(
    10 * c_loss3 * mults[11], max=3.0 * mults[11])
    gloss=gloss+torch.clamp(glossa_eig,max=1.0)+torch.clamp(gloss_eig*10,max=1.0)+torch.clamp(eig_lossix.mean(),max=1.0)+torch.clamp(eig_lossix2.mean(),max=1.0)+torch.clamp(eig_loss2.mean()*.01,max=2.5)+torch.clamp(eig_loss.mean()*5,max=2.5)
    gloss=2*topkloss*mults[19]+gloss+torch.clamp(glossix*mults[20]*.5,max=1.*mults[20])+torch.clamp(glossaix*mults[21]*.5,max=1.*mults[21])+torch.clamp(glossa*mults[22]*5,max=4.0*mults[22])+torch.clamp(k_loss3_ix*5*mults[23],max=2.0*mults[23])+torch.clamp(k_loss3*3.5*mults[24],max=7.0*mults[24])+torch.clamp(s_lossat*mults[25],max=1.0*mults[25])+torch.clamp(20.*(2*jump_loss+2*jump_lossa+jump_loss_euclid+jump_loss_euclid_abs)*mults[26],max=2.0*mults[26])#+diversity_loss
    gloss=torch.clamp(glosssum*.4,max=2.)+torch.clamp(glossasum*.4,max=2.)+torch.clamp(constraint_loss*1e4,max=1.)+torch.clamp(std_loss_ix*100,max=1.)+torch.clamp(std_loss*10,max=1.)+torch.clamp(stdev_loss*20,max=1.)+torch.clamp(topkloss*25,max=8.0)+torch.clamp(topklossix,max=4.5)+torch.clamp(gloss*30,max=3.0)+torch.clamp(c_loss2*8.,max=3.0)+torch.clamp(c_loss*500.,max=3.0)+sumloss+sumloss2+torch.clamp(10*c_loss3,max=3.0) +torch.clamp(ac_loss*20,max=1.25)+torch.clamp(ac_loss2*20.,max=1.25)+torch.clamp(ac_loss_abs,max=1.0)+torch.clamp(ac_loss2_abs*20.,max=1.0)+torch.clamp(k_loss*500.,max=1.0)+torch.clamp(k_loss2*.5,max=2.0)+torch.clamp(s_loss,max=1.0)

    # loss ablation

    # basket_index_loss=torch.clamp(glosssum * mults[0] * .4, max=2. * mults[0]) + torch.clamp(glossasum * mults[1] * .4,
    #                                                                                max=2. * mults[1]) + torch.clamp(
    #     constraint_loss * mults[2] * 1e4, max=1. * mults[2])
    # standard_dev_loss=torch.clamp(std_loss * mults[4] * 10, max=1. * mults[4]) + torch.clamp(
    #     stdev_loss * mults[5] * 20, max=1. * mults[5])
    # corr_loss_excl_topk=torch.clamp(
    #     c_loss2 * mults[9] * 8., max=3.0 * mults[9]) + torch.clamp(c_loss * 500. * mults[10],
    #                                                                max=3.0 * mults[10]) +torch.clamp(
    #     10 * c_loss3 * mults[11], max=3.0 * mults[11])
    # acfloss= torch.clamp(ac_loss * mults[12] * 20,
    #                                                                  max=1.25 * mults[12]) + torch.clamp(
    #     ac_loss2 * mults[13] * 20., max=1.25 * mults[13]) + torch.clamp(ac_loss_abs * mults[14],
    #                                                                     max=1.0 * mults[14]) + torch.clamp(
    #     ac_loss2_abs * mults[15] * 20., max=1.0 * mults[15])
    # highermomentsloss=torch.clamp(k_loss * mults[16] * 500.,
    #                                                                        max=1.0 * mults[16]) + torch.clamp(
    #     k_loss2 * mults[17] * .5, max=2.0 * mults[17]) + torch.clamp(s_loss * mults[18], max=1.0 * mults[18])+torch.clamp(k_loss3*3.5,max=7.0)+torch.clamp(s_lossat,max=1.0)+torch.clamp(20.*(2*jump_loss+2*jump_lossa+jump_loss_euclid+jump_loss_euclid_abs),max=2.0)
    # topkloss=torch.clamp(topkloss * mults[6] * 25,
    #                                                                  max=8.0 * mults[6])
    # gloss =    torch.clamp(gloss * mults[8] * 30, max=3.0 * mults[8]) + sumloss + sumloss2 +torch.clamp(glossa*5,max=4.0)
    # eig_loss_all=+torch.clamp(glossa_eig,max=1.0)+torch.clamp(gloss_eig*10,max=1.0)+torch.clamp(eig_loss2.mean()*.01,max=2.5)+torch.clamp(eig_loss.mean()*5,max=2.5)
    # total_loss=gloss+corr_loss_excl_topk
    # ixonlyloss=torch.clamp(eig_lossix.mean(),max=1.0)+torch.clamp(eig_lossix2.mean(),max=1.0)+ torch.clamp(
    #     topklossix * mults[7], max=4.5 * mults[7])+torch.clamp(k_loss3_ix*5,max=2.0)+c_loss_ix
    # if seed%6!=0:
    #     total_loss=total_loss+topkloss
    # elif seed%6!=1:
    #     total_loss = total_loss + acfloss
    # elif seed % 6 != 2:
    #     total_loss = total_loss + eig_loss_all
    # elif seed % 6 != 3:
    #     total_loss = total_loss + standard_dev_loss
    # elif seed % 6 != 4:
    #     total_loss = total_loss + basket_index_loss
    # elif seed % 6 != 5:
    #     total_loss = total_loss + ixonlyloss
    #
    #
    # gloss=total_loss

    save_gradients = False
    print('gloss:', gloss.item())

    gloss.backward()

    if save_gradients:
        mm_list[step_number]=[]
        mm=[[x.min().item(), x.max().item(), x.sum().item()] for x in Generator.parameters()]
        mm_list[step_number].append(mm)
        # print('saving gradients of generator')
        ave_grads_g[step_number] = []
        layers_g[step_number] = []
        layer_number = 0
        for n, p in Generator.named_parameters():
            if (p.requires_grad) and ("bias" not in n):
                try:
                    average_gradient = p.grad.abs().mean()
                except:
                    average_gradient = torch.tensor(np.nan)
                layers_g[step_number].append(n)
                ave_grads_g[step_number].append(average_gradient.item())
                # writer.add_scalar(str(n) + '_g', average_gradient.item(), step)
                # writer.add_scalar('Avg_gradient_layer_g_' + str(layer_number), average_gradient.item(), step)
                layer_number = layer_number + 1
    torch.nn.utils.clip_grad_norm_(Generator.parameters(), max_norm=2.)

    G_optimizer.step()
    scheduler.step(gloss)

    constraint_loss1=constraint_loss.item()
    stdev_loss1=stdev_loss.item()
    # ema_generator = copy.deepcopy(Generator)
    # for param in ema_generator.parameters():
    #     param.requires_grad = False
    std_loss1=std_loss.item()
    alpha=0.99
    # update_ema(ema_generator,Generator , alpha)
    del c_loss,glossa,corr_fake,ac_loss,gloss,k_loss3,acorr_fake,acorr_real,c_loss2,ac_loss2,topkloss,topkloss1,k_loss,k_loss2,constraint_loss,std_loss,s_loss,s_lossat,eig_loss,jump_loss,jump_lossa,jump_loss_euclid,jump_loss_euclid_abs,stdev_loss
    # del c_loss,glossa,corr_fake1,corr_fake,corr_real1,ac_loss,gloss,acorr_fake,acorr_real
    if weightstab is not  None:
        del eig_lossix,eig_realix,eig_fakeix,correlations_ix,correlations_fake,c_loss_ix,glossa_eig,gloss_eig,ac_loss_abs
    del eig_fake,eig_real,eigenvectors_real,  standardized_returns_real, standardized_returns_fake


    torch.cuda.empty_cache()
    gc.collect(generation=2)


    step_number = step_number + 1

    return [ioui,gl,gla,cl,cl2,kl,mstloss,stdl,skl,sl1,sl2,acl,kl2,eigl,eiglix,div_loss,c_l_ix,j_l,j_la,j_l_e,j_l_e_a,s_lossat1,gl_e,gla_e,constraint_loss1,std_loss1,stdev_loss1,glix,glaix ,glsum ,glasum,c_loss_ix1_,c_loss_ix2_], ave_grads_g, layers_g,mm_list, step_number
def update_ema(generator, ema_generator, alpha=0.999):
    with torch.no_grad():
        for ema_param, param in zip(ema_generator.parameters(), generator.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
if __name__ == '__main__':
    bin_dict = {5: [[-1.01, -.65], [-0.65, -.43],[-.43,  -.2],[-.2,.2], [0.2, .34], [0.34, .5], [0.5, 0.75], [0.75, 1.01]],
                10: [[-1.01, -.5], [-.5, -.3], [-.3,  -.12],[-.12,.12], [0.12, .25], [0.25, 0.4], [0.4, .55], [0.55, 1.01]],
                20: [[-1.01, -.34], [-.34, -.2], [-.2,  -.1],[-.1,.1], [0.1, .21], [0.21, 0.35], [0.35, 0.52], [0.52, 1.01]],
                40: [[-1.01, -.25], [-.25, -.15], [-.15, -.08],[-.08,.08], [0.08, .2], [0.2, .3], [0.3, .45], [0.45, 1.01]]}
    ss=101
    for seed in range(ss,ss+42,1):
        correlatedn=True
        torch.manual_seed(seed)
        np.random.seed(seed)
        generate=False
        corrtruesplit=False
        ninst=392
        ninst2=392
        algo_id='FiLMConv'
        base_dir='./TensorboardFiles'
        p=sl2=20
        datasetname='sandi'
        actual=0
        notreal=False
        N=ninst2
        qmax=250
        T=25000
        zm=.1
        zsharedm=0.1
        zedgem=0.1
        learning_rate=1e-3
        beta=(0.9,0.99)
        mu=0.0
        includeevents=False
        harris=False
        return_FiLM = False
        trump=False
        data_params={}
        key_settings = datasetname + 'N_' + str(N) + '_sl2_' + str(sl2)

        experiment_directory = pt.join(base_dir, datasetname,  'seed={}'.format(seed), algo_id,key_settings)
        tensorboard_directory = pt.join(base_dir,datasetname, 'seed={}'.format(seed), algo_id, 'TB')
        plt_directory = pt.join(base_dir, 'PLOTS', datasetname, 'seed={}'.format(seed),
                                algo_id)
        gen_dir = pt.join(base_dir, 'Generations', datasetname, 'seed={}'.format(seed),
                                algo_id)
        model_dir = pt.join(base_dir, 'Models', datasetname, 'seed={}'.format(seed),
                                algo_id)

        if not pt.exists(gen_dir):
            os.makedirs(gen_dir)

        if not pt.exists(model_dir):
            os.makedirs(model_dir)
        if not pt.exists(experiment_directory):
            os.makedirs(experiment_directory)
        if not pt.exists(tensorboard_directory):
            os.makedirs(tensorboard_directory)
        if not pt.exists(plt_directory):
            # if the experiment directory does not exist we create the directory
            os.makedirs(plt_directory)
        if actual==1:
            actst='_act_'
        else:
            actst=''

        if not pt.exists(datasetname + 'x_real312lg'+actst+'.pt'):
            if datasetname not in ['MGBM']:
                    if datasetname in ['NGARCH2']:
                        x_real, x_real1, real_correlation_matrix = get_data(datasetname,  p, q,actual,
                                                                            **data_params)
                    else:
                        x_real,x_real1,xcorr1 = get_data(datasetname, p, qmax,actual, **data_params)
                            # xcorr1=None

            x_real1=x_real1[:,:15000]
            x_real=x_real[:15000]
            torch.save(x_real, datasetname + 'x_real312lg'+actst+'.pt')
        else:

            x_real=torch.load( datasetname+'x_real312lg'+actst+'.pt')

        sprint=False
        if sprint:
            addsprint='_sprint'
        else:
            addsprint=''
        if actual!=0:
            addsprint =addsprint+ '_act'
        if not pt.exists(datasetname + 'ixTrainMyGeometricDataset1121lg'+str(ninst)+addsprint+'.pt'):
            if sprint:
                trainlen=int(x_real.size(0)*.005)
                dataset_test = MyGeometricDataset(x_real[-trainlen - 260:, :, :ninst], ninst2)
            else:
                if actual!=0:
                    dataset_test = MyGeometricDataset(x_real[:, :, :ninst], ninst2)
                    dataset = MyGeometricDataset(x_real[:, :, :ninst], ninst2)
                else:
                    trainlen = int(x_real.size(0) * .75)
                    dataset_test = MyGeometricDataset(x_real[trainlen + 260:, :, :ninst], ninst2)
            torch.save(dataset_test, datasetname + 'ixTestMyGeometricDataset1121lg' + str(ninst)+addsprint + '.pt')
            if actual == 0:
                dataset = MyGeometricDataset(x_real[:trainlen,:,:ninst],ninst2)

            torch.save(dataset, datasetname+'ixTrainMyGeometricDataset1121lg'+str(ninst)+addsprint+'.pt')
        else:
            dataset=torch.load( datasetname+'ixTrainMyGeometricDataset1121lg'+str(ninst)+addsprint+'.pt')
            dataset_test = torch.load(datasetname + 'ixTestMyGeometricDataset1121lg' + str(ninst) + addsprint + '.pt')
        if generate:


            dataset = torch.load(datasetname + 'ixTestMyGeometricDataset1121lg' + str(ninst)+addsprint+'.pt')
            if actual != 0:

                bs = 5  # How many times to repeat it
                ix1 = 0
                sampler11 = RepeatSingleIndexSampler(ix1, bs)
                loader_rep = DataLoader( dataset, sampler=sampler11, batch_size=bs)
                loader = iter(loader_rep)


            else:
                loader = DataLoader(dataset, batch_size=20, shuffle=False)
        else:
            loader = DataLoader(dataset, batch_size=20, shuffle=True)

        if generate:
            num_epochs=1
        else:
            num_epochs=20
        num_features=sl2+1
        binlist=bin_dict[10]
        bins = torch.tensor(
            binlist)
        generator = ARFNN_Net(in_channels=num_features, hidden_channels=64,
                    out_channels=1, num_layers=1 ,
                    dropout=0.0,corr_thresh=.2,num_relations=len(bins)+1).to('cuda')
        if generate:
            modelfiles = [x for x in os.listdir(model_dir) if '.pth' in x]
            last_weights =   [x for x in modelfiles if int(x.split('_')[0]) ==np.max(
                                [int(x.split('_')[0]) for x in modelfiles])]

            generator_path = model_dir + '\\' + last_weights[0]
            generator = torch.load(generator_path)
            generator.eval()
        def count_parameters(model):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)


        print(f"Number of parameters: {count_parameters(generator)}")
        print(f"Number of parameters: {count_parameters(generator)}")

        for name, param in generator.named_parameters():
            if param.requires_grad:
                print(f"Parameter: {name}, Shape: {param.shape}, Number of elements: {param.numel()}")

        G_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=beta)
        scheduler = ReduceLROnPlateau(G_optimizer, 'min', factor=.99, patience=50, min_lr=5e-5)

        D_steps_per_G_step = 2
        generator.to('cuda')
        training_starttime = datetime.datetime.now()
        layers_g = {}
        ave_grads_g = {}
        mm_list = {}
        step_number = 0
        desf_dict={}
        des_dict = {}
        desf_dicta={}
        des_dicta = {}
        desf_dictb={}
        des_dictb = {}
        # includeevents=True
        # harris=False
        # return_FiLM = False
        # trump=False


        des_idx_dict={}
        if not includeevents:
            for t in [5,10,20,40]:
                desf_dict[t]=torch.load('disp_range_pcorr_mat_'+str(t)+'.pt').float()
                des_dict[t] = torch.load('disp_range_corr_mat_' + str(t) + '.pt').float()

                a = pd.DataFrame(des_dict[t].mean(dim=1).detach().cpu().numpy())
                a['Type'] = pd.qcut(a[0], 6, labels=range(6))
                a['Type'] = a['Type'].astype(float)
                des_idx_dict[t]={'Type':a}
        else:
            t=10
            eventtime=12
            if harris:
                desf_dict[20] = torch.load('harris_a_pcorr15_mat_' + str(t) + '.pt').float()
                des_dict[20] = torch.load('harris_a_corr15_mat_' + str(t) + '.pt').float()
            elif trump:
                desf_dict[20] = torch.load('trump_a_pcorr15_mat_' + str(t) + '.pt').float()
                des_dict[20] = torch.load('trump_a_corr15_mat_' + str(t) + '.pt').float()
            else:
                desf_dict[20] = torch.load('all_pcorr.pt').float()
                des_dict[20] = torch.load('all_corr.pt').float()
                desf_dictb[20] = torch.load('before_pcorr.pt').float()
                des_dictb[20] = torch.load('before_corr.pt').float()
                desf_dicta[20] = torch.load('after_pcorr.pt').float()
                des_dicta[20] = torch.load('after_corr.pt').float()


        t=8
        desired_vol=torch.load(r"C:\Users\username\PycharmProjects\GPS\timeseriessynth\EarningsAndElectionsVol.pt")#desired_volsnormflat
        metric_dict={}
        metric_dict1 = {}
        mcounter=0
        weightstab=torch.tensor(pd.read_hdf('weightstable.h5').values).float().to(x_real.device)
        barrier_dict={5:0.05,10:.03,20:0.02,40:0.01}#got from random matric 35/65 percenitles
        rands=[]
        for epoch in range(num_epochs):

            start_time=datetime.datetime.now()
            ioulist=[]
            goodchoices=[]

            if not generate:
                generator.train()
            usetrainloader=True
            nstepstrain=len(loader)-1
            loadertouse=loader

            for step,batch in tqdm.tqdm(enumerate(loadertouse)):
                if epoch % 2 == 0:
                    stepper = 3
                else:
                    stepper = 5
                if torch.isnan(batch.x).sum().item()!=0:
                    print('nan')
                x_real=batch.to('cuda')
                bs=x_real.batch_size
                if includeevents:
                    q=20
                else:
                    if epoch==0:
                        q = np.random.choice([10,20])
                    else:
                        if nsteps==step:
                            q= np.random.choice([20])
                        elif q<=10:
                            q = np.random.choice([20])
                        else:
                            q=np.random.choice([10,20])
                if q>40:
                    q2=40
                else:
                    q2=q
                barrier = barrier_dict[q2]
                print('Barrier',barrier)
                bins = torch.tensor(
                  binlist)
                print('Len Bins',len(bins))
                q111 = q
                NL = int(N * q)
                z = torch.randn(bs * N, q).to('cuda')
                z_jump = torch.randn(bs * N, q).to('cuda')
                chi2_dist = torch.distributions.Chi2(df=2.5)
                z_jump = chi2_dist.sample((bs * N, q)).to('cuda')
                z_jump = .5*(z_jump+4)
                z_shared = torch.randn(bs, len(bins)+1, q).to('cuda')

                current_corr_matrix_holder = torch.zeros((bs, N, N))
                desired_corr_matrix_holder = torch.zeros((bs, N, N))
                current_corr_matrix = corr(x_real.x[:, :p].reshape(bs, ninst2, -1).permute(0, 2, 1),
                                             False).clone().to(desired_corr_matrix_holder.device)

                print('if acloss bad, set shared noise to 0')
                desired_corr_matrix_holder = torch.zeros((bs,N, N))
                indices2 = torch.triu_indices(N, N, 1)

                desired_corr_matrix_all=des_dict[q111]
                if step == len(loader) - 1:
                    corrtype = np.random.choice(range(6))
                elif step == len(loader):
                    corrtype=np.random.choice(range(6))
                else:
                    corrtype=(step+1)%6
                if not includeevents:
                    ini_poss_indices=np.where(des_idx_dict[q111]['Type']['Type']==corrtype)[0]
                else:
                    if desired_corr_matrix_all.size(0)<bs:
                        ini_poss_indices=np.array([x for x in range(bs)])
                    else:
                        ini_poss_indices = np.random.choice(desired_corr_matrix_all.size(0), bs, replace=False)
                if len(rands)>0:
                    if desired_corr_matrix_all.size(0) < bs:
                        poss_choice =np.array([x for x in range(bs)])
                    else:
                        poss_choice=[x for x in ini_poss_indices if x not in np.hstack(rands)]
                    try:
                        random_choice = np.random.choice(poss_choice, bs, replace=False)
                    except:
                        random_choice = np.random.choice(poss_choice, bs, replace=True)
                    rands.append(random_choice)
                    if len(poss_choice)<bs*3:
                        rands=[]
                else:
                  random_choice = np.random.choice(ini_poss_indices, bs,replace=False)
                  rands.append(random_choice)
                if includeevents:
                    ident = 0
                    if ( trump) or ( harris):
                        desired_corr_matrix = corr(x_real.x[:, :p].reshape(bs, ninst2, -1).permute(0, 2, 1), False).clone().to(desired_corr_matrix_holder.device)
                        post_event_desired_corr_matrix=desired_corr_matrix_all[random_choice]

                    else:
                        if corrtruesplit:
                            cts='CTS'
                            desired_corr_matrix = des_dictb[20].to(desired_corr_matrix_holder.device).repeat(bs,1)
                            post_event_desired_corr_matrix=des_dicta[20].to(desired_corr_matrix_holder.device).repeat(bs,1)
                        else:
                            cts = ''
                            desired_corr_matrix = des_dict[20].to(desired_corr_matrix_holder.device).repeat(bs,1)
                            post_event_desired_corr_matrix = des_dict[20].to(desired_corr_matrix_holder.device).repeat(bs,1)
                else:
                    if epoch%10==0:
                        if step==len(loader)-1:
                            ident = 0
                            desired_corr_matrix = desired_corr_matrix_all[random_choice]#.repeat(bs,1)
                        elif step==len(loader):
                            ident = 0
                            desired_corr_matrix = desired_corr_matrix_all[random_choice]#.repeat(bs,1)

                        elif step % stepper == 0:
                            ident = 1
                            desired_corr_matrix=corr(x_real.x[:, p:p+q].reshape(bs, ninst2, -1).permute(0, 2, 1), False).clone().to(desired_corr_matrix_holder.device)
                        else:
                            ident = 0
                            desired_corr_matrix=desired_corr_matrix_all[random_choice]#.repeat(bs,1)
                    elif epoch % 5 != 0:
                        if step == len(loader) - 1:
                            ident = 0
                            desired_corr_matrix = desired_corr_matrix_all[random_choice]#.repeat(bs,1)
                        elif step == len(loader):
                            ident = 0
                            desired_corr_matrix = desired_corr_matrix_all[random_choice]#.repeat(bs,1)

                        elif step % stepper == 0:
                            ident = 1
                            desired_corr_matrix = corr(x_real.x[:, p:p + q].reshape(bs, ninst2, -1).permute(0, 2, 1),
                                                       False).clone().to(desired_corr_matrix_holder.device)
                        else:
                            ident = 0
                            desired_corr_matrix = desired_corr_matrix_all[random_choice]#.repeat(bs,1)


                    else:
                        if step == len(loader) - 1:
                            ident = 1
                            desired_corr_matrix = corr(x_real.x[:, p:p + q].reshape(bs, ninst2, -1).permute(0, 2, 1),
                                                       False).clone().to(desired_corr_matrix_holder.device)
                        elif step == len(loader):
                            ident = 1
                            desired_corr_matrix = corr(x_real.x[:, p:p + q].reshape(bs, ninst2, -1).permute(0, 2, 1),
                                                       False).clone().to(desired_corr_matrix_holder.device)

                        elif step % stepper == 0:
                            ident = 1
                            desired_corr_matrix = corr(x_real.x[:, p:p + q].reshape(bs, ninst2, -1).permute(0, 2, 1),
                                                       False).clone().to(desired_corr_matrix_holder.device)
                        else:
                            ident = 0
                            desired_corr_matrix = desired_corr_matrix_all[random_choice]#.repeat(bs,1)
                flat_dcm=desired_corr_matrix.clone()
                if includeevents:
                    random_choice2 = np.random.choice(desired_vol.size(0), bs)
                    reformed = x_real.x.reshape(bs, ninst2, -1)[:, :, :p + q]
                    desired_vol_mat = desired_vol[random_choice2].to('cuda')
                    desired_vol_mat_ix = torch.zeros(bs, desired_vol.size(1), 13).to('cuda')
                    desired_vol_mat=torch.cat((desired_vol_mat, desired_vol_mat_ix), dim=2)
                    desired_vol_mat = torch.transpose(desired_vol_mat, 2, 1)
                    reformed = reformed + torch.rand(reformed.size()).to(reformed.device) * 1e-5

                    earningsandelectionedits = torch.abs(desired_vol_mat / reformed) * reformed
                    mask = earningsandelectionedits != 0
                    reformed[mask] = earningsandelectionedits[mask].float()
                    if actual==0:
                        x_real.x[:, :p + eventtime] = reformed[:,:,:p + eventtime].reshape(x_real.x[:, :p + eventtime].size())
                    elif (actual!=0) & (notreal):
                        x_real.x[:, p:p + eventtime] = reformed[:,:,p:p + eventtime].reshape(x_real.x[:, p:p + eventtime].size())

                else:
                    desired_vol_mat=None
                desired_corr_matrix_holder[:, indices2[0], indices2[1]] = desired_corr_matrix
                desired_corr_matrix_holder=desired_corr_matrix_holder+desired_corr_matrix_holder.permute(0,2,1)
                desired_corr_matrix_holder=desired_corr_matrix_holder+torch.eye(N).repeat(bs,1,1)
                desired_corr_matrix=desired_corr_matrix_holder
                if correlatedn:
                    z=correlated_noise(desired_corr_matrix,z.reshape(bs,-1,q)).reshape(-1,q)
                batchnoise = z.reshape(bs, -1, q).to(desired_corr_matrix.device)
                z_edges = torch.stack(
                  [batched_pairwise_noise_matrix(batchnoise[:, :, i], flat_dcm) for i in
                   range(q)]).transpose(1, 0)
                if includeevents:
                    post_event_desired_corr_matrix_holder = torch.zeros((bs, N, N))
                    z_post_edges = torch.stack(
                        [batched_pairwise_noise_matrix(batchnoise[:, :, i], post_event_desired_corr_matrix) for i in
                         range(q)]).transpose(1, 0)
                    post_event_desired_corr_matrix_holder[:, indices2[0], indices2[1]] = post_event_desired_corr_matrix
                    post_event_desired_corr_matrix_holder = post_event_desired_corr_matrix_holder + post_event_desired_corr_matrix_holder.permute(0, 2, 1)
                    post_event_desired_corr_matrix_holder = post_event_desired_corr_matrix_holder + torch.eye(N).repeat(bs, 1, 1)
                    post_event_desired_corr_matrix = post_event_desired_corr_matrix_holder
                locorr=-.999
                maxbs=.999
                k=40
                k2=40
                cid = torch.eye(desired_corr_matrix.size(2)).reshape((1, desired_corr_matrix.size(2), desired_corr_matrix.size(2))).repeat(
                      desired_corr_matrix.size(0), 1, 1).to(desired_corr_matrix.device)
                ctopk_values, curindices = torch.topk(torch.abs(desired_corr_matrix)-cid, k, dim=1)
                curmask_dcm = torch.zeros_like(desired_corr_matrix)
                indices_prob = torch.rand(curindices.size())
                ctopk_values2, indices22 = torch.topk(indices_prob, k2, dim=1)
                curmask_index = torch.zeros_like(curindices)
                curmask_index.scatter_(1, indices22, 1)
                curmask_index = curmask_index * curindices + curmask_index  # therefore anything that's not index 0 is good?
                curmask_dcm2 = torch.zeros((desired_corr_matrix.size(0), desired_corr_matrix.size(1) + 1,
                                          desired_corr_matrix.size(2) + 1))
                curmask_dcm2.scatter_(1, curmask_index, 1)
                curmask_dcm = curmask_dcm2[:, 1:, :-1]
                curmask_dcm2 = curmask_dcm.clone() * desired_corr_matrix
                curmask_dcm = curmask_dcm * desired_corr_matrix
                indices2 = ((curmask_dcm >= locorr) & (curmask_dcm <= maxbs)& (curmask_dcm.abs() >barrier)).nonzero(as_tuple=False).long().T
                n = curmask_dcm[0].size(1)
                tril_i, tril_j = torch.tril_indices(n, n, offset=-1)
                combined_mask = torch.ones_like(curmask_dcm, dtype=torch.bool)
                combined_mask =((curmask_dcm >= locorr) & (curmask_dcm <= maxbs)& (curmask_dcm.abs() >barrier))
                valid_mask = combined_mask[:,tril_i, tril_j]
                indices2[1] = indices2[1] + indices2[0] * N
                indices2[2] = indices2[0] * N + indices2[2]
                z_edge_index=z_edges[combined_mask.repeat(q,1,1,1).transpose(0,1)].reshape(-1,q).to(x_real.x.device)
                extra_edge_index = torch.index_select(indices2, 0, torch.tensor([1, 2]))
                extra_edge_index = extra_edge_index.to(x_real.x.device)
                extra_edge_attr = curmask_dcm2[(curmask_dcm >= locorr) & (curmask_dcm <= maxbs)& (curmask_dcm.abs() >barrier)]
                bin_indices = torch.empty_like(extra_edge_attr)
                if includeevents:
                    id = torch.eye(post_event_desired_corr_matrix.size(2)).reshape(
                        (1, post_event_desired_corr_matrix.size(2), post_event_desired_corr_matrix.size(2))).repeat(
                        post_event_desired_corr_matrix.size(0), 1, 1).to(post_event_desired_corr_matrix.device)
                    topk_values, indices = torch.topk(torch.abs(post_event_desired_corr_matrix) - id, k, dim=1)
                    mask_dcm = torch.zeros_like(post_event_desired_corr_matrix)
                    indices_prob = torch.rand(indices.size())
                    topk_values2, indices2 = torch.topk(indices_prob, k2, dim=1)
                    mask_index = torch.zeros_like(indices)
                    mask_index.scatter_(1, indices2, 1)
                    mask_index = mask_index * indices + mask_index
                    mask_dcm2 = torch.zeros(
                        (post_event_desired_corr_matrix.size(0), post_event_desired_corr_matrix.size(1) + 1, post_event_desired_corr_matrix.size(2) + 1))
                    mask_dcm2.scatter_(1, mask_index, 1)
                    mask_dcm = mask_dcm2[:, 1:, 1:]
                    mask_dcm2 = mask_dcm.clone() * post_event_desired_corr_matrix
                    mask_dcm = mask_dcm * post_event_desired_corr_matrix
                    combined_mask = torch.ones_like(mask_dcm, dtype=torch.bool)
                    combined_mask = ((mask_dcm >= locorr) & (mask_dcm <= maxbs) & (mask_dcm.abs() > barrier))
                    indices2 = ((mask_dcm >= locorr) & (mask_dcm <= maxbs)& (mask_dcm.abs() >barrier)).nonzero(as_tuple=False).long().T
                    indices2[1] = indices2[1] + indices2[0] * N
                    indices2[2] = indices2[0] * N + indices2[2]
                    post_event_extra_edge_index = torch.index_select(indices2, 0, torch.tensor([1, 2]))
                    post_event_extra_edge_index = post_event_extra_edge_index.to(x_real.x.device)
                    post_event_extra_edge_attr = mask_dcm2[(mask_dcm >= locorr) & (mask_dcm <= maxbs)& (mask_dcm.abs() >barrier)]
                    post_event_bin_indices = torch.empty_like(post_event_extra_edge_attr)
                    z_post_edge_index = z_post_edges[combined_mask.repeat(q, 1, 1, 1).transpose(0, 1)].reshape(-1, q).to(
                        x_real.x.device)
                    if includeevents:
                        for i, (lower, upper) in enumerate(bins):
                            mask = (post_event_extra_edge_attr >= lower) & (post_event_extra_edge_attr < upper)
                            post_event_bin_indices[mask] = i
                    post_event_extra_edge_attr = post_event_bin_indices.to(x_real.x.device).long()
                for i, (lower, upper) in enumerate(bins):
                  mask = (extra_edge_attr >= lower) & (extra_edge_attr < upper)
                  bin_indices[mask] = i
                extra_edge_attr=bin_indices.to(x_real.x.device).long()
                indices1 = torch.triu_indices(N, N, 1)
                desired_corr_matrix1 = desired_corr_matrix.clone()[:,indices1[0], indices1[1]]
                print('MAX',x_real.x.max().item())
                print('MIN', x_real.x.min().item())
                try:
                    del(x_fake)
                    torch.cuda.empty_cache()
                    gc.collect(generation=2)
                except:
                    pass
                if includeevents:
                    event_t=12
                    x_fake,bls,gls = generator(x_real.clone(), z*zm,z_shared[batch.batch]*zsharedm, z_edge_index*zedgem,z_jump, extra_edge_index, extra_edge_attr, desired_vol_mat,
                                       post_event_extra_edge_index, post_event_extra_edge_attr, z_post_edge_index*zedgem, event_t,ninst2,return_FiLM=return_FiLM,notreal=notreal)


                else:
                    x_fake,bls,gls = generator(x_real.clone(),z*zm,z_shared[batch.batch]*zsharedm, z_edge_index*zedgem,z_jump, extra_edge_index, extra_edge_attr, desired_vol_mat,
                                           ninst2=ninst2,return_FiLM=return_FiLM)

                    if epoch%4==0:
                        if not generate:
                            if (step+1)%20==0:
                                x_fake_gs1= x_fake.reshape(bs, N, q).permute(0, 2, 1)
                                zrealgs1=x_real.x.reshape(bs, N, -1).permute(0, 2, 1)[:,p:p+q]
                                torch.save(zrealgs1, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) + '_generation_real_train.pt')
                                torch.save(x_fake_gs1, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) + '_generation_train.pt')
                                torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) + '_dcm_train.pt')

                if generate:
                    x_fake_gs = x_fake.reshape(bs, N, q).permute(0, 2, 1)
                    if not includeevents:
                        if return_FiLM:
                            torch.save(torch.stack(bls).reshape(-1,q,bs,N), gen_dir + '\\' + datasetname + '_' + str(
                              epoch * len(loader) + step) + '_betas.pt')
                            torch.save(torch.stack(gls).reshape(-1,q,bs,N), gen_dir + '\\' + datasetname + '_' + str(
                              epoch * len(loader) + step) + '_gammas.pt')
                            torch.save(x_real.x,
                                 gen_dir + '\\' + datasetname + '_' + str(
                                     epoch * len(loader) + step) + '_realcondition.pt')


                            torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                                  epoch * len(loader) + step) + '_generation512.pt')

                            torch.save(desired_corr_matrix,
                                         gen_dir + '\\' + datasetname + '_' + str(
                                             epoch * len(loader) + step) + '_dcm512.pt')
                    elif harris:
                      if actual == 1:
                          if notreal:
                              torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                                  epoch * len(loader) + step) + '_harris_generation512r.pt')
                              torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                  epoch * len(loader) + step) + '_harris_pedcm512r.pt')
                              torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                  epoch * len(loader) + step) + '_harris_dcm512r.pt')

                          else:
                              torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                                  epoch * len(loader) + step) + '_harris_generation512nr.pt')
                              torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                  epoch * len(loader) + step) + '_harris_pedcm512nr.pt')
                              torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                  epoch * len(loader) + step) + '_harris_dcm512nr.pt')

                      else:
                          torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                              epoch * len(loader) + step) + '_harris_generation512.pt')
                          torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                              epoch * len(loader) + step) + '_harris_pedcm512.pt')
                          torch.save(desired_corr_matrix,
                                     gen_dir + '\\' + datasetname + '_' + str(
                                         epoch * len(loader) + step) + '_harris_dcm512.pt')
                    elif trump:
                        if actual==1:
                         if  notreal:
                            torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                                epoch * len(loader) + step) + '_generation_trump512r.pt')
                            torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                epoch * len(loader) + step) + '_pedcm_trump512r.pt')
                            torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                epoch * len(loader) + step) + '_dcm_trump512r.pt')
                         else:
                          torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                            epoch * len(loader) + step) + '_generation_trump512nr.pt')
                          torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                            epoch * len(loader) + step) + '_pedcm_trump512nr.pt')
                          torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                            epoch * len(loader) + step) + '_dcm_trump512nr.pt')
                        else:
                          torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                              epoch * len(loader) + step) + '_generation_trump512.pt')
                          torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                              epoch * len(loader) + step) + '_pedcm_trump512.pt')
                          torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                              epoch * len(loader) + step) + '_dcm_trump512.pt')
                    else:

                        if actual == 1:
                            if notreal:
                                torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) +'_'+cts+ '_generation_act512r.pt')
                                torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) + '_'+cts+ '_pedcm_act512r.pt')
                                torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) +'_'+cts+  '_dcm_act512r.pt')

                            else:
                                torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) + '_'+cts+ '_generation_act512nr.pt')
                                torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) +'_'+cts+  '_pedcm_act512nr.pt')
                                torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                    epoch * len(loader) + step) +'_'+cts+  '_dcm_act512nr.pt')
                        else:
                            torch.save(x_fake_gs, gen_dir + '\\' + datasetname + '_' + str(
                                epoch * len(loader) + step) +'_'+cts+  '_generation_act512.pt')
                            torch.save(post_event_desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                epoch * len(loader) + step) +'_'+cts+  '_pedcm_act512.pt')
                            torch.save(desired_corr_matrix, gen_dir + '\\' + datasetname + '_' + str(
                                epoch * len(loader) + step) + '_'+cts+ '_dcm_act512.pt')

                            x_future_g = xreal[:, p:p + q].reshape(bs, -1, q).permute(0, 2, 1)
                            x_past_g = xreal[:, :p].reshape(bs, -1, p).permute(0, 2, 1)
                            plt_figures(x_fake_gs, x_future_g, x_past_g, desired_corr_matrix, desired_vol_mat, epoch,
                                        plt_directory, datasetname, gan_algo='GNN', weightstab=weightstab)
                            if weightstab is not None:
                                plt_figures([x_fake_gs[0][:, :, -13:]], x_future_g[:, :, -13:], x_past_g[:, :, -13:],
                                            C_desired_ix, desired_vol_mat, epoch,
                                            plt_directory, datasetname, gan_algo='GNN', weightstab=weightstab,
                                            sectors=True)
                            if weightstab is not None:
                                plt_figures([x_fake_gs[0][:, :, -26:-13]], x_future_g[:, :, -13:], x_past_g[:, :, -13:],
                                            C_desired_ix, desired_vol_mat, epoch,
                                            plt_directory, datasetname, gan_algo='GNA', weightstab=weightstab,
                                            sectors=True)

                if not generate:
                    mlist, ave_grads_g, layers_g,mm_list, step_number = train_step(ninst2,p,q,generator, G_optimizer, scheduler, x_fake, x_real,
                                                                        layers_g, ave_grads_g,mm_list,ident=ident,weightstab=weightstab,
                                                                        corr_loss=True, epoch=epoch,
                                                                        step_number=step_number,desired_corr_matrix=desired_corr_matrix1,full_desired_corr_matrix=desired_corr_matrix,desired_vol_mat=desired_vol_mat,seed=seed)
                    print('LR:',G_optimizer.param_groups[0]['lr'])
                    ioui,gl,gla,cl,cl2,kl,mstloss,stdl,skl,sl1,sl2,acl,kl2,eigl,eiglix,div_loss,c_l_ix,j_l,j_la,j_l_e,j_l_e_a,s_lossat1,gl_e,gla_e,constraint_loss,std_loss,stdev_loss,glix,glaix ,glsum ,glasum,c_loss_ix1_,c_loss_ix2_=mlist
                    metric_dict[mcounter]={'tk':ioui,'gl':gl,'cl':cl,'gla':gla,'cl2':cl2,'kl':kl,'mstloss':mstloss,'stdl':stdl,'skl':skl,'sl1':sl1,'sl2':sl2,'acl':acl,'kl2':kl2,'eigl':eigl,'eiglix':eiglix,'div_loss':div_loss,'c_l_ix':c_l_ix,'j_l':j_l,'j_la':j_la,'j_l_e':j_l_e,'j_l_e_a':j_l_e_a,'s_lossa':s_lossat1,'gl_e':gl_e,'gla_e':gla_e,'constraint_loss':constraint_loss,'std_loss':std_loss,'stdev_loss':stdev_loss,'glix':glix,'glaix':glaix ,'glsum':glsum ,'glasum':glasum,',c_loss_ix1_':c_loss_ix1_,'c_loss_ix2_':c_loss_ix2_}
                    mcounter+=1
                    ioulist.append(ioui)
                    torch.cuda.empty_cache()
                    gc.collect(generation=2)
                    endtime=datetime.datetime.now()
                    nsteps=step
                    print(goodchoices)
                    print('iou',np.mean(ioulist))
                    print(epoch)
                    print('EpochTrainingLength',endtime-start_time)
                    if epoch%2==0:
                        torch.save(generator, model_dir+'\\'+str(epoch)+'_model_weights.pth')

            if epoch%2==0:
                if weightstab is not None:

                    xf = x_fake.reshape(bs, N , -1).permute(0, 2, 1)
                    indexvalues = torch.matmul(xf[:, :, :-13], weightstab.to(x_fake.device).T)
                    xf = torch.cat((xf, indexvalues), dim=2)
                    x_fake = xf.permute(0,2,1).reshape(-1,q).unsqueeze(2)
                    xreal = torch.cat((x_real.x.reshape(bs, N , -1).permute(0, 2, 1),
                                       x_real.x2.reshape(bs, 13, -1).permute(0, 2, 1)), dim=2)
                    xreal = xreal.permute(0, 2, 1).reshape(x_fake.size(0),-1)
                    C_desired_ix = desired_corr_matrix[:, -13:, -13:].to(
                        x_fake.device)
                else:
                    xreal=x_real.x
                if not generate:
                    x_fake_gs=[x_fake.reshape(bs,-1,q).permute(0,2,1)]
                    x_future_g=xreal[:,p:p+q].reshape(bs,-1,q).permute(0,2,1)
                    x_past_g=xreal[:,:p].reshape(bs,-1,p).permute(0,2,1)
                    plt_figures(x_fake_gs, x_future_g, x_past_g,desired_corr_matrix,desired_vol_mat, epoch, plt_directory, datasetname,gan_algo='GNN',weightstab=weightstab)
                    if weightstab is not None:
                        plt_figures([x_fake_gs[0][:,:,-13:]], x_future_g[:,:,-13:], x_past_g[:,:,-13:], C_desired_ix, desired_vol_mat, epoch,
                                    plt_directory, datasetname, gan_algo='GNN', weightstab=weightstab,sectors=True)
                    if weightstab is not None:
                        plt_figures([x_fake_gs[0][:,:,-26:-13]], x_future_g[:,:,-13:], x_past_g[:,:,-13:], C_desired_ix, desired_vol_mat, epoch,
                                    plt_directory, datasetname, gan_algo='GNA', weightstab=weightstab,sectors=True)
                    loss_dict=evaluate( x_fake_gs, x_future_g, epoch)
                    loss_df=pd.DataFrame(loss_dict)
                    loss_df.to_hdf(experiment_directory+'\\loss_df_'+str(epoch)+'.h5', key='loss_df')
                    metric_df=pd.DataFrame(metric_dict)
                    metric_df.to_hdf(experiment_directory + '\\metric_df_' + str(epoch) + '.h5', key='metric_df')

            if (epoch) % 2 == 0:
                if not generate:
                    desired_corr_matrix_all = des_dict[q]
                    corrtype = (epoch + 1) % 6
                    ini_poss_indices = np.where(des_idx_dict[q]['Type']['Type'] == corrtype)[0]
                    random_choice = np.random.choice(ini_poss_indices, 1, replace=False)
                    for ident in [0,1]:
                        index11 = np.random.choice([x for x in range(20)])
                        bs = 10
                        ix1=np.random.choice(range(len(dataset)-22),20)
                        sampler11 = RepeatSingleIndexSampler(index11, bs)
                        loader_rep = DataLoader([x for x in dataset[list(ix1)]], sampler=sampler11, batch_size=bs)
                        loader_iter = iter(loader_rep)
                        x_real_repeat = next(loader_iter)
                        del loader_iter, loader_rep
                        torch.cuda.empty_cache()
                        gc.collect(2)
                        z = torch.randn(bs * N, q).to('cuda')   # Shape: (B * N, l)
                        z_jump = torch.randn(bs * N, q).to('cuda')
                        chi2_dist = torch.distributions.Chi2(df=2.5)
                        z_jump = chi2_dist.sample((bs * N, q)).to('cuda')
                        z_jump = .5 * (z_jump + 4)
                        z_shared = torch.randn(bs, len(bins) + 1, q).to('cuda')
                        desired_corr_matrix_holder = torch.zeros((bs, N, N))#.to('cuda')
                        indices2 = torch.triu_indices(N, N, 1)
                        if includeevents:
                            desired_corr_matrix = corr(x_real_repeat.x[:, :p].reshape(bs, ninst, -1).permute(0, 2, 1)[:, :, :ninst2],
                                                       False).clone().to(desired_corr_matrix_holder.device)
                            post_event_desired_corr_matrix = desired_corr_matrix_all[random_choice].repeat(bs,1)
                        else:
                            if ident==0:
                                desired_corr_matrix = desired_corr_matrix_all[random_choice].to(desired_corr_matrix_holder.device).repeat(bs,1)

                            else:
                                desired_corr_matrix = corr(x_real_repeat.x[:, p:p + q].reshape(bs, ninst2, -1).permute(0, 2, 1),
                                                           False).clone().to(desired_corr_matrix_holder.device)

                        flat_dcm = desired_corr_matrix.clone()

                        batchnoise = z.reshape(bs, -1, q).to(desired_corr_matrix.device)

                        if includeevents:
                            random_choice2 = np.random.choice(desired_vol.size(0), bs)
                            reformed = x_real_repeat.x.reshape(bs, ninst2, -1)[:, :, :p + q]
                            desired_vol_mat = desired_vol[random_choice2].to('cuda')
                            desired_vol_mat_ix = torch.zeros(bs, 2 * q, 13).to('cuda')
                            desired_vol_mat = torch.cat((desired_vol_mat, desired_vol_mat_ix), dim=2)
                            desired_vol_mat = torch.transpose(desired_vol_mat, 2, 1)
                            earningsandelectionedits = torch.abs(desired_vol_mat / reformed) * reformed
                            mask = earningsandelectionedits != 0
                            reformed[mask] = earningsandelectionedits[mask].float()
                            x_real_repeat.x[:, :p + q] = reformed.reshape(x_real_repeat.x[:, :p + q].size())
                        else:
                            desired_vol_mat = None

                        desired_corr_matrix_holder[:, indices2[0], indices2[1]] = desired_corr_matrix
                        desired_corr_matrix_holder = desired_corr_matrix_holder + desired_corr_matrix_holder.permute(0, 2, 1)
                        desired_corr_matrix_holder = desired_corr_matrix_holder + torch.eye(N).repeat(bs, 1, 1)
                        desired_corr_matrix = desired_corr_matrix_holder
                        if correlatedn:
                            z = correlated_noise(desired_corr_matrix, z.reshape(bs, -1, q)).reshape(-1, q)
                        batchnoise = z.reshape(bs, -1, q).to(desired_corr_matrix.device)
                        z_edges = torch.stack(
                            [batched_pairwise_noise_matrix(batchnoise[:, :, i], flat_dcm) for i in
                             range(q)]).transpose(1, 0)
                        if includeevents:
                            post_event_desired_corr_matrix_holder = torch.zeros((bs, N, N))

                            post_event_desired_corr_matrix_holder[:, indices2[0], indices2[1]] = post_event_desired_corr_matrix
                            post_event_desired_corr_matrix_holder = post_event_desired_corr_matrix_holder + post_event_desired_corr_matrix_holder.permute(
                                0, 2, 1)
                            post_event_desired_corr_matrix_holder = post_event_desired_corr_matrix_holder + torch.eye(N).repeat(
                                bs, 1, 1)
                            post_event_desired_corr_matrix = post_event_desired_corr_matrix_holder



                        locorr = -.999
                        maxbs = .999
                        id = torch.eye(desired_corr_matrix.size(2)).reshape(
                            (1, desired_corr_matrix.size(2), desired_corr_matrix.size(2))).repeat(
                            desired_corr_matrix.size(0), 1, 1).to(desired_corr_matrix.device)
                        topk_values, indices = torch.topk(torch.abs(desired_corr_matrix) - id, k, dim=1)
                        mask_dcm = torch.zeros_like(desired_corr_matrix)
                        indices_prob = torch.rand(indices.size())
                        topk_values2, indices2 = torch.topk(indices_prob, k2, dim=1)
                        mask_index = torch.zeros_like(indices)
                        mask_index.scatter_(1, indices2, 1)
                        mask_index = mask_index * indices + mask_index
                        mask_dcm2 = torch.zeros(
                            (desired_corr_matrix.size(0), desired_corr_matrix.size(1) + 1, desired_corr_matrix.size(2) + 1))
                        mask_dcm2.scatter_(1, mask_index, 1)
                        mask_dcm = mask_dcm2[:, 1:, 1:]
                        mask_dcm2 = mask_dcm.clone() * desired_corr_matrix
                        mask_dcm = mask_dcm *desired_corr_matrix
                        indices2 = ((mask_dcm >= locorr) & (mask_dcm <= maxbs)& (mask_dcm.abs() >barrier)).nonzero(as_tuple=False).long().T
                        indices2[1] = indices2[1] + indices2[0] * N
                        indices2[2] = indices2[0] * N + indices2[2]
                        extra_edge_index = torch.index_select(indices2, 0, torch.tensor([1, 2]))
                        extra_edge_index = extra_edge_index.to(x_real.x.device)
                        extra_edge_attr = mask_dcm2[(mask_dcm >= locorr) & (mask_dcm <= maxbs)& (mask_dcm.abs()>barrier)]
                        bin_indices = torch.empty_like(extra_edge_attr)
                        if includeevents:
                            id = torch.eye(post_event_desired_corr_matrix.size(2)).reshape(
                                (1, post_event_desired_corr_matrix.size(2), post_event_desired_corr_matrix.size(2))).repeat(
                                post_event_desired_corr_matrix.size(0), 1, 1).to(post_event_desired_corr_matrix.device)
                            topk_values, indices = torch.topk(torch.abs(post_event_desired_corr_matrix) - id, k, dim=1)
                            mask_dcm = torch.zeros_like(post_event_desired_corr_matrix)
                            indices_prob = torch.rand(indices.size())
                            topk_values2, indices2 = torch.topk(indices_prob, k2, dim=1)
                            mask_index = torch.zeros_like(indices)
                            mask_index.scatter_(1, indices2, 1)
                            mask_index = mask_index * indices + mask_index
                            mask_dcm2 = torch.zeros(
                                (post_event_desired_corr_matrix.size(0), post_event_desired_corr_matrix.size(1) + 1,
                                 post_event_desired_corr_matrix.size(2) + 1))
                            mask_dcm2.scatter_(1, mask_index, 1)
                            mask_dcm = mask_dcm2[:, 1:, 1:]
                            mask_dcm2 = mask_dcm.clone() * post_event_desired_corr_matrix
                            mask_dcm = mask_dcm * post_event_desired_corr_matrix
                            indices2 = ((mask_dcm >= locorr) & (mask_dcm <= maxbs)& (mask_dcm.abs() >barrier)).nonzero(as_tuple=False).long().T
                            indices2[1] = indices2[1] + indices2[0] * N
                            indices2[2] = indices2[0] * N + indices2[2]
                            post_event_extra_edge_index = torch.index_select(indices2, 0, torch.tensor([1, 2]))
                            post_event_extra_edge_index = post_event_extra_edge_index.to(x_real.x.device)
                            post_event_extra_edge_attr = mask_dcm2[(mask_dcm >= locorr) & (mask_dcm <= maxbs)& (mask_dcm.abs() >barrier)]
                            post_event_bin_indices = torch.empty_like(post_event_extra_edge_attr)
                        if includeevents:
                            for i, (lower, upper) in enumerate(bins):
                                mask = (post_event_extra_edge_attr >= lower) & (post_event_extra_edge_attr < upper)
                                post_event_bin_indices[mask] = i
                            post_event_extra_edge_attr = post_event_bin_indices.to(x_real.x.device).long()

                        for i, (lower, upper) in enumerate(bins):
                            mask = (extra_edge_attr >= lower) & (extra_edge_attr < upper)
                            bin_indices[mask] = i

                        extra_edge_attr = bin_indices.to(x_real.x.device).long()
                        indices1 = torch.triu_indices(N, N, 1)

                        n = curmask_dcm[0].size(1)
                        tril_i, tril_j = torch.tril_indices(n, n, offset=-1)
                        combined_mask = torch.ones_like(mask_dcm, dtype=torch.bool)
                        combined_mask = (
                                    (mask_dcm >= locorr) & (mask_dcm <= maxbs) & (mask_dcm.abs() > barrier))

                        valid_mask = combined_mask[:, tril_i, tril_j]
                        z_edge_index = z_edges[combined_mask.repeat(q, 1, 1, 1).transpose(0, 1)].reshape(-1, q)
                        desired_corr_matrix1 = desired_corr_matrix.clone()[:, indices1[0], indices1[1]]

                        generator.eval()
                        with torch.no_grad():
                            x_fake_rep,bls,gls = generator(x_real_repeat.to('cuda').clone(), z*zm,z_shared[x_real_repeat.batch]*zsharedm, z_edge_index*zedgem, z_jump, extra_edge_index, extra_edge_attr,
                                               desired_vol_mat, ninst2=ninst2)
                        if weightstab is not None:

                            xf = x_fake_rep.reshape(bs, N, -1).permute(0, 2, 1)
                            indexvalues = torch.matmul(xf[:, :, :-13], weightstab.to(x_fake.device).T)
                            xf = torch.cat((xf, indexvalues), dim=2)
                            # x_fake=xf.reshape(x_real.x[:,p:p+q].size())
                            x_fake_rep = xf.permute(0, 2, 1).reshape(-1, q).unsqueeze(2)
                            xreal = torch.cat((x_real_repeat.x.reshape(bs, N, -1).permute(0, 2, 1),
                                               x_real_repeat.x2.reshape(bs, 13, -1).permute(0, 2, 1)), dim=2)
                            xreal = xreal.permute(0, 2, 1).reshape(x_fake_rep.size(0), -1)
                            C_desired_ix=desired_corr_matrix[:,-13:,-13:]



                        else:
                            xreal = x_real_repeat.x
                        if not generate:
                            x_fake_gs = [x_fake_rep.reshape(bs, -1, q).permute(0, 2, 1)]
                            x_future_g = xreal[:, p:p + q].reshape(bs, -1, q).permute(0, 2, 1)
                            x_past_g = xreal[:, :p].reshape(bs, -1, p).permute(0, 2, 1)
                            plt_figures(x_fake_gs, x_future_g, x_past_g, desired_corr_matrix, desired_vol_mat, epoch, plt_directory,
                                        datasetname, gan_algo='GNN', weightstab=weightstab,rep=True,ident=ident)
                            loss_dict1 = evaluate(x_fake_gs, x_future_g, epoch)
                            loss_df1 = pd.DataFrame(loss_dict)
                            loss_df1.to_hdf(experiment_directory + '\\loss_df_repeat' + str(epoch) + '.h5', key='loss_df')

        try:
            del (x_fake)
            torch.cuda.empty_cache()
            gc.collect(generation=2)
        except:
            pass
        training_endtime = datetime.datetime.now()
        print('TotalTrainingLength', training_endtime - training_starttime)

