import logging
from collections import OrderedDict
import time
import torch
import torch.nn.functional as F
from attrdict import AttrDict

def get_logger(filename, mode='a'):
    logging.basicConfig(level=logging.INFO, format='%(message)s')
    logger = logging.getLogger()
    for hdlr in logger.handlers:
        logger.removeHandler(hdlr)
    logger.addHandler(logging.FileHandler(filename, mode=mode))
    logger.addHandler(logging.StreamHandler())
    return logger

class RunningAverage(object):
    def __init__(self, *keys):
        self.sum = OrderedDict()
        self.cnt = OrderedDict()
        self.clock = time.time()
        for key in keys:
            self.sum[key] = 0
            self.cnt[key] = 0

    def update(self, key, val):
        if isinstance(val, torch.Tensor):
            val = val.item()
        if self.sum.get(key, None) is None:
            self.sum[key] = val
            self.cnt[key] = 1
        else:
            self.sum[key] = self.sum[key] + val
            self.cnt[key] += 1

    def reset(self):
        for key in self.sum.keys():
            self.sum[key] = 0
            self.cnt[key] = 0
        self.clock = time.time()

    def clear(self):
        self.sum = OrderedDict()
        self.cnt = OrderedDict()
        self.clock = time.time()

    def keys(self):
        return self.sum.keys()

    def get(self, key):
        assert(self.sum.get(key, None) is not None)
        return self.sum[key] / self.cnt[key]

    def info(self, show_et=True):
        line = ''
        for key in self.sum.keys():
            val = self.sum[key] / self.cnt[key]
            if type(val) == float:
                line += f'{key} {val:.4f} '
            else:
                line += f'{key} {val} '.format(key, val)
        if show_et:
            line += f'({time.time()-self.clock:.3f} secs)'
        return line

seed_ch_list = ['Fp1', 'Fpz', 'Fp2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1',
               'FCz', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5',
               'CP3','CP1', 'CPz', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7',
               'PO5', 'PO3','POz', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'Oz', 'O2', 'CB2']



motor_ch_list = ['Fz', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6', 'CP3', 'CP1', 'CPz', 'CP2',
            'CP4', 'P1', 'Pz', 'P2', 'POz']


monitor_ch_list = ['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1','C3','C5','T7'
                  ,'TP7','CP5','CP3','CP1','P1','P3','P5','P7','P9','PO7','PO3','O1','Iz','Oz','POz','Pz','CPz','Fpz','Fp2','AF8','AF4'
                  ,'AFz','Fz','F2','F4','F6','F8','FT8','FC6','FC4','FC2','FCz','Cz','C2','C4','C6','T8','TP8','CP6','CP4','CP2','P2','P4','P6','P8','P10','PO8','PO4','O2']



standard_channel_list = ['Fp1', 'Fpz', 'Fp2', 'Fp9', 'Fp10', 'Nz', 'AF1', 'AF2', 'AFz', 'AF3', 'AF4', 'AF5', 'AF6', 
                         'AF7', 'AF8', 'AF9', 'AF10', 'F1', 'F2', 'Fz', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9', 'F10',
                         'FC1', 'FC2', 'FCz', 'FC3', 'FC4', 'FC5', 'FC6', 'FT7', 'FT8', 'FT9', 'FT10', 'C1', 'C2', 'Cz',
                         'C3', 'C4', 'C5', 'C6', 'T7', 'T8', 'T9', 'T10', 'I1', 'I2', 'CP1', 'CP2', 'CPz', 'CP3', 'CP4',
                         'CP5', 'CP6', 'TP7', 'TP8', 'TP9', 'TP10', 'P1', 'P2', 'Pz', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8',
                         'P9', 'P10', 'PO1', 'PO2', 'POz', 'PO3', 'PO4', 'PO5', 'PO6', 'PO7', 'PO8', 'PO9', 'PO10',
                         'O1', 'O2', 'Oz', 'O9', 'O10', 'Iz', 'CB1', 'CB2', 'A1', 'A2']




