import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import math
from tqdm import tqdm
import time
from torch.utils.data import Dataset, DataLoader

def kl_divergence(ax, bx):
     # KL divergence
    log_ax = torch.log(ax + 1e-8)  # Avoid log(0) #[3,512, 2000]
    log_bx = torch.log(bx + 1e-8)  # Avoid

    sub = log_ax.unsqueeze(2) - log_bx.unsqueeze(1) # 3, 512, 1024, 2000
    cur_sim = torch.sum(ax.unsqueeze(2) * sub, dim=3)  # 3, 512, 1024
    cur_sim = cur_sim.mean(dim=3)  # 3, 512, 1024
    cur_sim = cur_sim.transpose(1,2)
    return cur_sim
def js_divergence(ax, bx):
    
    p = (ax.unsqueeze(2) + bx.unsqueeze(1)) / 2
    log_p = torch.log(p + 1e-8)  # Avoid log(0)
    
    cur_sim = (ax.unsqueeze(2) * (torch.log(ax.unsqueeze(2) + 1e-8) - log_p)).sum(dim=3) + \
                (bx.unsqueeze(1) * (torch.log(bx.unsqueeze(1) + 1e-8) - log_p)).sum(dim=3)
    cur_sim = cur_sim.mean(dim=3)
    cur_sim = cur_sim.transpose(1,2)
    return cur_sim

def js_divergence_CI(ax, bx):
    # print('ax: ', ax.shape, ' bx: ', bx.shape)
    p = (ax.unsqueeze(2) + bx.unsqueeze(1)) / 2
    log_p = torch.log(p + 1e-8)  # Avoid log(0)
    
    # print('log_p: ', log_p.shape)
    diver = ((ax.unsqueeze(2) * (torch.log(ax.unsqueeze(2) + 1e-8) - log_p)).sum(dim=3) + \
                (bx.unsqueeze(1) * (torch.log(bx.unsqueeze(1) + 1e-8) - log_p)).sum(dim=3))/2
    diver = diver #/ log_p.shape[3]
    # print('diver: ', diver.shape)
    # print('diver: ', diver)
    # cur_sim = 1 / (diver+1e-6)
    # diver = torch.log(torch.exp(diver)/4)
    cur_sim = -diver
    
    cur_sim = cur_sim.transpose(1,2)
    # print('cur_sim: ', cur_sim.shape)
    # cur_sim = (cur_sim - 0) / (cur_sim.max(dim=2, keepdim=True)[0] - 0)*2-1  # Normalize
    return cur_sim

    
