import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import math
from tqdm import tqdm
import time
from torch.utils.data import Dataset, DataLoader

def kl_divergence(ax, bx):
     # KL divergence
    log_ax = torch.log(ax + 1e-8)  # Avoid log(0) #[3,512, 2000]
    log_bx = torch.log(bx + 1e-8)  # Avoid

    sub = log_ax.unsqueeze(2) - log_bx.unsqueeze(1) # 3, 512, 1024, 2000
    cur_sim = torch.sum(ax.unsqueeze(2) * sub, dim=3)  # 3, 512, 1024
    cur_sim = cur_sim.mean(dim=3)  # 3, 512, 102 4
    cur_sim = cur_sim.transpose(1,2)
    return cur_sim
def js_divergence(ax, bx):
    p = (ax.unsqueeze(2) + bx.unsqueeze(1)) / 2
    log_p = torch.log(p + 1e-8)  # Avoid log(0)
    
    cur_sim = (ax.unsqueeze(2) * (torch.log(ax.unsqueeze(2) + 1e-8) - log_p)).sum(dim=3) + \
                (bx.unsqueeze(1) * (torch.log(bx.unsqueeze(1) + 1e-8) - log_p)).sum(dim=3)
    cur_sim = cur_sim.mean(dim=3)
    cur_sim = cur_sim.transpose(1,2)
    return cur_sim

def js_divergence_CI(ax, bx):
    p = (ax.unsqueeze(2) + bx.unsqueeze(1)) / 2
    log_p = torch.log(p + 1e-8)  # Avoid log(0)
    
    cur_sim = (ax.unsqueeze(2) * (torch.log(ax.unsqueeze(2) + 1e-8) - log_p)).sum(dim=3) + \
                (bx.unsqueeze(1) * (torch.log(bx.unsqueeze(1) + 1e-8) - log_p)).sum(dim=3)
    # cur_sim = cur_sim.mean(dim=3)
    cur_sim = cur_sim.transpose(1,2)
    # print('cur_sim: ', cur_sim.shape)
    return cur_sim

class RetrievalTool():
    def __init__(
        self,
        seq_len,
        pred_len,
        channels,
        n_period=3,
        temperature=0.1,
        topm=20,
        with_dec=False,
        return_key=False,
    ):
        period_num = [16, 8, 4, 2, 1]
        period_num = period_num[-1 * n_period:]
        
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.channels = channels
        
        self.n_period = n_period
        self.period_num = sorted(period_num, reverse=True)
        
        self.temperature = temperature
        self.topm = topm
        
        self.with_dec = with_dec
        self.return_key = return_key
        
    def prepare_dataset(self, train_data):
        train_data_all = []
        y_data_all = []

        for i in range(len(train_data)):
            td = train_data[i]
            train_data_all.append(td[1])
            
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])
            
        self.train_data_all = torch.tensor(np.stack(train_data_all, axis=0)).float()
        self.train_data_all_mg, _ = self.decompose_mg(self.train_data_all)
        
        self.y_data_all = torch.tensor(np.stack(y_data_all, axis=0)).float()
        self.y_data_all_mg, _ = self.decompose_mg(self.y_data_all)

        self.n_train = self.train_data_all.shape[0]

    def decompose_mg(self, data_all, remove_offset=True):
        data_all = copy.deepcopy(data_all) # T, S, C

        mg = []
        for g in self.period_num:
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = cur.repeat_interleave(repeats=g, dim=1)
            
            mg.append(cur)
            
        mg = torch.stack(mg, dim=0) # G, T, S, C

        if remove_offset:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
        else:
            offset = None
            
        offset = torch.stack(offset, dim=0)
            
        return mg, offset
    
    def periodic_batch_corr(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape
        
        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            ax = cur_data - torch.mean(cur_data, dim=2, keepdim=True)
            
            cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))
            sim.append(cur_sim)
            
        sim = torch.cat(sim, dim=2)
        
        return sim
        
    def retrieve(self, x, index, train=True):
        index = index.to(x.device)
        
        bsz, seq_len, channels = x.shape
        assert(seq_len == self.seq_len, channels == self.channels)
        
        x_mg, mg_offset = self.decompose_mg(x) # G, B, S, C

        sim = self.periodic_batch_corr(
            self.train_data_all_mg.flatten(start_dim=2), # G, T, S * C
            x_mg.flatten(start_dim=2), # G, B, S * C
        ) # G, B, T

        if train:
            sliding_index = torch.arange(2 * (self.seq_len + self.pred_len) - 1).to(x.device)
            sliding_index = sliding_index.unsqueeze(dim=0).repeat(len(index), 1)
            sliding_index = sliding_index + (index - self.seq_len - self.pred_len + 1).unsqueeze(dim=1)
            
            sliding_index = torch.where(sliding_index >= 0, sliding_index, 0)
            sliding_index = torch.where(sliding_index < self.n_train, sliding_index, self.n_train - 1)

            self_mask = torch.zeros((bsz, self.n_train)).to(x.device)
            self_mask = self_mask.scatter_(1, sliding_index, 1.)
            self_mask = self_mask.unsqueeze(dim=0).repeat(self.n_period, 1, 1)
            
            sim = sim.masked_fill_(self_mask.bool(), float('-inf')) # G, B, T

        sim = sim.reshape(self.n_period * bsz, self.n_train) # G X B, T
                
        topm_index = torch.topk(sim, self.topm, dim=1).indices
        ranking_sim = torch.ones_like(sim) * float('-inf')
        
        rows = torch.arange(sim.size(0)).unsqueeze(-1).to(sim.device)
        ranking_sim[rows, topm_index] = sim[rows, topm_index]
        
        sim = sim.reshape(self.n_period, bsz, self.n_train) # G, B, T
        ranking_sim = ranking_sim.reshape(self.n_period, bsz, self.n_train) # G, B, T

        data_len, seq_len, channels = self.train_data_all.shape
            
        ranking_prob = F.softmax(ranking_sim / self.temperature, dim=2)
        ranking_prob = ranking_prob.detach().cpu() # G, B, T
        
        y_data_all = self.y_data_all_mg.flatten(start_dim=2) # G, T, P * C
        
        pred_from_retrieval = torch.bmm(ranking_prob, y_data_all).reshape(self.n_period, bsz, -1, channels)
        pred_from_retrieval = pred_from_retrieval.to(x.device)
        
        return pred_from_retrieval
    
    def retrieve_all(self, data, train=False, device=torch.device('cpu')):
        assert(self.train_data_all_mg != None)
        
        rt_loader = DataLoader(
            data,
            batch_size=1024,
            shuffle=False,
            num_workers=8,
            drop_last=False
        )
        
        retrievals = []
        with torch.no_grad():
            for index, batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(rt_loader):
                pred_from_retrieval = self.retrieve(batch_x.float().to(device), index, train=train)
                pred_from_retrieval = pred_from_retrieval.cpu()
                retrievals.append(pred_from_retrieval)
                
        retrievals = torch.cat(retrievals, dim=1)
        
        return retrievals, None

