import os
import math
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn

from datasets import load_from_disk, concatenate_datasets


### M1
from brainlm_mae.modeling_brainlm_mask_next_1024_2 import BrainLMForPretraining as pre_series
### M2
# from brainlm_mae.modeling_brainlm_M2_test import BrainLMForPretraining as pre_series


from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import matplotlib.colors as mcol
import seaborn as sns

from torch.utils.data import Dataset
from random import randint

import nibabel as nib
import numpy as np


import warnings

import logging
logging.disable(logging.WARNING)
warnings.filterwarnings("ignore", category=UserWarning)

   
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def standardize(data):
    # mean = np.mean(data)
    # std = np.std(data)
    # # if std==0:
    # #     print(mean,std)
    # # else:
    # #     print(mean,std)
    # return (data - mean) / std
    
    
    median_value = np.median(data)
    Q1 = np.percentile(data,25)
    Q3 = np.percentile(data,75)
    IQR = Q3 - Q1
    
    return (data - median_value) / IQR




TR = 0.72   
    
from nitime.timeseries import TimeSeries
from nitime.analysis import SpectralAnalyzer, FilterAnalyzer, NormalizationAnalyzer






def read_xlsx(file_path):
    df = pd.read_excel(file_path)
    data = df.to_numpy()
    return data

coords_ds = read_xlsx("./Coordinates-398.xlsx")
window_xyz_list = []       
for brain_region_idx in range(398):
    # Append voxel coordinates
    xyz = torch.tensor(
        [
            coords_ds[brain_region_idx][0],
            coords_ds[brain_region_idx][1],
            coords_ds[brain_region_idx][2],
        ],
        dtype=torch.float32,
    )
    window_xyz_list.append(xyz)
window_xyz_list = torch.stack(window_xyz_list)


from scipy.signal import butter, filtfilt
def band_pass_filter(data, lowcut, highcut, fs, order=1):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    filtered_data = filtfilt(b, a, data, axis=-1)
    return filtered_data

# 应用带通滤波器
lowcut = 0.01  # 低截止频率（单位：Hz）
highcut = 0.1  # 高截止频率（单位：Hz）


def z_score_normalize(signal_vector):
    # mean = np.mean(data, axis=1, keepdims=True)
    # std = np.std(data, axis=1, keepdims=True)
    # return (data - mean) / std
    
    median_value = np.median(signal_vector[g])
    Q1 = np.percentile(signal_vector[g],25)
    Q3 = np.percentile(signal_vector[g],75)
    IQR = Q3 - Q1
    signal_vector[g] = (signal_vector[g] - median_value) / IQR



