import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
import math
from tqdm import tqdm
import time
from torch.utils.data import Dataset, DataLoader

def kl_divergence(ax, bx):
     # KL divergence
    log_ax = torch.log(ax + 1e-8)  # Avoid log(0) #[3,512, 2000]
    log_bx = torch.log(bx + 1e-8)  # Avoid

    sub = log_ax.unsqueeze(2) - log_bx.unsqueeze(1) # 3, 512, 1024, 2000
    cur_sim = torch.sum(ax.unsqueeze(2) * sub, dim=3)  # 3, 512, 1024
    cur_sim = cur_sim.mean(dim=3)  # 3, 512, 1024
    cur_sim = cur_sim.transpose(1,2)
    return cur_sim
def js_divergence(ax, bx):
    p = (ax.unsqueeze(2) + bx.unsqueeze(1)) / 2
    log_p = torch.log(p + 1e-8)  # Avoid log(0)
    
    cur_sim = (ax.unsqueeze(2) * (torch.log(ax.unsqueeze(2) + 1e-8) - log_p)).sum(dim=3) + \
                (bx.unsqueeze(1) * (torch.log(bx.unsqueeze(1) + 1e-8) - log_p)).sum(dim=3)
    cur_sim = cur_sim.mean(dim=3)
    cur_sim = cur_sim.transpose(1,2)
    return cur_sim

def js_divergence_CI(ax, bx):
    p = (ax.unsqueeze(2) + bx.unsqueeze(1)) / 2
    log_p = torch.log(p + 1e-8)  # Avoid log(0)
    
    cur_sim = (ax.unsqueeze(2) * (torch.log(ax.unsqueeze(2) + 1e-8) - log_p)).sum(dim=3) + \
                (bx.unsqueeze(1) * (torch.log(bx.unsqueeze(1) + 1e-8) - log_p)).sum(dim=3)
    # cur_sim = cur_sim.mean(dim=3)
    cur_sim = cur_sim.transpose(1,2)
    # print('cur_sim: ', cur_sim.shape)
    return cur_sim