class FreqRetrievalTool():
    def __init__(
        self,
        seq_len,
        pred_len,
        channels,
        n_period=3,
        temperature=0.1,
        topm=20,
        with_dec=False,
        return_key=False,
    ):
        period_num = [16, 8, 4, 2, 1]
        period_num = period_num[-1 * n_period:]
        
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.channels = channels
        self.step_size = 5
        self.n_period = n_period
        self.period_num = sorted(period_num, reverse=True)
        
        self.temperature = 0.1#temperature
        self.topm = topm
        
        self.with_dec = with_dec
        self.return_key = return_key

        self.sim_by_freq = True
        # self.temperature=0.01
        self.remove_offset = not self.sim_by_freq

        
    def prepare_dataset(self, train_data):
        train_data_all = []
        y_data_all = []

        for i in range(len(train_data)):
            # if i % 5 != 0:
            #     continue
            td = train_data[i]

            train_data_all.append(td[1])
            
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])
        
        print('len of train_data_all: ', len(train_data_all))
        self.train_data_all = torch.tensor(np.stack(train_data_all, axis=0)).float()
        print('train_data_all shape: ', self.train_data_all.shape)
        self.train_data_all_mg, _ = self.decompose_mg(self.train_data_all, label=False)
        print('train_data_all_mg shape: ', self.train_data_all_mg.shape)
        
        self.y_data_all = torch.tensor(np.stack(y_data_all, axis=0)).float()
        self.y_data_all_mg, _ = self.decompose_mg(self.y_data_all, label=True)

        self.n_train = self.train_data_all.shape[0]

    def decompose_mg(self, data_all, remove_offset=True, label=False):
        remove_offset = self.remove_offset
        data_all = copy.deepcopy(data_all) # T, S, C

        mg = []
        for g in self.period_num:
            # print('g: ', g)
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = torch.cat([cur]*g, dim=1)
            ori = cur

            if not label:
                if self.sim_by_freq:
                    cur = torch.fft.rfft(cur, dim=1).abs()[:,1:,:]

                    cur = cur / (torch.sum(cur, dim=1, keepdim=True))
                    if torch.isnan(cur).any():
                        print('NAN AT CUR')
                    # print('cur after: ', torch.sum(cur))
            mg.append(cur)

        mg = torch.stack(mg, dim=0) # G, T, S, C
        # print('mg: ', torch.sum(mg))
        if remove_offset:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
            offset = torch.stack(offset, dim=0)
        else:
            offset = None
            
        # if label and self.sim_by_freq:
        #     # normalize mg
        #     offset = []
        #     for i, data_p in enumerate(mg):
        #         cur_offset = data_p[:,-1:,:]
        #         mg[i] = data_p - cur_offset
        #         offset.append(cur_offset)
        #     offset = torch.stack(offset, dim=0)
        if label:
            mg = F.normalize(mg, dim=1) # normalize the mg to become distribution
        offset = None 
        # print('mg after: ', torch.sum(mg))

        return mg, offset
    
    def periodic_batch_corr(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape
        
        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []

        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            ax = cur_data - torch.mean(cur_data, dim=2, keepdim=True)
            

            cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))  #Covariance 
            sim.append(cur_sim)
            
        sim = torch.cat(sim, dim=2)
        # print('sim: ', sim.shape)
        return sim

    def periodic_batch_corr_by_freq(self, data_all, key, in_bsz=512):
        """
            Compute periodic batch correlation by KL divergence between two normalized frequencies.
        """
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape
        # # print('data_all shape: ', data_all.shape)
        #convert to frequency domain
        # # print('key shape before fft: ', key.shape)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        # print('train_len: ', train_len)
        # print('data_all: ', torch.sum(data_all), data_all.shape)
        for i in range(iters):  
            for_start_time = time.time()
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            ax = data_all[:, start_idx:end_idx].to(key.device)
            # # print('bx shape: ', bx.shape)
            p, ba, d = ax.shape
            # print('ax shape: ', ax.shape)
            # print('key shape: ', key.shape)
            if len(key.shape) == 3:
                p, bb, d = key.shape
            ax = ax.reshape(p, ba, -1, self.channels)
            key = key.reshape(p, bb, -1, self.channels)
            # print('ax: ', torch.sum(ax), ax.shape)
            # print('key: ', torch.sum(key), key.shape)
            # JS divergence
            cur_sim = js_divergence(ax, key)

            sim.append(cur_sim)

        sim = torch.cat(sim, dim=2)
        # print('frequency sim: ', sim.shape)
        return sim
        
    def retrieve(self, x, index, train=True):
        
        
        bsz, seq_len, channels = x.shape
        assert(seq_len == self.seq_len, channels == self.channels)
        

        start = time.time()
        # if train:
        #     x_mg = self.train_data_all_mg[:, index, :, :].to(x.device) # G, B, S, C
        # else:
        x_mg, mg_offset = self.decompose_mg(x, label=False) # G, B, S, C
        

        index = index.to(x.device)
        end = time.time()
        start = time.time()
        # print(' self.train_data_all_mg.flatten(start_dim=2): ',  self.train_data_all_mg.flatten(start_dim=2).shape)
        # print(' x_mg.flatten(start_dim=2): ',  x_mg.flatten(start_dim=2).shape)
        sim = self.periodic_batch_corr_by_freq(
            self.train_data_all_mg.flatten(start_dim=2), # G, T, S * C
            x_mg.flatten(start_dim=2), # G, B, S * C
            in_bsz=256, # batch size for correlation computation
        )
        # check is nan 
        # if torch.isnan(sim).any():
        #     print('sim shape: ', torch.sum(sim))
        # sim=sim.to(x.device)
        # print('sim0: ', sim.shape, torch.sum(sim))
        end = time.time()

        if train:
            sliding_index = torch.arange(2 * (self.seq_len + self.pred_len) - 1).to(x.device)
            # print('sliding_index: ', sliding_index.shape)

            sliding_index = sliding_index.unsqueeze(dim=0).repeat(len(index), 1)
            sliding_index = sliding_index + (index - self.seq_len - self.pred_len + 1).unsqueeze(dim=1)
            
            sliding_index = torch.where(sliding_index >= 0, sliding_index, 0)
            sliding_index = torch.where(sliding_index < self.n_train, sliding_index, self.n_train - 1)

            self_mask = torch.zeros((bsz, self.n_train)).to(x.device)
            self_mask = self_mask.scatter_(1, sliding_index, 1.)
            self_mask = self_mask.unsqueeze(dim=0).repeat(self.n_period, 1, 1)
            # print('sim: ', sim.shape, ' self_mask: ', self_mask.shape)
            sim = sim.masked_fill_(self_mask.bool(), float('-inf')) # G, B, T
        # if torch.isnan(sim).any():
        #     print('sim shape: ', torch.sum(sim))
        sim = sim.reshape(self.n_period * bsz, self.n_train) # G X B, T
        # sim = sim + torch.randn_like(sim).to(sim.device) * 1e-6 # add small noise to avoid numerical issues
        if torch.isnan(sim).any():
            print('sim shape: ', torch.sum(sim))    
        # if torch.isnan(sim).any():
        #     print('ranking_prob shape: ', torch.sum(sim))
        topm_index = torch.topk(sim, self.topm, dim=1).indices
        ranking_sim = torch.ones_like(sim) * float('-inf')
        
        rows = torch.arange(sim.size(0)).unsqueeze(-1).to(sim.device)
        # print('sim: ', sim.shape, ' topm_index: ', topm_index.shape)
        # print('ranking_sim: ', ranking_sim.shape, ' topm_index: ', topm_index.shape)
        # print('rows: ', rows.shape)
        ranking_sim[rows, topm_index] = sim[rows, topm_index]
        
        sim = sim.reshape(self.n_period, bsz, self.n_train) # G, B, T
        ranking_sim = ranking_sim.reshape(self.n_period, bsz, self.n_train) # G, B, T

        data_len, seq_len, channels = self.train_data_all.shape
        if torch.isnan(ranking_sim).any():
            print('ranking_sim shape: ', torch.sum(ranking_sim))
        # print('ranking_sim shape: ', torch.sum(ranking_sim))
        # print('ranking_sim: ', ranking_sim)
        ranking_prob = F.softmax(ranking_sim / self.temperature, dim=2)
        ranking_prob = ranking_prob.detach().cpu() # G, B, T
        
        y_data_all = self.y_data_all_mg.flatten(start_dim=2) # G, T, P * C
        # if torch.isnan(ranking_prob).any():
        #     import pickle as pkl
        #     with open('debug.pkl', 'wb') as f:
        #         pkl.dump({
        #             'ranking_prob': ranking_prob.cpu(),
        #             'y_data_all': y_data_all.cpu(),
        #             'sim': sim.cpu(),
        #             'ranking_sim': ranking_sim.cpu(),
        #             'topm_index': topm_index.cpu(),
        #             'temperature': self.temperature,
        #         }, f)
        #     print('ranking_prob shape: ', torch.sum(ranking_prob))
            # print('topm_index: ', topm_index)
        pred_from_retrieval = torch.bmm(ranking_prob, y_data_all).reshape(self.n_period, bsz, -1, channels)
        pred_from_retrieval = pred_from_retrieval.to(x.device)
        
        return pred_from_retrieval, ranking_prob
    
    def retrieve_all(self, data, train=False, device=torch.device('cpu')):
        assert(self.train_data_all_mg != None)
        
        rt_loader = DataLoader(
            data,
            batch_size=256,
            shuffle=False,
            num_workers=16,
            drop_last=False
        )
        # print(f'len of data train {train}: ', len(data))
        retrievals = []
        ranking_prob_list = []
        with torch.no_grad():
            for index, batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(rt_loader):
                pred_from_retrieval, ranking_prob = self.retrieve(batch_x.float().to(device), index, train=train)
                pred_from_retrieval = pred_from_retrieval.cpu()
                retrievals.append(pred_from_retrieval)
                ranking_prob_list.append(ranking_prob.cpu())
                
        retrievals = torch.cat(retrievals, dim=1)
        print('retrievals: ', torch.sum(retrievals))
        # print('retrievals shape: ', retrievals.shape)
        return retrievals, ranking_prob_list