class MyDataset(Dataset):
    def __init__(self, data_folder,flag="train"):
        self.data_folder = data_folder
        self.filenames = []
        self.flag = flag
        
        
        file_list = os.listdir(self.data_folder)
        file_list.sort(key= lambda x:x[:6])#.png所以是[:-4]
        #print(len(file_list))
        
        if self.flag == "train":
            datas = file_list[:12000]    ######################################################  训练数据个数10  2000
            for per_data in datas:
                self.filenames.append(os.path.join(self.data_folder, per_data))
        elif self.flag == "val":
            datas = file_list[12000:]  ######################################################   测试数据个数
            for per_data in datas:
                self.filenames.append(os.path.join(self.data_folder, per_data))
    
    def __len__(self):
        return len(self.filenames) 
    
    def __getitem__(self, index):
        # labels = {"Emotion":0,"Gambling":1,"Language":2,"Motor":3,"Relation":4,"Social":5,"Nback":6}  ### 任务态标签
        # labels = {"EMOTION":0,"GAMBLING":1,"LANGUAGE":2,"MOTOR":3,"RELATIONAL":4,"SOCIAL":5,"WM":6}  ### 任务态标签
        # labels = {'MOTOR_lf':0, 'MOTOR_rf':1,'MOTOR_lh':2,'MOTOR_rh':3,'MOTOR_t':4,
        #          'WM_0bk_body':5,'WM_0bk_faces':6,'WM_0bk_places':7,'WM_0bk_tools':8,'WM_2bk_body':9,'WM_2bk_faces':10,'WM_2bk_places':11,'WM_2bk_tools':12,
        #           'EMOTION_fear':13,'EMOTION_neut':14,'GAMBLING_loss':15,'GAMBLING_win':16,'LANGUAGE_math':17,'LANGUAGE_story':18,
        #           'RELATIONAL_match':19 ,'RELATIONAL_relation':20,'SOCIAL_mental':21,'SOCIAL_rnd':22
        #          }
        
        # labels = {'MOTOR_lf':0, 'MOTOR_rf':1,'MOTOR_lh':2,'MOTOR_rh':3,'MOTOR_t':4,
        #          'WM_0bk_body':5,'WM_0bk_faces':6,'WM_0bk_places':7,'WM_0bk_tools':8,'WM_2bk_body':9,'WM_2bk_faces':10,'WM_2bk_places':11,'WM_2bk_tools':12,
        #           'EMOTION_fear':13,'EMOTION_neut':14,'GAMBLING_loss':15,'GAMBLING_win':16,'LANGUAGE_math':17,'LANGUAGE_story':18,
        #           'RELATIONAL_match':19 ,'RELATIONAL_relation':20,'SOCIAL_mental':21,'SOCIAL_rnd':22,'rfMRI':23
        #          }
        
        labels = {'MOTOR_lf':0, 'MOTOR_rf':1,'MOTOR_lh':2,'MOTOR_rh':3,'MOTOR_t':4,
                 'WM_0bk_body':5,'WM_0bk_faces':6,'WM_0bk_places':7,'WM_0bk_tools':8,'WM_2bk_body':5,'WM_2bk_faces':6,'WM_2bk_places':7,'WM_2bk_tools':8,
                  'EMOTION_fear':9,'EMOTION_neut':10,'GAMBLING_loss':11,'GAMBLING_win':12,'LANGUAGE_math':13,'LANGUAGE_story':14,
                  'RELATIONAL_match':15 ,'RELATIONAL_relation':16,'SOCIAL_mental':17,'SOCIAL_rnd':18,'rfMRI':19
                 }
        
        data_label  = -1  ### 静息态
        data_label  = torch.tensor(data_label, dtype=torch.int64)
        for i in labels.keys():
            if i in self.filenames[index]:
                
                data_label = labels[i]
                break
       
        
        # print(self.filenames[index])
        image = np.load(self.filenames[index])
        # print(image.shape)
       
        data1 = self.preprocess(image)
       
       
        return data1,data_label
    
    def preprocess(self, examples):
        
        examples_o ={}
        label = 1 
        label = torch.tensor(label, dtype=torch.int64)
        
        # signal_vector = examples[:,:150]  # (1559, 30, 200)
        # print(examples.shape)
        len = examples.shape[0]
        if len==1200:
            len = 100
        signal_vector = examples[:len,:].transpose()  # (1559, 30, 200)
        
       
        # signal_vector = np.array([standardize(signal_vector[i]) for i in range(signal_vector.shape[0])])
        
        
        
        
        T = TimeSeries(signal_vector, sampling_interval=1.5)
        F = FilterAnalyzer(T, ub=0.15, lb=0.01)
        # signal_vector4 =F.filtered_fourier.data
        # signal_vector4 = np.array([standardize(signal_vector4[i]) for i in range(signal_vector4.shape[0])])
        signal_vector4 = NormalizationAnalyzer(F.filtered_fourier).z_score.data  ##  filtered_fourier   filtered_boxcar
        
        # signal_window = z_score_normalize(band_pass_filter(signal_vector, lowcut, highcut, 1/0.72))
        # signal_window = torch.tensor(signal_window, dtype=torch.float32)
        
        
        
     
        signal_window = torch.tensor(signal_vector4, dtype=torch.float32)
        
        signal_window = signal_window.transpose(1, 0)

        # examples_o["signal_vectors"] = signal_window[1:-1,:]
        examples_o["signal_vectors"] = signal_window
        examples_o["signal_vectors1"] = signal_window
        examples_o["xyz_vectors"] = window_xyz_list
        # examples_o["xyz_vectors"] = None
        examples_o["label"] = label

        

        return examples_o
    
    




model_series = pre_series.from_pretrained("training-runs/story/checkpoint-100000/").to(device)   

# model_series = pre_series.from_pretrained("training-runs/movie_story_r_1000_256_M3/checkpoint-27000/").to(device) 

## linear
### space+time9  16000  77.37
##  space+time10 （tr=1,0.01-0.1，50,0.2， no trend_loss） 10000  79.66%    22000  80%
### space+time11 56400 

### space+time12 10000

for param in model_series.parameters():
    param.requires_grad = False
    
# for param in model_series.vit.encoder.layer[-1].parameters():
#     param.requires_grad = True
    
#print(model.vit.config)
#print(model.vit.embeddings.mask_ratio)
#print(model.vit.embeddings.config.mask_ratio)
model_series.vit.embeddings.mask_ratio = 0.0
model_series.vit.embeddings.config.mask_ratio = 0.0
torch.manual_seed(1234)
random.seed(1234)
np.random.seed(1234)