class FreqRetrievalToolCI():
    def __init__(
        self,
        seq_len,
        pred_len,
        channels,
        n_period=3,
        temperature=0.1,
        topm=20,
        with_dec=False,
        return_key=False,
        n_bands=3
    ):
        period_num = [16, 8, 4, 2, 1]
        period_num = period_num[-1 * n_period:]
        
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.channels = channels
        self.step_size = 5
        self.n_period = n_period
        self.period_num = sorted(period_num, reverse=True)
        
        self.temperature = 0.1#temperature
        self.topm = topm
        
        self.with_dec = with_dec
        self.return_key = return_key

        self.sim_by_freq = True
        # self.temperature=0.01
        self.remove_offset = not self.sim_by_freq
        self.n_bands = n_bands
        
    def prepare_dataset(self, train_data):
        train_data_all = []
        y_data_all = []

        for i in range(len(train_data)):
            # if i % 5 != 0:
            #     continue
            td = train_data[i]

            train_data_all.append(td[1])
            
            if self.with_dec:
                y_data_all.append(td[2][-(train_data.pred_len + train_data.label_len):])
            else:
                y_data_all.append(td[2][-train_data.pred_len:])
        
        print('len of train_data_all: ', len(train_data_all))
        self.train_data_all = torch.tensor(np.stack(train_data_all, axis=0)).float()
        print('train_data_all shape: ', self.train_data_all.shape)
        self.train_data_all_mg, _ = self.decompose_mg(self.train_data_all, label=False)
        print('train_data_all_mg shape: ', torch.sum(self.train_data_all_mg))
        
        self.y_data_all = torch.tensor(np.stack(y_data_all, axis=0)).float()
        self.y_data_all_mg, _ = self.decompose_mg_value(self.y_data_all, label=False)

        self.n_train = self.train_data_all.shape[0]
        

    def decompose_mg(self, data_all, remove_offset=True, label=False):
        remove_offset = self.remove_offset
        data_all = copy.deepcopy(data_all) # T, S, C

        mg = []
        for g in self.period_num:
            # print('g: ', g)
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = torch.cat([cur]*g, dim=1)
            ori = cur

            if not label:
                if self.sim_by_freq:
                    cur = cur / torch.sum(cur, dim=1, keepdim=True) # normalize frequency to become distribution
                    # cur = cur - cur[:,-1:,:]
                    cur = torch.fft.rfft(cur, dim=1).abs()[:,1:,:]
                    n_freq = cur.shape[1]
                    n_bands = self.n_bands
                    cur = cur.reshape(cur.shape[0], n_freq // n_bands, n_bands, cur.shape[2])
                    # print('before cur: ', torch.sum(cur))
                    # sum_cur = torch.sum(cur, dim=1, keepdim=True)

                    cur = cur / (torch.sum(cur, dim=1, keepdim=True)+1e-8)
                    if torch.isnan(cur).any():
                        print('NAN AT CUR')
                    # print('cur after: ', torch.sum(cur))
            mg.append(cur)

        mg = torch.stack(mg, dim=0) # G, T, S, C
        # print('mg: ', torch.sum(mg))
        if remove_offset:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
            offset = torch.stack(offset, dim=0)
        else:
            offset = None
            
        if label:
            mg = F.normalize(mg, dim=1) # normalize the mg to become distribution
        offset = None 
        # print('mg after: ', torch.sum(mg))
        # print('mg: ', torch.sum(mg))
        return mg, offset

    def decompose_mg_value(self, data_all, remove_offset=True, label=False):
        remove_offset = self.remove_offset
        data_all = copy.deepcopy(data_all) # T, S, C

        mg = []
        for g in self.period_num:
            # print('g: ', g)
            cur = data_all.unfold(dimension=1, size=g, step=g).mean(dim=-1)
            cur = torch.cat([cur]*g, dim=1)
            ori = cur

            if not label:
                if self.sim_by_freq:
                    # cur = cur / torch.sum(cur, dim=1, keepdim=True) # normalize frequency to become distribution
                    # cur = cur - cur[:,-1:,:]
                    cur = torch.fft.rfft(cur, dim=1).abs()[:,1:,:]
                    # n_freq = cur.shape[1]
                    # n_bands = self.n_bands
                    # cur = cur.reshape(cur.shape[0], n_freq // n_bands, n_bands, cur.shape[2])

                    B, F, D = cur.shape
                    n_bands = self.n_bands
                    band_size = F // n_bands

                    # Create frequency mask: (n_bands, F)
                    mask = torch.zeros(n_bands, F, device=cur.device, dtype=cur.dtype)
                    for i in range(n_bands):
                        start = i * band_size
                        end = (i + 1) * band_size if i < n_bands - 1 else F
                        mask[i, start:end] = 1.0

                    # Reshape to apply across batch and features: (1, n_bands, F, 1)
                    mask = mask.view(1, n_bands, F, 1)

                    # Expand spec to shape (B, n_bands, F, D)
                    spec_exp = cur.unsqueeze(1).expand(-1, n_bands, -1, -1)

                    # Apply mask to isolate each band's frequencies
                    cur = spec_exp * mask  # (B, n_bands, F, D)
                    # print('cur: ', cur.shape)
                    cur = cur.transpose(1,2)

                    # padding 0 at the first
                    cur = torch.cat([torch.zeros_like(cur[:,:1,:,:]), cur], dim=1)  # (B, F+1, D)

                    # print('cur: ', cur.shape)
                    # padding to make the 1st dimension become n_freq

                    
                    # reverse fft
                    cur = torch.fft.irfft(cur, dim=1)
                    # print('reverse: ', cur.shape)
                    if torch.isnan(cur).any():
                        print('NAN AT CUR')
                    # print('cur after: ', torch.sum(cur))
            mg.append(cur)

        mg = torch.stack(mg, dim=0) # G, T, S, C
        # print('mg: ', torch.sum(mg))
        if remove_offset:
            offset = []
            for i, data_p in enumerate(mg):
                cur_offset = data_p[:,-1:,:]
                mg[i] = data_p - cur_offset
                offset.append(cur_offset)
            offset = torch.stack(offset, dim=0)
        else:
            offset = None
            
        if label:
            mg = F.normalize(mg, dim=1) # normalize the mg to become distribution
        offset = None 
        # print('mg after: ', torch.sum(mg))
        # print('mg: ', torch.sum(mg))
        return mg, offset
    
    def periodic_batch_corr(self, data_all, key, in_bsz = 512):
        _, bsz, features = key.shape
        _, train_len, _ = data_all.shape
        
        bx = key - torch.mean(key, dim=2, keepdim=True)
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []

        for i in range(iters):
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            cur_data = data_all[:, start_idx:end_idx].to(key.device)
            ax = cur_data - torch.mean(cur_data, dim=2, keepdim=True)
            

            cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))  #Covariance 
            sim.append(cur_sim)
            
        sim = torch.cat(sim, dim=2)
        # print('sim: ', sim.shape)
        return sim

    def periodic_batch_corr_by_freq(self, data_all, key, in_bsz=512):
        """
            Compute periodic batch correlation by KL divergence between two normalized frequencies.
        """
        # _, bsz, features = key.shape
        # _, train_len, _ = data_all.shape
        train_len = data_all.shape[1]
        # print('data_all shape: ', torch.sum(data_all))
        #convert to frequency domain
        # print('key shape before fft: ', torch.sum(key))
        
        iters = math.ceil(train_len / in_bsz)
        
        sim = []
        # print('train_len: ', train_len)
        # print('data_all: ', torch.sum(data_all), data_all.shape)
        for i in range(iters):
            for_start_time = time.time()
            start_idx = i * in_bsz
            end_idx = min((i + 1) * in_bsz, train_len)
            
            ax = data_all[:, start_idx:end_idx].to(key.device)
            # print('key shape: ', key.shape)
            # print('ax shape: ', ax.shape)
            # p, ba, d = ax.shape
            # if len(key.shape) == 3:
            #     p, bb, d = key.shape
            
            # ax = ax.reshape(p, ba, -1, self.channels)
            # key = key.reshape(p, bb, -1, self.channels)
            # print('ax: ', torch.sum(ax), ax.shape)
            # print('key: ', torch.sum(key), key.shape)
            # JS divergence
            # diver_start_time = time.time()
            cur_sim = js_divergence_CI(ax, key)
            # print('cur_sim: ', torch.sum(cur_sim))
            # print(f'JS divergence time: {dict_end_time - diver_start_time:.4f}s')
            # cur_sim = torch.bmm(F.normalize(bx, dim=2), F.normalize(ax, dim=2).transpose(-1, -2))
            sim.append(cur_sim)
            # print(f'Batch {i} correlation time: {for_end_time - for_start_time:.4f}s')
            # # print('cur_sim shape: ', cur_sim.shape)
            
        sim = torch.cat(sim, dim=2)
        # print('frequency sim: ', sim.shape)
        return sim
        
    def retrieve(self, x, index, train=True):
        
        
        bsz, seq_len, channels = x.shape
        assert(seq_len == self.seq_len, channels == self.channels)
        

        start = time.time()
        # if train:
        #     x_mg = self.train_data_all_mg[:, index, :, :].to(x.device) # G, B, S, C
        # else:
        x_mg, mg_offset = self.decompose_mg(x, label=False) # G, B, S, C
        
        # print('x_mg: ', torch.sum(x_mg))
        index = index.to(x.device)
        end = time.time()
        start = time.time()

        # print(' self.train_data_all_mg: ',  torch.sum(self.train_data_all_mg))
        # print(' x_mg: ',  x_mg.shape)
        # print(' self.train_data_all_mg.flatten(start_dim=2): ',  self.train_data_all_mg.flatten(start_dim=2).shape)
        # print(' x_mg.flatten(start_dim=2): ',  x_mg.flatten(start_dim=2).shape)
        sim = self.periodic_batch_corr_by_freq(
            self.train_data_all_mg, # G, T, S * C
            x_mg, # G, B, S * C
            in_bsz=256, # batch size for correlation computation
        )
        # check is nan 
        # if torch.isnan(sim).any():
        #     print('sim shape: ', torch.sum(sim))
        # sim=sim.to(x.device)
        # print('sim0: ', sim.shape, torch.sum(sim))
        end = time.time()
        # print('sim: ', sim.shape)
        if train:
            sliding_index = torch.arange(2 * (self.seq_len + self.pred_len) - 1).to(x.device)
            # print('sliding_index: ', sliding_index.shape)

            sliding_index = sliding_index.unsqueeze(dim=0).repeat(len(index), 1)
            sliding_index = sliding_index + (index - self.seq_len - self.pred_len + 1).unsqueeze(dim=1)
            
            sliding_index = torch.where(sliding_index >= 0, sliding_index, 0)
            sliding_index = torch.where(sliding_index < self.n_train, sliding_index, self.n_train - 1)

            self_mask = torch.zeros((bsz, self.n_train)).to(x.device)
            self_mask = self_mask.scatter_(1, sliding_index, 1.)
            self_mask = self_mask.unsqueeze(dim=0).unsqueeze(dim=-1).repeat(self.n_period, 1, 1, self.channels)
            # print('sim: ', sim.shape, ' self_mask: ', self_mask.shape)
            # print('self_mask: ', self_mask.shape)
            self_mask = self_mask.unsqueeze(3).repeat(1, 1, 1, self.n_bands, 1) # G, B, T, P, C
            sim = sim.masked_fill_(self_mask.bool(), float('-inf')) # G, B, T
        # if torch.isnan(sim).any():
        #     print('sim shape: ', torch.sum(sim))
        sim = sim.reshape(self.n_period * bsz, self.n_train, self.n_bands, self.channels) # G X B, T
        # sim = sim + torch.randn_like(sim).to(sim.device) * 1e-6 # add small noise to avoid numerical issues
        if torch.isnan(sim).any():
            print('sim shape: ', torch.sum(sim))    
        # if torch.isnan(sim).any():
        #     print('ranking_prob shape: ', torch.sum(sim))
        topm_index = torch.topk(sim, self.topm, dim=1).indices
        ranking_sim = torch.ones_like(sim) * float('-inf')
        

        rows = torch.arange(sim.size(0), device=sim.device).unsqueeze(1).unsqueeze(2)  # [768, 1, 1]
        rows = rows.expand(-1, topm_index.size(1), topm_index.size(2))  # [768, 20, 7]

        # Step 2: Use advanced indexing
        # print('ranking_sim: ', ranking_sim.shape)
        # print('topm_index: ', topm_index.shape)
        # ranking_sim[rows, topm_index, torch.arange(sim.size(3)).view(1, 1, 1, -1)] = \
        #     sim[rows, topm_index, torch.arange(sim.size(3)).view(1, 1, 1, -1)]
        B, N, M, D = ranking_sim.shape  # [256, 7825, 3, 7]
        _, T, _, _ = topm_index.shape   # T = topm = 20

        # Create batch indices
        batch_idx = torch.arange(B).view(B, 1, 1, 1).expand(B, T, M, D)  # [256, 20, 3, 7]

        # Mode indices (i.e., 0, 1, 2 for M=3)
        mode_idx = torch.arange(M).view(1, 1, M, 1).expand(B, T, M, D)   # [256, 20, 3, 7]

        # Feature indices (i.e., 0 to 6 for D=7)
        feat_idx = torch.arange(D).view(1, 1, 1, D).expand(B, T, M, D)   # [256, 20, 3, 7]

        # Now use these to index
        ranking_sim[batch_idx, topm_index, mode_idx, feat_idx] = sim[batch_idx, topm_index, mode_idx, feat_idx]
        
        # print('sim: ', sim.shape, ' ranking_sim: ', ranking_sim.shape)
        sim = sim.reshape(self.n_period, bsz, self.n_train, self.n_bands, self.channels) # G, B, T
        ranking_sim = ranking_sim.reshape(self.n_period, bsz, self.n_train, self.n_bands, self.channels) # G, B, T

        # data_len, seq_len, channels = self.train_data_all.shape
        if torch.isnan(ranking_sim).any():
            print('ranking_sim shape: ', torch.sum(ranking_sim))
    
        ranking_prob = F.softmax(ranking_sim / self.temperature, dim=2)
        ranking_prob = ranking_prob.detach().cpu() # G, B, T
        
        y_data_all = self.y_data_all_mg#.flatten(start_dim=2) # G, T, P * C
        # if torch.isnan(ranking_prob).any():
        #     import pickle as pkl
        #     with open('debug.pkl', 'wb') as f:
        #         pkl.dump({
        #             'ranking_prob': ranking_prob.cpu(),
        #             'y_data_all': y_data_all.cpu(),
        #             'sim': sim.cpu(),
        #             'ranking_sim': ranking_sim.cpu(),
        #             'topm_index': topm_index.cpu(),
        #             'temperature': self.temperature,
        #         }, f)
        #     print('ranking_prob shape: ', torch.sum(ranking_prob))
            # print('topm_index: ', topm_index)
        # print('ranking_prob: ', ranking_prob.shape)
        # print('y_data_all: ', y_data_all.shape)
        # y_data_all = y_data_all.unsqueeze(3).repeat(1, 1, 1, self.n_bands, 1) # G, T, P * C -> G, T, P, C
        # pred_from_retrieval = torch.bmm(ranking_prob, y_data_all).reshape(self.n_period, bsz, -1, channels)
        pred_from_retrieval = torch.einsum('bikdp,bkjdp->bijdp', ranking_prob, y_data_all)

        pred_from_retrieval = pred_from_retrieval.to(x.device)
        # print('pred_from_retrieval: ', torch.sum(pred_from_retrieval))
        return pred_from_retrieval, ranking_prob
    
    def retrieve_all(self, data, train=False, device=torch.device('cpu')):
        assert(self.train_data_all_mg != None)
        
        rt_loader = DataLoader(
            data,
            batch_size=256,
            shuffle=False,
            num_workers=16,
            drop_last=False
        )
        # print(f'len of data train {train}: ', len(data))
        retrievals = []
        ranking_prob_list = []
        with torch.no_grad():
            for index, batch_x, batch_y, batch_x_mark, batch_y_mark in tqdm(rt_loader):
                pred_from_retrieval, ranking_prob = self.retrieve(batch_x.float().to(device), index, train=train)
                pred_from_retrieval = pred_from_retrieval.cpu()
                retrievals.append(pred_from_retrieval)
                ranking_prob_list.append(ranking_prob.cpu())
                
        retrievals = torch.cat(retrievals, dim=1)
        # print('retrievals: ', torch.sum(retrievals))
        # print('retrievals shape: ', retrievals.shape)
        return retrievals, ranking_prob_list