class FreqRetrievalToolCI():
    def __init__(
        self,
        seq_len,
        pred_len,
        channels,
        n_period=3,
        temperature=0.1,
        topm=20,
        with_dec=False,
        return_key=False    
        ):
        period_num = [16, 8, 4, 2, 1]
        period_num = period_num[-1 * n_period:]
        
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.channels = channels
        self.step_size = 5
        self.n_period = n_period
        self.period_num = sorted(period_num, reverse=True)
        
        self.temperature = 0.1#temperature
        self.topm = topm
        
        self.with_dec = with_dec
        self.return_key = return_key

        self.sim_by_freq = True
        # self.temperature=0.01
        self.remove_offset = not self.sim_by_freq

        
    def prepare_dataset(self, train_data):
        train_data_all = []
        y_data_all = []

        for i in range(len(train_data)):
            # if i % 5 != 0:
            #     continue
            td = train_data[i]

            train_data_all.append(td[1])
            
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])
        
        print('len of train_data_all: ', len(train_data_all))
        self.train_data_all = torch.tensor(np.stack(train_data_all, axis=0)).float()
        print('train_data_all shape: ', self.train_data_all.shape)
        self.train_data_all_mg, _ = self.decompose_mg(self.train_data_all, label=False)
        print('train_data_all_mg shape: ', self.train_data_all_mg.shape)
        
        self.y_data_all = torch.tensor(np.stack(y_data_all, axis=0)).float()
        self.y_data_all_mg, _ = self.decompose_mg(self.y_data_all, label=True)

        self.n_train = self.train_data_all.shape[0]

    def decompose_mg(self, data_all, remove_offset=True, label=False):
        remove_offset = self.remove_offset
        data_all = copy.deepcopy(data_all) # T, S, C

        mg = []
        for g in self.period_num:
            # print('g: ', g)
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = torch.cat([cur]*g, dim=1)
            ori = cur
            # cur = cur.repeat_interleave(repeats=g, dim=1)
            # if g==2:
            #     print('cur: ', cur)
            # print('before cur: ', torch.sum(cur))
            if not label:
                if self.sim_by_freq:
                    # cur = cur / torch.sum(cur, dim=1, keepdim=True) # normalize frequency to become distribution
                    # cur = cur - cur[:,-1:,:]
                    cur = torch.fft.rfft(cur, dim=1).abs()[:,1:,:]
                    sum_cur = torch.sum(cur, dim=1, keepdim=True)
                    # if torch.sum(sum_cur == 0) > 0:
                    #     print('error: ', torch.sum(sum_cur == 0))
                    #     print('cur: ', cur.shape)
                    #     mask = (sum_cur == 0)
                    #     # get original cur at mask position
                    #     cur = torch.where(mask, ori, ori)
                    #     print('cur after: ', torch.sum(cur))
                    # print('cur: ', torch.sum(cur), cur.shape)
                    # if torch.isnan(cur).any():
                    #     print('NAN AT CUR')
                    # print('cur: ', cur.shape)
                    # normalize frequency to become distribution
                    # if torch.isnan(cur/torch.sum(cur, dim=1, keepdim=True)).any():
                    # #     print('error: ', torch.sum(cur, dim=1, keepdim=True))
                    # if torch.isnan(cur).any():
                    #     print('NAN AT CUR')
                    #     print('torch.sum(cur, dim=1, keepdim=True): ',torch.sum(torch.sum(cur, dim=1, keepdim=True)==0))
                    cur = cur / (torch.sum(cur, dim=1, keepdim=True))
                    if torch.isnan(cur).any():
                        print('NAN AT CUR')
                    # print('cur after: ', torch.sum(cur))
            mg.append(cur)

        mg = torch.stack(mg, dim=0) # G, T, S, C
        # print('mg: ', torch.sum(mg))
        if remove_offset:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
            offset = torch.stack(offset, dim=0)
        else:
            offset = None
            
        # if label and self.sim_by_freq:
        #     # normalize mg
        #     offset = []
        #     for i, data_p in enumerate(mg):
        #         cur_offset = data_p[:,-1:,:]
        #         mg[i] = data_p - cur_offset
        #         offset.append(cur_offset)
        #     offset = torch.stack(offset, dim=0)
        if label:
            mg = F.normalize(mg, dim=1) # normalize the mg to become distribution
        offset = None 
        # print('mg after: ', torch.sum(mg))

        return mg, offset
    
    def periodic_batch_corr(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape
        
        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []

        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            ax = cur_data - torch.mean(cur_data, dim=2, keepdim=True)
            

            cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))  #Covariance 
            sim.append(cur_sim)
            
        sim = torch.cat(sim, dim=2)
        # print('sim: ', sim.shape)
        return sim

    def periodic_batch_corr_by_freq(self, data_all, key, in_bsz=512):
        """
            Compute periodic batch correlation by KL divergence between two normalized frequencies.
        """
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape
        # # print('data_all shape: ', data_all.shape)
        #convert to frequency domain
        # # print('key shape before fft: ', key.shape)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        # print('train_len: ', train_len)
        # print('data_all: ', torch.sum(data_all), data_all.shape)
        for i in range(iters):
            for_start_time = time.time()
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            ax = data_all[:, start_idx:end_idx].to(key.device)
            # # print('bx shape: ', bx.shape)
            p, ba, d = ax.shape
            if len(key.shape) == 3:
                p, bb, d = key.shape
            ax = ax.reshape(p, ba, -1, self.channels)
            key = key.reshape(p, bb, -1, self.channels)
            # print('ax: ', torch.sum(ax), ax.shape)
            # print('key: ', torch.sum(key), key.shape)
            # JS divergence
            diver_start_time = time.time()
            cur_sim = js_divergence_CI(ax, key)
            # print('cur_sim: ', torch.sum(cur_sim))
            dict_end_time = time.time()
            # print(f'JS divergence time: {dict_end_time - diver_start_time:.4f}s')
            # cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))
            sim.append(cur_sim)
            for_end_time = time.time()
            # print(f'Batch {i} correlation time: {for_end_time - for_start_time:.4f}s')
            # # print('cur_sim shape: ', cur_sim.shape)
            
        sim = torch.cat(sim, dim=2)
        # print('frequency sim: ', sim.shape)
        return sim
        
    def retrieve(self, x, index, train=True):
        
        
        bsz, seq_len, channels = x.shape
        assert(seq_len == self.seq_len, channels == self.channels)
        

        start = time.time()
        # if train:
        #     x_mg = self.train_data_all_mg[:, index, :, :].to(x.device) # G, B, S, C
        # else:
        x_mg, mg_offset = self.decompose_mg(x, label=False) # G, B, S, C
        

        index = index.to(x.device)
        end = time.time()
        start = time.time()
        # print(' self.train_data_all_mg.flatten(start_dim=2): ',  self.train_data_all_mg.flatten(start_dim=2).shape)
        # print(' x_mg.flatten(start_dim=2): ',  x_mg.flatten(start_dim=2).shape)
        sim = self.periodic_batch_corr_by_freq(
            self.train_data_all_mg.flatten(start_dim=2), # G, T, S * C
            x_mg.flatten(start_dim=2), # G, B, S * C
            in_bsz=256, # batch size for correlation computation
        )
        # check is nan 
        # if torch.isnan(sim).any():
        #     print('sim shape: ', torch.sum(sim))
        # sim=sim.to(x.device)
        # print('sim0: ', sim.shape, torch.sum(sim))
        end = time.time()
        # print('sim: ', sim.shape)
        if train:
            sliding_index = torch.arange(2 * (self.seq_len + self.pred_len) - 1).to(x.device)
            # print('sliding_index: ', sliding_index.shape)

            sliding_index = sliding_index.unsqueeze(dim=0).repeat(len(index), 1)
            sliding_index = sliding_index + (index - self.seq_len - self.pred_len + 1).unsqueeze(dim=1)
            
            sliding_index = torch.where(sliding_index >= 0, sliding_index, 0)
            sliding_index = torch.where(sliding_index < self.n_train, sliding_index, self.n_train - 1)

            self_mask = torch.zeros((bsz, self.n_train)).to(x.device)
            self_mask = self_mask.scatter_(1, sliding_index, 1.)
            self_mask = self_mask.unsqueeze(dim=0).unsqueeze(dim=-1).repeat(self.n_period, 1, 1, self.channels)
            # print('sim: ', sim.shape, ' self_mask: ', self_mask.shape)
            # print('self_mask: ', self_mask.shape)
            sim = sim.masked_fill_(self_mask.bool(), float('-inf')) # G, B, T
        # if torch.isnan(sim).any():
        #     print('sim shape: ', torch.sum(sim))
        sim = sim.reshape(self.n_period * bsz, self.n_train, self.channels) # G X B, T
        # sim = sim + torch.randn_like(sim).to(sim.device) * 1e-6 # add small noise to avoid numerical issues
        if torch.isnan(sim).any():
            print('sim shape: ', torch.sum(sim))    
        # if torch.isnan(sim).any():
        #     print('ranking_prob shape: ', torch.sum(sim))
        topm_index = torch.topk(sim, self.topm, dim=1).indices
        ranking_sim = torch.ones_like(sim) * float('-inf')
        
        # rows = torch.arange(sim.size(0)).unsqueeze(-1).to(sim.device)
        # print('sim: ', sim.shape, ' topm_index: ', topm_index.shape)
        # print('ranking_sim: ', ranking_sim.shape, ' topm_index: ', topm_index.shape)
        # print('rows: ', rows.shape)
        # ranking_sim[rows, topm_index] = sim[rows, topm_index]
        rows = torch.arange(sim.size(0), device=sim.device).unsqueeze(1).unsqueeze(2)  # [768, 1, 1]
        rows = rows.expand(-1, topm_index.size(1), topm_index.size(2))  # [768, 20, 7]

        # Step 2: Use advanced indexing
        ranking_sim[rows, topm_index, torch.arange(sim.size(2)).view(1, 1, -1)] = \
            sim[rows, topm_index, torch.arange(sim.size(2)).view(1, 1, -1)]
        
        # print('sim: ', sim.shape, ' ranking_sim: ', ranking_sim.shape)
        sim = sim.reshape(self.n_period, bsz, self.n_train, self.channels) # G, B, T
        ranking_sim = ranking_sim.reshape(self.n_period, bsz, self.n_train, self.channels) # G, B, T

        data_len, seq_len, channels = self.train_data_all.shape
        if torch.isnan(ranking_sim).any():
            print('ranking_sim shape: ', torch.sum(ranking_sim))
        # print('ranking_sim shape: ', torch.sum(ranking_sim))
        # print('ranking_sim: ', ranking_sim)
        ranking_prob = F.softmax(ranking_sim / self.temperature, dim=2)
        ranking_prob = ranking_prob.detach().cpu() # G, B, T
        
        y_data_all = self.y_data_all_mg#.flatten(start_dim=2) # G, T, P * C
        # if torch.isnan(ranking_prob).any():
        #     import pickle as pkl
        #     with open('debug.pkl', 'wb') as f:
        #         pkl.dump({
        #             'ranking_prob': ranking_prob.cpu(),
        #             'y_data_all': y_data_all.cpu(),
        #             'sim': sim.cpu(),
        #             'ranking_sim': ranking_sim.cpu(),
        #             'topm_index': topm_index.cpu(),
        #             'temperature': self.temperature,
        #         }, f)
        #     print('ranking_prob shape: ', torch.sum(ranking_prob))
            # print('topm_index: ', topm_index)
        # print('ranking_prob: ', ranking_prob.shape)
        # print('y_data_all: ', y_data_all.shape)
        # pred_from_retrieval = torch.bmm(ranking_prob, y_data_all).reshape(self.n_period, bsz, -1, channels)
        pred_from_retrieval = torch.einsum('bikp,bkjp->bijp', ranking_prob, y_data_all)

        pred_from_retrieval = pred_from_retrieval.to(x.device)
        
        return pred_from_retrieval, ranking_prob
    
    def retrieve_all(self, data, train=False, device=torch.device('cpu')):
        assert(self.train_data_all_mg != None)
        
        rt_loader = DataLoader(
            data,
            batch_size=256,
            shuffle=False,
            num_workers=16,
            drop_last=False
        )
        # print(f'len of data train {train}: ', len(data))
        retrievals = []
        ranking_prob_list = []
        with torch.no_grad():
            for index, batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(rt_loader):
                pred_from_retrieval, ranking_prob = self.retrieve(batch_x.float().to(device), index, train=train)
                pred_from_retrieval = pred_from_retrieval.cpu()
                retrievals.append(pred_from_retrieval)
                ranking_prob_list.append(ranking_prob.cpu())
                
        retrievals = torch.cat(retrievals, dim=1)
        print('retrievals: ', torch.sum(retrievals))
        # print('retrievals shape: ', retrievals.shape)
        return retrievals, ranking_prob_list