def collate_fn(examples):
    signal_vectors = torch.stack(
        [example for example in examples["signal_vectors"]], dim=0
    )
    
    signal_vectors1 = torch.stack(
        [example for example in examples["signal_vectors1"]], dim=0
    )

    labels = torch.stack([example for example in examples["label"]])
    
    xyz_vectors = torch.stack([example for example in examples["xyz_vectors"] ])


    return {
        "signal_vectors": signal_vectors,
        "signal_vectors1": signal_vectors1,
        "xyz_vectors": xyz_vectors,
        "input_ids": signal_vectors,
        "labels": labels,
    }





num_classes = 19
class MyModel(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModel, self).__init__()
        # self.linear = nn.Linear(input_size, 128)  ####加线性层   256 or128？？？？？？？？？？？？？？？？？？？
        # self.r = nn.LeakyReLU(0.1)
        self.bn = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(128)
        # self.outs = nn.Linear(128, output_size)
        
        
        # self.linear = nn.Linear(128, 128)  ####加线性层  
        # self.linear2 = nn.Linear(128, 128)  ####加线性层  
        # self.linear3 = nn.Linear(128, 23)  ####加线性层  
        # self.linear4 = nn.Linear(256, 256)  ####加线性层  
        # self.linear5 = nn.Linear(256, 23)  ####加线性层  
        # self.r = nn.LeakyReLU(0.2)

        #"""1024""""
        self.linear = nn.Linear(256, 256)  ####加线性层  
        self.linear2 = nn.Linear(256, 20 )  ####加线性层  
        self.linear5 = nn.Linear(256, 20)  ####加线性层  
        #
        #"""64""""
        # self.linear = nn.Linear(1000, 256)  ####加线性层  
        # self.linear2 = nn.Linear(256, 23)  ####加线性层  
        
        self.linear3 = nn.Linear(256, 23)  ####加线性层  
        self.linear4 = nn.Linear(32, 23)  ####加线性层  
       
        
        # self.linear = nn.Linear(256, 256)  ####加线性层  
        # self.linear2 = nn.Linear(256, 256)  ####加线性层  
        # self.linear3 = nn.Linear(256, 23)  ####加线性层  
        self.r = nn.LeakyReLU(0.1)
        self.r3 = nn.GELU()
        self.r1 = nn.Tanh()
        

   

    def forward(self, x):
        
        # out = self.linear5(self.r(self.linear2(self.r(self.linear(x)))))    ### b,2125*6+1,512 --->  b,2125*6+1,1
        # out = self.linear5(x)
        
        out = self.linear2(self.r(self.linear(x)))  ### b,2125*6+1,512 --->  b,2125*6+1,1
        # out = self.linear(self.r(x))  ### b,2125*6+1,512 --->  b,2125*6+1,1
        # out = self.linear5(x) ### b,2125*6+1,512 --->  b,2125*6+1,1


        return out
    
    

class double_Model(nn.Module):
    def __init__(self):
        super(double_Model, self).__init__()
        
        self.mymodel = MyModel()
        
        self.linear = nn.Linear(16, 7)  ####加线性层  
        self.r = nn.LeakyReLU(0.1)
       

    def forward(self, x1,x2):
        x1 = self.mymodel(x1)
        x2 = self.mymodel(x2)
        out = self.linear(torch.cat([x1,x2],dim=1))    ### b,2125*6+1,512 --->  b,2125*6+1,1  
        
        # out = self.mymodel(torch.cat([x1,x2],dim=2))    ### b,2125*6+1,512 --->  b,2125*6+1,1 


        return out
    
    
#CHCP_tfMRI_origion
data_folder = "/root/autodl-tmp/BrainLM-main/DataSet/HCP_1000t_split"  ###################   HCP  tfMRI   下游任务数据加载   HCP_1000t_split
train_dataset = MyDataset(data_folder=data_folder,flag="train")
val_dataset = MyDataset(data_folder=data_folder,flag="val")   
batch_size = 1
                                                                                                                                                                                                                                                                
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False,num_workers=12,pin_memory=True)


m_model = MyModel(512,num_classes).to(device)

learning_rate = 0.001
optimizer = torch.optim.Adam(m_model.parameters(), lr=learning_rate, weight_decay=1e-4)
# optimizer = torch.optim.SGD(m_model.parameters(), lr=learning_rate, momentum=0.5)
loss_fn = torch.nn.CrossEntropyLoss()
num_epochs = 100


best_acc = 0