def get_batch(eeg, mask_ratio, dataset):
    B, C, L = eeg.shape
    batch = AttrDict()
    
    all_channel_list = [s.upper() for s in standard_channel_list]
    all_channel_embedding = torch.eye(len(standard_channel_list), requires_grad=False).cuda()

    if dataset == 'seed':
        ch_list = [s.upper() for s in seed_ch_list]
        # batch.t_index = [all_channel_list.index(item) for item in ch_list]
    elif dataset == 'motor':
        ch_list = [s.upper() for s in motor_ch_list]
        # batch.t_index = [all_channel_list.index(item) for item in ch_list]
    elif dataset == 'monitor':
        ch_list = [s.upper() for s in monitor_ch_list]

    reordered_index_map = [ch_list.index(ch) for ch in all_channel_list if ch in ch_list]
    new_channels = [ch for ch in all_channel_list if ch in ch_list]
    new_eeg = eeg[:, reordered_index_map, :]

    batch.xt = new_eeg
    batch.t_index = [all_channel_list.index(item) for item in new_channels]
    batch.t_index = torch.tensor(list(batch.t_index))

    # t_index = torch.tensor(batch.t_index)
    n = C - int(mask_ratio * C)
    channels = torch.stack([torch.randperm(C) for _ in range(B)])
    reserved_channels = channels[:, :n]
    missing_channels = channels[:, n:]
    reserved_channels, _ = torch.sort(reserved_channels, dim=1)
    missing_channels, _ = torch.sort(missing_channels, dim=1)
    xc = torch.stack([torch.index_select(new_eeg[i], 0, reserved_channels[i]) for i in range(B)])
    # c_index = torch.stack([torch.index_select(t_index, 0, selected_channels[i]) for i in range(B)])
    batch.xc = xc
    batch.yc = reserved_channels
    batch.yt = missing_channels
    channel_embedding = all_channel_embedding[batch.t_index]
    channel_embedding = channel_embedding.unsqueeze(0).expand(new_eeg.shape[0], -1, -1)
    batch.right_embedding = channel_embedding.cuda()
    batch.left_embedding = channel_embedding.transpose(2, 1).cuda()
    return batch



def compute_pcc(ori, recon):
    B, C, L = ori.shape
    ori = ori.view(-1, L)
    recon = recon.view(-1, L)

    mean_ori = torch.mean(ori, dim=1, keepdim=True)
    mean_recon = torch.mean(recon, dim=1, keepdim=True)

    ori_centered = ori - mean_ori
    recon_centered = recon - mean_recon

    conv = (ori_centered * recon_centered).sum(dim=1) / (L-1)
    std_ori = torch.sqrt((ori_centered**2).sum(dim=1) / (L - 1))
    std_recon = torch.sqrt((recon_centered**2).sum(dim=1) / (L - 1))

    pcc = conv / (std_ori * std_recon)
    pcc = pcc.view(B, C)
    pcc = pcc.mean(dim=0).mean().item()

    return pcc

def compute_rrmse(ori, recon):
    mse = ((ori - recon) ** 2).mean()
    mean_square_ori = (ori ** 2).mean()
    rrmse = torch.sqrt(mse) / torch.sqrt(mean_square_ori)
    return rrmse.item()

def compute_rsnr(ori, recon):
    ori_energy = torch.sum(ori ** 2)
    error_energy = torch.sum((ori - recon) ** 2)

    rsnr = 10 * torch.log10(ori_energy / error_energy)
    return rsnr.item()

def compute_ssim(ori, recon, window_size=11, C1=1e-4, C2=9e-4):
    B, C, L = ori.shape
    window = torch.ones(C, 1, window_size, device=ori.device) / window_size

    mu_x = F.conv1d(ori, window, padding=window_size // 2, groups=C)
    mu_y = F.conv1d(recon, window, padding=window_size // 2, groups=C)

    mu_x2 = mu_x ** 2
    mu_y2 = mu_y ** 2

    sigma_x2 = F.conv1d(ori ** 2, window, padding=window_size // 2, groups=C) - mu_x2
    sigma_y2 = F.conv1d(recon ** 2, window, padding=window_size // 2, groups=C) - mu_y2
    sigma_xy = F.conv1d(ori * recon, window, padding=window_size // 2, groups=C) - mu_x * mu_y

    ssim_map = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / ((mu_x2 + mu_y2 + C1) * (sigma_x2 + sigma_y2 + C2))

    return ssim_map.mean().item()


if __name__ == '__main__':
  
    print(len(seed_ch_list))
    # print(len([standard_channel_list.index(item) for item in pass2d_ch_list]))


    # print("t_index:", batch.t_index)
    # print("right embedding:", batch.right_embedding.shape)
    # print("left embedding:", batch.left_embedding.shape)
    # print("yt:", batch.yt)


    # c_index = torch.stack([torch.index_select(batch.t_index, 0, batch.yc[i]) for i in range(data.shape[0])])
    # print(c_index)
    # A = torch.zeros(2, 95, 95)
    # B = torch.ones(2, 95, 95)
    # for i in range(A.shape[0]):
    #     for j in c_index[i]:
    #         A[i][j, c_index] = B[i][j, c_index]
    # print(A)