class TextRetrievalTool():
    def __init__(
        self,
        seq_len,
        pred_len,
        channels,
        n_period=3,
        temperature=0.1,
        topm=20,
        with_dec=False,
        return_key=False,
    ):
        period_num = [16, 8, 4, 2, 1]
        period_num = period_num[-1 * n_period:]
        
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.channels = channels
        
        self.n_period = n_period
        self.period_num = sorted(period_num, reverse=True)
        
        self.temperature = temperature
        self.topm = topm
        
        self.with_dec = with_dec
        self.return_key = return_key

    def prepare_dataset(self, train_data):
        train_data_all = []
        y_data_all = []

        for i in range(len(train_data)):
            td = train_data[i]
            train_data_all.append(td[1])
            
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])
            
        self.train_data_all = torch.tensor(np.stack(train_data_all, axis=0)).float()
        self.train_data_all_mg, _ = self.decompose_mg(self.train_data_all)
        
        self.y_data_all = torch.tensor(np.stack(y_data_all, axis=0)).float()
        self.y_data_all_mg, _ = self.decompose_mg(self.y_data_all)

        self.n_train = self.train_data_all.shape[0]

        self.prepare_text(train_data)

    def prepare_text(self, train_data):
        train_data_all = []
        y_data_all = []
        
        for i in range(len(train_data)):
            td = train_data[i]
            index = td[0]

            prompt_embeddings_batch = train_data.get_text_embedding(torch.tensor([index]))
            prompt_embeddings_batch = torch.FloatTensor(prompt_embeddings_batch).to('cuda')
            # prompt_embeddings_batch = prompt_embeddings_batch.mean(dim=1, keepdim=False)  # Average the embeddings
            # # print('prompt_embeddings_batch shape: ', prompt_embeddings_batch.shape)
            prompt_embeddings_batch = prompt_embeddings_batch[:,-1,:]
            # prompt_embeddings_batch = F.adaptive_max_pool1d(prompt_embeddings_batch.transpose(1,2), 1).squeeze(-1)
            # # print('prompt_embeddings_batch after adaptive_max_pool1d: ', prompt_embeddings_batch.shape)
            # # print('prompt_embeddings_batch: ', prompt_embeddings_batch.shape)
            train_data_all.append(prompt_embeddings_batch)
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])

        self.train_data_all_text = torch.cat(train_data_all, dim=0)

        self.train_data_all_mg_text = self.train_data_all_text.unsqueeze(0)  # Add a dummy dimension for n_period
        # # print('train_data_all_mg_text1: ', self.train_data_all_mg_text.shape)

    def decompose_mg(self, data_all, remove_offset=True):
        data_all = copy.deepcopy(data_all) # T, S, C

        mg = []
        for g in self.period_num:
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = cur.repeat_interleave(repeats=g, dim=1)
            
            mg.append(cur)
            
        mg = torch.stack(mg, dim=0) # G, T, S, C

        if remove_offset:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
        else:
            offset = None
            
        offset = torch.stack(offset, dim=0)
            
        return mg, offset
    
    def periodic_batch_corr_text(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        # # print('data_all: ' , data_all.shape)
        _, train_len, _ = data_all.shape
        
        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            # # print('cur_data: ', cur_data.shape) #torch.Size([1, 512, 768])
            # # print('key: ', key.shape) #torch.Size([256, 1, 768])
            # compute cosine similarity between key and cur_data
            cosine_sim = F.cosine_similarity(key, cur_data, dim=2)  # [1, 512, 768] vs [1, 512, 768]
            # # print('cosine_sim: ', cosine_sim.shape) #torch.Size([1, 512])
            
            sim.append(cosine_sim)
            
        sim = torch.cat(sim, dim=1)
        sim = sim.unsqueeze(0)  # Add a dummy dimension for n_period
        
        return sim
    
    def periodic_batch_corr(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape

        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            ax = cur_data - torch.mean(cur_data, dim=2, keepdim=True)
            
            cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))
            sim.append(cur_sim)
            
        sim = torch.cat(sim, dim=2)
        
        return sim

    def retrieve(self, x, text, index, train=True):
        index = index.to(x.device)
        
        bsz, seq_len, channels = x.shape  # [256, 24, 768]
        assert(seq_len == self.seq_len, channels == self.channels)

        x_mg, mg_offset = self.decompose_mg(x) # G, B, S, C

        sim = self.periodic_batch_corr(
            self.train_data_all_mg.flatten(start_dim=2), # G, T, S * C
            x_mg.flatten(start_dim=2), # G, B, S * C
        ) # G, B, T


        sim_text = self.periodic_batch_corr_text(
            self.train_data_all_mg_text, # G, T, S * C
            text, # G, B, S * C
        ) # G, B, T
        
        # normalize the similarity scores
        # sim = F.normalize(sim, dim=2)  # Normalize the similarity scores
        # sim_text = F.normalize(sim_text, dim=2)  # Normalize the text similarity
        # print min max of sim and sim_text
        # # print('sim min: ', torch.min(sim), ' sim max: ', torch.max(sim))
        # # print('sim_text min: ', torch.min(sim_text), ' sim_text max: ', torch.max(sim_text))
        sim = sim + sim_text / 10  # Combine the similarity scores from both modalities
        # sim = sim + sim_text  # Combine the similarity scores from both modalities
        # sim = sim 
        if train:
            sliding_index = torch.arange(2 * (self.seq_len + self.pred_len) - 1).to(x.device)
            sliding_index = sliding_index.unsqueeze(dim=0).repeat(len(index), 1)
            # print('sliding_index:   ', sliding_index.shape)
            # print('index:           ', index.shape)
            # print('self.seq_len:    ', self.seq_len)
            # print('self.pred_len:   ', self.pred_len)
            sliding_index = sliding_index + (index - self.seq_len - self.pred_len + 1).unsqueeze(dim=1)
            
            sliding_index = torch.where(sliding_index >= 0, sliding_index, 0)
            sliding_index = torch.where(sliding_index < self.n_train, sliding_index, self.n_train - 1)

            self_mask = torch.zeros((bsz, self.n_train)).to(x.device)
            self_mask = self_mask.scatter_(1, sliding_index, 1.)
            self_mask = self_mask.unsqueeze(dim=0).repeat(self.n_period, 1, 1)
            
            sim = sim.masked_fill_(self_mask.bool(), float('-inf')) # G, B, T

        sim = sim.reshape(self.n_period * bsz, self.n_train) # G X B, T
                
        topm_index = torch.topk(sim, self.topm, dim=1).indices
        ranking_sim = torch.ones_like(sim) * float('-inf')
        
        rows = torch.arange(sim.size(0)).unsqueeze(-1).to(sim.device)
        ranking_sim[rows, topm_index] = sim[rows, topm_index]
        
        sim = sim.reshape(self.n_period, bsz, self.n_train) # G, B, T
        ranking_sim = ranking_sim.reshape(self.n_period, bsz, self.n_train) # G, B, T

        train_size, channels, _ = self.y_data_all.shape
        # # print('self.train_data_all: ', self.y_data_all.shape)
            
        ranking_prob = F.softmax(ranking_sim / self.temperature, dim=2)
        ranking_prob = ranking_prob.detach().cpu() # G, B, T
        
        y_data_all = self.y_data_all_mg.flatten(start_dim=2) # G, T, P * C
        # # print('pred_from_retrieval: ', torch.bmm(ranking_prob, y_data_all).shape)
        pred_from_retrieval = torch.bmm(ranking_prob, y_data_all).reshape(self.n_period, bsz, -1, channels)
        
        pred_from_retrieval = pred_from_retrieval.to(x.device)
        
        return pred_from_retrieval
    
    def retrieve_all(self, data, train=False, device=torch.device('cpu')):
        # assert(self.train_data_all_mg != None)
        
        rt_loader = DataLoader(
            data,
            batch_size=16,
            shuffle=False,
            num_workers=8,
            drop_last=False
        )
        retrievals = []
        with torch.no_grad():
            for index, batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(rt_loader):
                
                batch_text = data.get_text_embedding(index)
                batch_text = torch.FloatTensor(batch_text).to(batch_y.device)
                batch_text = batch_text.mean(dim=1, keepdim=True)  # Average the embeddings
                pred_from_retrieval = self.retrieve(batch_x.float().to(device), batch_text.float().to(device), index, train=train)
                pred_from_retrieval = pred_from_retrieval.cpu()
                retrievals.append(pred_from_retrieval)
                
        retrievals = torch.cat(retrievals, dim=1)
        return retrievals
    