for epoch in range(num_epochs):
    # for batch in train_loader:
    for i, (inputs, labels) in enumerate(train_loader):
       
        m_model.train()
   
        
        
        model_series_inputs = collate_fn(inputs)
        # labels = labels.repeat_interleave(model_series_inputs["signal_vectors"].shape[1])
        # # print(model_series_inputs["signal_vectors"])
        signal = model_series_inputs["signal_vectors"].to(device)
        signal2 = model_series_inputs["signal_vectors1"].to(device)
        # zeros_tensor = torch.zeros(signal.size(0), signal.size(1), 20).to(device)
        # signal = torch.cat((signal, zeros_tensor), dim=2).to(device)
        model_series_inputs["signal_vectors"] = signal.to(device)
        model_series_inputs["signal_vectors1"] = signal2.to(device)
        
        outputs_series = model_series(signal_vectors=model_series_inputs["signal_vectors"],signal_vectors1=model_series_inputs["signal_vectors1"],labels=model_series_inputs["labels"],input_ids=model_series_inputs["input_ids"],xyz_vectors=model_series_inputs["xyz_vectors"])
        outputs_series_latent = outputs_series["logits"][1][:,0,:]  ###b,150,256
        outputs_series_latent = outputs_series_latent.reshape(-1,256)  ## b*150,256 ##########1024
        
        # outputs_series_latent2 = outputs_series["logits"][1][:,-1,:]  ###b,150,256
        # outputs_series_latent2 = outputs_series_latent2.reshape(-1,1024)  ## b*150,256 ##########1024
        # outputs_series_latents = torch.cat((outputs_series_latent,outputs_series_latent2),dim=-1)
        # # outputs_series_latents = outputs_series_latent+outputs_series_latent2
        
        
        
        # print(outputs_series_latent[0,:10])
        # outputs_series_latent = outputs_series_latent.reshape(outputs_series_latent.size(0),-1)
        # print(outputs_series_latent[0,:10])
        # print(outputs_series_latent.shape)

        
        outputs = m_model(outputs_series_latent).to(device)
        # print(labels)
        # print(outputs)
        loss = loss_fn(outputs, labels.to(device))
        
        loss.backward()
        if (i+1)%64==0: #
            optimizer.step() # 更新网络参数
            optimizer.zero_grad()
            # print('Training loss: {:.2f}'.format(loss))




    with torch.no_grad():
        m_model.eval()
        correct = 0
        total = 0
        for batch in val_loader:
            inputs, labels = batch
            

            
            model_series_inputs = collate_fn(inputs)
            labels = labels.repeat_interleave(model_series_inputs["signal_vectors"].shape[1])
            signal = model_series_inputs["signal_vectors"].to(device)
            signal2 = model_series_inputs["signal_vectors1"].to(device)
            # zeros_tensor = torch.zeros(signal.size(0), signal.size(1), 20).to(device)
            # signal = torch.cat((signal, zeros_tensor), dim=2).to(device)
            model_series_inputs["signal_vectors"] = signal.to(device)
            model_series_inputs["signal_vectors1"] = signal2.to(device)
            
            outputs_series = model_series(signal_vectors=model_series_inputs["signal_vectors"],signal_vectors1=model_series_inputs["signal_vectors1"],labels=model_series_inputs["labels"],input_ids=model_series_inputs["input_ids"],xyz_vectors=model_series_inputs["xyz_vectors"])
            # outputs_series_latent = outputs_series["logits"][1][:,1:,:]
            # outputs_series_latent = outputs_series_latent.reshape(outputs_series_latent.size(0),-1,424)
            
            outputs_series_latent = outputs_series["logits"][1][:,0,:]  ###b,150,256
            outputs_series_latent = outputs_series_latent.reshape(-1,256)  ## b*150,256  ###############1024
            
            # outputs_series_latent2 = outputs_series["logits"][1][:,-1,:]  ###b,150,256
            # outputs_series_latent2 = outputs_series_latent2.reshape(-1,1024)  ## b*150,256 ##########1024
            # outputs_series_latents = torch.cat((outputs_series_latent,outputs_series_latent2),dim=-1)
            # # outputs_series_latents = outputs_series_latent+outputs_series_latent2

            outputs = m_model(outputs_series_latent).to(device)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels.to(device)).sum().item()

        accuracy = 100 * correct / total
        if accuracy>best_acc:
            best_acc = accuracy
            # torch.save(m_model, '/root/autodl-tmp/BrainLM-main/training-runs/Test_HCP_Task/'+str(accuracy)+'_HCP_Task.pt')
        
        print('Validation accuracy: {:.2f}%'.format(accuracy))