class FreqL1RetrievalTool():
    def __init__(
        self,
        seq_len,
        pred_len,
        channels,
        n_period=3,
        temperature=0.1,
        topm=20,
        with_dec=False,
        return_key=False,
    ):
        period_num = [16, 8, 4, 2, 1]
        period_num = period_num[-1 * n_period:]
        
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.channels = channels
        
        self.n_period = n_period
        self.period_num = sorted(period_num, reverse=True)
        
        self.temperature = temperature
        self.topm = topm
        
        self.with_dec = with_dec
        self.return_key = return_key
        
    def prepare_dataset(self, train_data):
        train_data_all = []
        y_data_all = []

        for i in range(len(train_data)):
            td = train_data[i]
            train_data_all.append(td[1])
            
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])
            
        self.train_data_all = torch.tensor(np.stack(train_data_all, axis=0)).float()
        self.train_data_all_mg, _ = self.decompose_mg(self.train_data_all, label=False)
        
        self.y_data_all = torch.tensor(np.stack(y_data_all, axis=0)).float()
        self.y_data_all_mg, _ = self.decompose_mg(self.y_data_all)

        self.n_train = self.train_data_all.shape[0]

    def decompose_mg(self, data_all, remove_offset=True, label=True):
        # data_all = copy.deepcopy(data_all) # T, S, C

        mg = []
        for g in self.period_num:
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = cur.repeat_interleave(repeats=g, dim=1)
            if not label:
                # if self.sim_by_freq:
                cur = torch.fft.rfft(cur, dim=1)[:,1:,:]
                # cur = torch.cat([cur.real, cur.imag], dim=1) # convert to real and imaginary parts
                # print('cur: ', cur.shape)
            mg.append(cur)
            
        mg = torch.stack(mg, dim=0) # G, T, S, C

        if remove_offset and label:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
                offset = torch.stack(offset, dim=0)
        else:
            offset = None
            
        
        # print('mg: ', mg.shape)
        return mg, offset
    
    def periodic_batch_corr(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape
        
        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            ax = cur_data

            # print('ax: ', ax.shape)
            # print('bx: ', bx.shape)
            # ax = cur_data - torch.mean(cur_data, dim=2, keepdim=True)
            # p, ba, d = ax.shape
            # if len(key.shape) == 3:
            #     p, bb, d = key.shape
            # ax = ax.reshape(p, ba, -1, self.channels)
            # bx = key.reshape(p, bb, -1, self.channels)
            # print('ax: ', ax.shape, ' bx: ', bx.shape)
            ax = ax.unsqueeze(1).repeat(1, bx.shape[1], 1, 1)
            bx = key.unsqueeze(2).repeat(1, 1, ax.shape[2], 1)
            # print('ax: ', ax.shape, ' bx: ', bx.shape)
            cur_sim = 1/(bx-ax).abs().mean(dim=-1)*math.sqrt(ax.shape[-1])
            # print('cur_sim: ', cur_sim.shape)
            # cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))
            sim.append(cur_sim)
            
        sim = torch.cat(sim, dim=2)
        
        return sim
        
    def retrieve(self, x, index, train=True):
        index = index.to(x.device)
        
        bsz, seq_len, channels = x.shape
        assert(seq_len == self.seq_len, channels == self.channels)
        
        x_mg, mg_offset = self.decompose_mg(x, label=False) # G, B, S, C

        sim = self.periodic_batch_corr(
            self.train_data_all_mg.flatten(start_dim=2), # G, T, S * C
            x_mg.flatten(start_dim=2), # G, B, S * C
        ) # G, B, T

        if train:
            sliding_index = torch.arange(2 * (self.seq_len + self.pred_len) - 1).to(x.device)
            sliding_index = sliding_index.unsqueeze(dim=0).repeat(len(index), 1)
            sliding_index = sliding_index + (index - self.seq_len - self.pred_len + 1).unsqueeze(dim=1)
            
            sliding_index = torch.where(sliding_index >= 0, sliding_index, 0)
            sliding_index = torch.where(sliding_index < self.n_train, sliding_index, self.n_train - 1)

            self_mask = torch.zeros((bsz, self.n_train)).to(x.device)
            self_mask = self_mask.scatter_(1, sliding_index, 1.)
            self_mask = self_mask.unsqueeze(dim=0).repeat(self.n_period, 1, 1)
            
            sim = sim.masked_fill_(self_mask.bool(), float('-inf')) # G, B, T

        sim = sim.reshape(self.n_period * bsz, self.n_train) # G X B, T

        topm_index = torch.topk(sim, self.topm, dim=1).indices
        ranking_sim = torch.ones_like(sim) * float('-inf')
        
        rows = torch.arange(sim.size(0)).unsqueeze(-1).to(sim.device)
        ranking_sim[rows, topm_index] = sim[rows, topm_index]
        
        sim = sim.reshape(self.n_period, bsz, self.n_train) # G, B, T
        ranking_sim = ranking_sim.reshape(self.n_period, bsz, self.n_train) # G, B, T

        data_len, seq_len, channels = self.train_data_all.shape
            
        ranking_prob = F.softmax(ranking_sim / self.temperature, dim=2)
        ranking_prob = ranking_prob.detach().cpu() # G, B, T
        
        y_data_all = self.y_data_all_mg.flatten(start_dim=2) # G, T, P * C
        
        pred_from_retrieval = torch.bmm(ranking_prob, y_data_all).reshape(self.n_period, bsz, -1, channels)
        pred_from_retrieval = pred_from_retrieval.to(x.device)
        
        return pred_from_retrieval
    
    def retrieve_all(self, data, train=False, device=torch.device('cpu')):
        assert(self.train_data_all_mg != None)
        
        rt_loader = DataLoader(
            data,
            batch_size=128,
            shuffle=False,
            num_workers=8,
            drop_last=False
        )
        
        retrievals = []
        with torch.no_grad():
            for index, batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(rt_loader):
                pred_from_retrieval = self.retrieve(batch_x.float().to(device), index, train=train)
                pred_from_retrieval = pred_from_retrieval.cpu()
                retrievals.append(pred_from_retrieval)
                
        retrievals = torch.cat(retrievals, dim=1)
        
        return retrievals, None
    


    




    


class STFT1RetrievalToolCI():
    def __init__(
        self,
        args,
        seq_len,
        pred_len,
        channels,
        n_period=3,
        temperature=0.1,
        topm=20,
        with_dec=False,
        return_key=False,
    ):
        period_num = [16, 8, 4, 2, 1]
        period_num = period_num[-1 * n_period:]
        
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.channels = channels
        
        self.n_period = n_period
        self.period_num = sorted(period_num, reverse=True)
        
        self.temperature = temperature
        self.topm = topm
        
        self.with_dec = with_dec
        self.return_key = return_key
        
        self.window_size = args.window_size
        self.hop_size = args.hop_size
        self.temperature = 0.1
        self.args = args
        # self.window = torch.hann_window(self.window_size).cuda()
        self.window = torch.ones(self.window_size)#.cuda()



    def prepare_dataset(self, train_data):
        train_data_all = []
        y_data_all = []

        for i in range(len(train_data)):
            td = train_data[i]
            train_data_all.append(td[1])
            
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])
            
        self.train_data_all = torch.tensor(np.stack(train_data_all, axis=0)).float()
        self.train_data_all_mg, _ = self.decompose_mg(self.train_data_all, label=False)
        
        self.y_data_all = torch.tensor(np.stack(y_data_all, axis=0)).float()
        self.y_data_all_mg, _ = self.decompose_mg(self.y_data_all)

        self.n_train = self.train_data_all.shape[0]

    def decompose_mg(self, data_all, remove_offset=True, label=True):
        # data_all = copy.deepcopy(data_all) # T, S, C
        data_all = data_all.cpu()#.cuda()
        mg = []
        for g in self.period_num:
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = cur.repeat_interleave(repeats=g, dim=1)
            if not label:
                # cur = cur/torch.sum(cur, dim=1, keepdim=True)
                # cur = (cur-cur.min(dim=1, keepdim=True)[0])/ (cur.max(dim=1, keepdim=True)[0]-cur.min(dim=1, keepdim=True)[0])
                # if self.sim_by_freq:
                # cur = torch.fft.rfft(cur, dim=1)[:,1:,:]
                B, S, C = cur.shape
                cur = cur.permute(0, 2, 1).reshape(-1, S)
                cur = torch.stft(
                    cur,
                    n_fft=self.window_size,
                    hop_length=self.hop_size,
                    window=self.window,
                    return_complex=True
                )
                cur = cur.reshape(B, C, cur.shape[-2], cur.shape[-1])  # B, C, T, F
                cur = cur.permute(0, 2, 3, 1)
                
                cur = cur[:,1:,:,:]

            mg.append(cur)
            
        mg = torch.stack(mg, dim=0) # G, T, S, C

        if remove_offset and label:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
            offset = torch.stack(offset, dim=0)
        else:
            offset = None
            
        
        # print('mg: ', mg.shape)
        return mg, offset
    
    def periodic_batch_corr(self, data_all, key, in_bsz = 8):
        # _, bsz, features = key.shape
        _, train_len, _, _, _ = data_all.shape
        in_bsz = self.args.bz1
        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            ax = cur_data
            ax = ax.cuda()
            key = key.cuda()
            # print('ax: ', ax.shape)
            # print('bx: ', bx.shape)
            # ax = cur_data - torch.mean(cur_data, dim=2, keepdim=True)
            # p, ba, d = ax.shape
            # if len(key.shape) == 3:
                # p, bb, d = key.shape
            # ax = ax.reshape(p, ba, -1, self.channels)
            # bx = key.reshape(p, bb, -1, self.channels)
            g, ba, f, p, d = ax.shape
            g, bb, f, p, d = key.shape
            # ax = ax.reshape(g, ba, f, -1)
            # bx = key.reshape(g, ba, f, -1)

            # print('ax: ', ax.shape, ' key: ', key.shape)
            
            # L1 distance
            # ax = ax.unsqueeze(1).repeat(1, bx.shape[1], 1, 1, 1, 1)
            # bx = key.unsqueeze(2).repeat(1, 1, ax.shape[2], 1, 1, 1)
            # l1 = (bx-ax).abs().mean(dim=3)
            # cur_sim = 1 / (l1+1)

            # ax = ax.reshape(g, ba, f, -1)
            # bx = bx.reshape(g, bb, f, -1)
            # print('ax: ', ax.shape, ' key: ', key.shape)
            cur_sim_1 = js_divergence_CI(ax.abs()/(ax.abs().sum(dim=2, keepdim=True)+1e-6), key.abs()/(key.abs().sum(dim=2, keepdim=True)+1e-6))  # G, B, T, F
            cur_sim_2 = -1 + torch.cos(torch.abs(torch.angle(ax).unsqueeze(2) - torch.angle(key).unsqueeze(1))).mean(dim=3).transpose(1,2)  # G, B, T, F
            # print('cur_sim_1: ', cur_sim_1.shape, ' cur_sim_2: ', cur_sim_2.shape)
            # print(cur_sim_2)
            cur_sim = cur_sim_1/5 + cur_sim_2/5
            # temporal weighting
            # print('cur_sim: ', cur_sim.shape)

            # t = torch.arange(cur_sim.shape[3]).cuda()+1
            # t = 1 / torch.flip(t, dims=[0])
            # t = t / torch.sum(t)
            # t = t.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
            
            alpha = self.args.alpha
            N = cur_sim.shape[3]
            i = torch.arange(N).to(cur_sim.device)+1
            weights = (1-alpha)**(N-i)
            weights = weights / torch.sum(weights)
            
            weights = weights.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
            cur_sim = (weights*cur_sim).sum(dim=3)
            # cur_sim = cur_sim.mean(dim=3)

            # cur_sim = cur_sim.mean(dim=3)  # G, B, T
            # print('cur_sim: ', torch.sum(cur_sim))
            # print('cur_sim: ', cur_sim.shape)
            # cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))
            sim.append(cur_sim)
            
        sim = torch.cat(sim, dim=2)
        
        return sim
        
    def retrieve(self, x, index, train=True):
        index = index.to(x.device)
        
        bsz, seq_len, channels = x.shape
        assert(seq_len == self.seq_len, channels == self.channels)
        
        x_mg, mg_offset = self.decompose_mg(x, label=False) # G, B, S, C
        # print('x_mg: ', x_mg.shape)
        sim = self.periodic_batch_corr(
            self.train_data_all_mg, # G, T, S * C
            x_mg, # G, B, S * C
        ) # G, B, T

        # print('sim: ', torch.sum(sim))

        if train:
            sliding_index = torch.arange(2 * (self.seq_len + self.pred_len) - 1).to(x.device)
            sliding_index = sliding_index.unsqueeze(dim=0).repeat(len(index), 1)
            sliding_index = sliding_index + (index - self.seq_len - self.pred_len + 1).unsqueeze(dim=1)
            
            sliding_index = torch.where(sliding_index >= 0, sliding_index, 0)
            sliding_index = torch.where(sliding_index < self.n_train, sliding_index, self.n_train - 1)

            self_mask = torch.zeros((bsz, self.n_train)).to(x.device)
            self_mask = self_mask.scatter_(1, sliding_index, 1.)
            self_mask = self_mask.unsqueeze(dim=0).unsqueeze(dim=-1).repeat(self.n_period, 1, 1, self.channels)
            
            sim = sim.masked_fill_(self_mask.bool(), float('-inf')) # G, B, T
            # print('sim: ', torch.sum(sim))
        sim = sim.reshape(self.n_period * bsz, self.n_train, self.channels) # G X B, T
                
        topm_index = torch.topk(sim, self.topm, dim=1).indices
        ranking_sim = torch.ones_like(sim) * float('-inf')
        # ranking_sim = torch.ones_like(sim) * float('0')
        
        rows = torch.arange(sim.size(0), device=sim.device).unsqueeze(1).unsqueeze(2)  # [768, 1, 1]
        rows = rows.expand(-1, topm_index.size(1), topm_index.size(2))  # [768, 20, 7]

        # Step 2: Use advanced indexing
        ranking_sim[rows, topm_index, torch.arange(sim.size(2)).view(1, 1, -1)] = \
            sim[rows, topm_index, torch.arange(sim.size(2)).view(1, 1, -1)]
        
        # print('sim: ', sim.shape, ' ranking_sim: ', ranking_sim.shape)
        sim = sim.reshape(self.n_period, bsz, self.n_train, self.channels) # G, B, T
        ranking_sim = ranking_sim.reshape(self.n_period, bsz, self.n_train, self.channels) # G, B, T

        data_len, seq_len, channels = self.train_data_all.shape
        if torch.isnan(ranking_sim).any():
            print('ranking_sim shape: ', torch.sum(ranking_sim))
        # print('ranking_sim shape: ', torch.sum(ranking_sim))
        # print('ranking_sim: ', ranking_sim)
            
        ranking_prob = F.softmax(ranking_sim / self.temperature, dim=2)
        
        # ranking_prob = ranking_sim / ranking_sim.sum(dim=2, keepdim=True)  # Normalize
        ranking_prob = ranking_prob.detach().cpu() # G, B, T
        
        y_data_all = self.y_data_all_mg#.flatten(start_dim=2) # G, T, P * C
        ranking_prob = ranking_prob.cuda()
        y_data_all = y_data_all.cuda()
        pred_from_retrieval = torch.einsum('bikp,bkjp->bijp', ranking_prob, y_data_all)

        pred_from_retrieval = pred_from_retrieval.to(x.device)
        
        
        return pred_from_retrieval, ranking_prob.cpu(), y_data_all.cpu()
    
    def retrieve_all(self, data, train=False, device=torch.device('cpu')):
        assert(self.train_data_all_mg != None)
        # device='cpu'
        
        rt_loader = DataLoader(
            data,
            batch_size=self.args.bz2,
            shuffle=False,
            num_workers=8,
            drop_last=False
        )
        
        retrievals = []
        ranking_prob_list = []
        batch_x_list = []
        y_data_all = None
        with torch.no_grad():
            for index, batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(rt_loader):
                pred_from_retrieval, ranking_prob, y_data_all = self.retrieve(batch_x.float().to(device), index, train=train)
                pred_from_retrieval = pred_from_retrieval.cpu()
                retrievals.append(pred_from_retrieval)
                ranking_prob_list.append(ranking_prob)
                batch_x_list.append(batch_x.cpu())
                
        retrievals = torch.cat(retrievals, dim=1)
        
        return retrievals, ranking_prob_list, y_data_all, batch_x_list
    