class TextRetrievalTool2():
    def __init__(
        self,
        seq_len,
        pred_len,
        channels,
        n_period=3,
        temperature=0.1,
        topm=20,
        with_dec=False,
        return_key=False,
    ):
        period_num = [16, 8, 4, 2, 1]
        period_num = period_num[-1 * n_period:]
        
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.channels = channels
        
        self.n_period = n_period
        self.period_num = sorted(period_num, reverse=True)
        
        self.temperature = temperature
        self.topm = topm
        
        self.with_dec = with_dec
        self.return_key = return_key

    def prepare_dataset(self, train_data):
        train_data_all = []
        y_data_all = []

        for i in range(len(train_data)):
            td = train_data[i]
            train_data_all.append(td[1])
            
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])
            
        self.train_data_all = torch.tensor(np.stack(train_data_all, axis=0)).float()
        self.train_data_all_mg, _ = self.decompose_mg(self.train_data_all)
        
        self.y_data_all = torch.tensor(np.stack(y_data_all, axis=0)).float()
        self.y_data_all_mg, _ = self.decompose_mg(self.y_data_all)

        self.n_train = self.train_data_all.shape[0]

        self.prepare_text(train_data)

    def prepare_text(self, train_data):
        train_data_all = []
        y_data_all = []
        
        for i in range(len(train_data)):
            td = train_data[i]
            index = td[0]

            prompt_embeddings_batch = train_data.get_text_embedding(torch.tensor([index]))
            prompt_embeddings_batch = torch.FloatTensor(prompt_embeddings_batch).to('cuda')
            # prompt_embeddings_batch = prompt_embeddings_batch.mean(dim=1, keepdim=False)  # Average the embeddings
            # # print('prompt_embeddings_batch shape: ', prompt_embeddings_batch.shape)
            prompt_embeddings_batch = prompt_embeddings_batch[:,-1,:]
            # prompt_embeddings_batch = F.adaptive_max_pool1d(prompt_embeddings_batch.transpose(1,2), 1).squeeze(-1)
            # # print('prompt_embeddings_batch after adaptive_max_pool1d: ', prompt_embeddings_batch.shape)
            # # print('prompt_embeddings_batch: ', prompt_embeddings_batch.shape)
            train_data_all.append(prompt_embeddings_batch)
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])

        self.train_data_all_text = torch.cat(train_data_all, dim=0)

        self.train_data_all_mg_text = self.train_data_all_text.unsqueeze(0)  # Add a dummy dimension for n_period
        # # print('train_data_all_mg_text1: ', self.train_data_all_mg_text.shape)

    def decompose_mg(self, data_all, remove_offset=True):
        data_all = copy.deepcopy(data_all) # T, S, C

        mg = []
        for g in self.period_num:
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = cur.repeat_interleave(repeats=g, dim=1)
            
            mg.append(cur)
            
        mg = torch.stack(mg, dim=0) # G, T, S, C

        if remove_offset:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
        else:
            offset = None
            
        offset = torch.stack(offset, dim=0)
            
        return mg, offset
    
    def periodic_batch_corr_text(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        # # print('data_all: ' , data_all.shape)
        _, train_len, _ = data_all.shape
        
        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            # # print('cur_data: ', cur_data.shape) #torch.Size([1, 512, 768])
            # # print('key: ', key.shape) #torch.Size([256, 1, 768])
            # compute cosine similarity between key and cur_data
            cosine_sim = F.cosine_similarity(key, cur_data, dim=2)  # [1, 512, 768] vs [1, 512, 768]
            # # print('cosine_sim: ', cosine_sim.shape) #torch.Size([1, 512])
            
            sim.append(cosine_sim)
            
        sim = torch.cat(sim, dim=1)
        sim = sim.unsqueeze(0)  # Add a dummy dimension for n_period
        
        return sim
    
    def periodic_batch_corr(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape

        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            ax = cur_data - torch.mean(cur_data, dim=2, keepdim=True)
            
            cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))
            sim.append(cur_sim)
            
        sim = torch.cat(sim, dim=2)
        
        return sim

    def retrieve(self, x, text, index, train=True):
        index = index.to(x.device)
        
        bsz, seq_len, channels = x.shape  # [256, 24, 768]
        assert(seq_len == self.seq_len, channels == self.channels)

        x_mg, mg_offset = self.decompose_mg(x) # G, B, S, C

        sim = self.periodic_batch_corr(
            self.train_data_all_mg.flatten(start_dim=2), # G, T, S * C
            x_mg.flatten(start_dim=2), # G, B, S * C
        ) # G, B, T


        sim_text = self.periodic_batch_corr_text(
            self.train_data_all_mg_text, # G, T, S * C
            text, # G, B, S * C
        ) # G, B, T
        
        # normalize the similarity scores
        # sim = F.normalize(sim, dim=2)  # Normalize the similarity scores
        # sim_text = F.normalize(sim_text, dim=2)  # Normalize the text similarity
        # print min max of sim and sim_text
        # # print('sim min: ', torch.min(sim), ' sim max: ', torch.max(sim))
        # # print('sim_text min: ', torch.min(sim_text), ' sim_text max: ', torch.max(sim_text))
        sim = sim + sim_text / 10  # Combine the similarity scores from both modalities
        # sim = sim + sim_text  # Combine the similarity scores from both modalities
        # sim = sim 
        if train:
            sliding_index = torch.arange(2 * (self.seq_len + self.pred_len) - 1).to(x.device)
            sliding_index = sliding_index.unsqueeze(dim=0).repeat(len(index), 1)
            # print('sliding_index:   ', sliding_index.shape)
            # print('index:           ', index.shape)
            # print('self.seq_len:    ', self.seq_len)
            # print('self.pred_len:   ', self.pred_len)
            sliding_index = sliding_index + (index - self.seq_len - self.pred_len + 1).unsqueeze(dim=1)
            
            sliding_index = torch.where(sliding_index >= 0, sliding_index, 0)
            sliding_index = torch.where(sliding_index < self.n_train, sliding_index, self.n_train - 1)

            self_mask = torch.zeros((bsz, self.n_train)).to(x.device)
            self_mask = self_mask.scatter_(1, sliding_index, 1.)
            self_mask = self_mask.unsqueeze(dim=0).repeat(self.n_period, 1, 1)
            
            sim = sim.masked_fill_(self_mask.bool(), float('-inf')) # G, B, T

        sim = sim.reshape(self.n_period * bsz, self.n_train) # G X B, T
                
        topm_index = torch.topk(sim, self.topm, dim=1).indices
        ranking_sim = torch.ones_like(sim) * float('-inf')
        
        rows = torch.arange(sim.size(0)).unsqueeze(-1).to(sim.device)
        ranking_sim[rows, topm_index] = sim[rows, topm_index]
        
        sim = sim.reshape(self.n_period, bsz, self.n_train) # G, B, T
        ranking_sim = ranking_sim.reshape(self.n_period, bsz, self.n_train) # G, B, T

        train_size, channels, _ = self.y_data_all.shape
        # # print('self.train_data_all: ', self.y_data_all.shape)
            
        ranking_prob = F.softmax(ranking_sim / self.temperature, dim=2)
        ranking_prob = ranking_prob.detach().cpu() # G, B, T
        
        y_data_all = self.y_data_all_mg.flatten(start_dim=2) # G, T, P * C
        # # print('pred_from_retrieval: ', torch.bmm(ranking_prob, y_data_all).shape)
        pred_from_retrieval = torch.bmm(ranking_prob, y_data_all).reshape(self.n_period, bsz, -1, channels)
        
        pred_from_retrieval = pred_from_retrieval.to(x.device)

        # print('ranking_prob shape: ', ranking_prob.shape)
        # print('self.train_data_all_mg_text shape: ', self.train_data_all_mg_text.shape)
        text_embed = torch.bmm(ranking_prob, self.train_data_all_mg_text).reshape(self.n_period, bsz, -1, self.channels)
        text_embed = text_embed.to(x.device)
        return pred_from_retrieval, ranking_prob
    
    def retrieve_all(self, data, train=False, device=torch.device('cpu')):
        # assert(self.train_data_all_mg != None)
        
        rt_loader = DataLoader(
            data,
            batch_size=256,
            shuffle=False,
            num_workers=8,
            drop_last=False
        )
        retrievals = []
        with torch.no_grad():
            for index, batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(rt_loader):
                
                batch_text = data.get_text_embedding(index)
                batch_text = torch.FloatTensor(batch_text).to(batch_y.device)
                batch_text = batch_text.mean(dim=1, keepdim=True)  # Average the embeddings
                pred_from_retrieval = self.retrieve(batch_x.float().to(device), batch_text.float().to(device), index, train=train)
                pred_from_retrieval = pred_from_retrieval.cpu()
                retrievals.append(pred_from_retrieval)
                
        retrievals = torch.cat(retrievals, dim=1)
        return retrievals
    

