from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from torch.nn.utils import clip_grad_norm_
import gc
from sklearn.metrics import f1_score, roc_auc_score

from src.models.exist_models import *
from src.models.mae import *

from src.models.utils import *

# from src.models.crossvit import *

# print("Device:", DEVICE)

# load physiological model
# # vanilla vit
# physio_model = ViTAdjust()

# # cross vit
# physio_model = CrossSignalViT(device='cuda')
# stat_dict = torch.load('../data/model_checkpoint_cross_freeze_vit100_99.pth', map_location=torch.device('cpu'))['model']
# physio_model.load_state_dict(stat_dict)

# mae
physio_model = MaskedAutoencoderViT(img_size=(387,65), patch_size=(9,5),mask_scheme='random',mask_prob=0.8,use_cwt=True,nvar=4, comb_freq=True)
stat_dict = torch.load('../data/results/model_mae_checkpoint-140.pth', map_location=torch.device('cpu'))['model']
physio_model.load_state_dict(stat_dict)
print("Model load successfull.")
# exit()

# basic init setting
physio_model = physio_model.to(torch.bfloat16).to(DEVICE)
physio_model.eval()

# load nlp model
from transformers import AutoTokenizer, AutoModel

# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
nlp_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to(DEVICE)

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def get_embedding(texts):
    # Sentences we want sentence embeddings for
    sentences = texts

    # Tokenize sentences
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(DEVICE)

    # Compute token embeddings
    with torch.no_grad():
        model_output = nlp_model(**encoded_input)

    # Perform pooling
    return mean_pooling(model_output, encoded_input['attention_mask'])

def eval_res(y_preds, y_trues, task='reg'):
    # y_preds, y_trues = np.nan_to_num(y_preds), np.nan_to_num(y_trues)
    if task == "reg":
        return 1 - np.mean(np.absolute(y_trues - y_preds) / y_trues)
    else:
        print("Classes in Test:", set(y_trues))
        if len(set(y_trues)) <= 2:
            return roc_auc_score(y_trues, y_preds[:, 1])
        else:
            # for i in range(len(y_trues)):
            #     print(y_trues[i], np.argmax(y_preds[i]))
            # print(y_trues, y_preds)
            return roc_auc_score(y_trues, y_preds, multi_class="ovo", average="macro")
            # return f1_score(y_trues, np.argmax(y_preds, axis=1), average="micro")

# ============================================================================================================================

def get_prompt(key, value, use_var=False):
    label_template_raw = {
        'NASA_TLX': 'NASA Task Load Index of {}%.',
        'PSQI':'Pittsburgh Sleep Quality Index of {}%.',
        'regression':'The next value is {}.',
        'Emotion': 'This subject is {}.',
        'Valence': 'Valence level of {}%.',
        'Arousal': 'Arousal level of {}%.',
        'min_bp':'Diastolic blood pressure of {}.',
        'max_bp':'Systolic blood pressure of {}.',
        'activity':'This subject is {}.',
    }

    label_template_var = {
        'NASA_TLX': [
            'NASA Task Load Index recorded at {}%.',
            'The workload according to NASA TLX is {}%.',
            'Measured NASA TLX is {}%.',
            'NASA Task Load Index score is {}%.'
        ],
        'PSQI': [
            'Pittsburgh Sleep Quality Index stands at {}%.',
            'PSQI score is {}%.',
            'The PSQI result is {}%.',
            'Sleep quality (PSQI) is at {}%.'
        ],
        'Emotion': [
            'The emotion detected is {}.',
            'This subject is feeling {}.',
            'The emotional state is {}.',
            'The identified emotion is {}.'
        ],
        'Valence': [
            'Valence level detected is {}%.',
            'The valence level is {}%.',
            'Valence is measured at {}%.',
            'Measured valence level is {}%.'
        ],
        'Arousal': [
            'Arousal level detected is {}%.',
            'The arousal level is {}%.',
            'Arousal is measured at {}%.',
            'Measured arousal level is {}%.'
        ],
        'min_bp': [
            'Diastolic blood pressure recorded at {}.',
            'The diastolic pressure is {}.',
            'Diastolic BP is {}.',
            'Measured diastolic pressure is {}.'
        ],
        'max_bp': [
            'Systolic blood pressure recorded at {}.',
            'The systolic pressure is {}.',
            'Systolic BP is {}.',
            'Measured systolic pressure is {}.'
        ],
        'activity': [
            'This subject is currently {}.',
            'The subject is engaged in {}.',
            'Current activity is {}.',
            'Subject\'s activity is {}.'
        ]
    }

    label_template = label_template_raw if not use_var else label_template_var

    # Emotion_idx = {
    #     "Neutral": 0,
    #     "Happy": 1,
    #     "Surprise": 2,
    #     "Disgust": 3,
    #     "Fear": 4,
    #     "Sad": 5,
    #     "Anger":6,
    #     "Mixed":7,
    # }
    
    # if key == 'Emotion':
    #     value = list(filter(lambda x: Emotion_idx[x] == value, Emotion_idx))[0]
    
    if key in ["NASA_TLX", "PSQI", "Valence", "Arousal", "min_bp", "max_bp"]:
        value = int(value)
    
    if not use_var:
        prompt_label = label_template[key].format(str(value))
    else:
        prompt_label = [s.format(str(value)) for s in label_template[key]]
    
    return prompt_label

def preprocess_to_get_embeddings(use_var=False, saved_folder_name='embeddings_mae'):
    question_template_raw = {
        'NASA_TLX': 'What is the NASA Task Load Index in Percentage?',
        'PSQI':'What is the Pittsburgh Sleep Quality Index in Percentage?',
        'regression':'What is the next value?',
        'Emotion': 'What is current emotion?',
        'Valence': 'What is the valence level in Percentage?',
        'Arousal': 'What is the arousal level in Percentage?',
        'min_bp':'What is the diastolic blood pressure?',
        'max_bp':'What is the systolic blood pressure?',
        'activity':'What is the current activity?',
    }

    question_template_var = {
        'NASA_TLX': [
            'What is the NASA Task Load Index in Percentage?',
            'Can you provide the NASA TLX score as a percentage?',
            'What percentage does the NASA Task Load Index indicate?',
            'Please tell me the NASA Task Load Index in percentage terms.',
            'How much is the NASA TLX score expressed in percentage?'
        ],
        'PSQI': [
            'What is the Pittsburgh Sleep Quality Index in Percentage?',
            'Can you give me the PSQI score as a percentage?',
            'What percentage is the Pittsburgh Sleep Quality Index?',
            'Please provide the PSQI as a percentage.',
            'How is the Pittsburgh Sleep Quality Index represented in percentage?'
        ],
        'Emotion': [
            'What is the current emotion?',
            'Can you identify the current emotion?',
            'What emotion is being expressed now?',
            'What is the emotion right now?',
            'Could you specify the current emotion?'
        ],
        'Valence': [
            'What is the valence level in Percentage?',
            'Can you provide the valence level as a percentage?',
            'What percentage represents the valence level?',
            'Please tell me the valence level in percentage terms.',
            'How much is the valence level expressed in percentage?'
        ],
        'Arousal': [
            'What is the arousal level in Percentage?',
            'Can you give me the arousal level as a percentage?',
            'What percentage represents the arousal level?',
            'Please provide the arousal level in percentage.',
            'How much is the arousal level expressed in percentage?'
        ],
        'min_bp': [
            'What is the diastolic blood pressure?',
            'Can you provide the diastolic blood pressure reading?',
            'What is the current diastolic BP?',
            'Please tell me the diastolic blood pressure.',
            'What is the diastolic pressure value?'
        ],
        'max_bp': [
            'What is the systolic blood pressure?',
            'Can you provide the systolic blood pressure reading?',
            'What is the current systolic BP?',
            'Please tell me the systolic blood pressure.',
            'What is the systolic pressure value?'
        ],
        'activity': [
            'What is the current activity?',
            'Can you identify the current activity?',
            'What activity is happening right now?',
            'Please tell me what the current activity is.',
            'What is the activity being performed currently?'
        ],
    }

    question_template = question_template_raw if not use_var else question_template_var

    for ds_name in [
        "dalia", 
        "cf", 
        "mendeley", 
        "auditory", 
        "phyatt", 
        "maus", 
        # "physionet", # autoregression task, would be better to remove this
    ]:
        print("Processing {}...".format(ds_name))
        for fn in tqdm(sorted(os.listdir("../data/pretrain/{}".format(ds_name)))):
            if fn[0] == "." or os.path.isfile("../data/{}/{}_{}_{}".format(saved_folder_name, ds_name, 0, fn)):
                continue
            try:
                with open("../data/pretrain/{}/{}".format(ds_name, fn), 'rb') as f:
                    data = pickle.load(f) # tss, cwt, sensor, label
            except:
                print("Error loading:", "../data/pretrain/{}/{}".format(ds_name, fn))
                continue
            
            # process through physiological model
            # print(data['cwt'].shape) # (C, L, F, 3)
            with torch.no_grad():
                try:
                    physio_out = physio_model.forward_all(torch.from_numpy(data['cwt']).to(torch.bfloat16).permute(0, 3, 1, 2).unsqueeze(0).to(DEVICE)).cpu().float().numpy().astype(np.float16) # (C, L, 768)
                except:
                    print("Forward error:", fn)
                    # print(torch.from_numpy(data['cwt']).shape)
                    # exit()
                    continue

                if np.isnan(np.sum(physio_out)):
                    print("nan output:", fn)
                    continue
                
                # print(physio_out.shape)
                # exit()

            # process through nlp model
            c_i = 0
            for l in data['label']:
                if not use_var:
                    sentences = [question_template[l], get_prompt(l, data['label'][l])]
                    nlp_embeds = get_embedding(sentences).cpu().numpy().astype(np.float16)
                    question, answers = nlp_embeds[0], nlp_embeds[1]
                else:
                    sentences = question_template[l] + get_prompt(l, data['label'][l], use_var=True)
                    nlp_embeds = get_embedding(sentences).cpu().numpy().astype(np.float16)
                    question, answers = nlp_embeds[:len(question_template[l]), :], nlp_embeds[len(question_template[l]):, :]

                curr_record = {
                    "phsio": physio_out, # (C, L, 768)
                    "question": question, # (384) if not use_var, else (n, 384)
                    "answers": answers # (384) if not use_var, else (n, 384)
                }
                # for k in curr_record:
                #     print(k, curr_record[k].shape, np.sum(curr_record[k]))
                # exit()

                # save
                with open("../data/{}/{}_{}_{}".format(saved_folder_name, ds_name, c_i, fn), 'wb') as f:
                    pickle.dump(curr_record, f)
                c_i += 1
            
            # clear cache
            torch.cuda.empty_cache()
            gc.collect()

# ============================================================================================================================

class EmbMatch_Dataset(Dataset):
    def __init__(self, embed_path="data/embeddings", one_ds_only=None, max_seconds=6, max_channels=6, use_var=False): 
        self.embed_path = embed_path
        self.fnames = [fn for fn in sorted(os.listdir(embed_path)) if fn[0] != '.']
        self.max_seconds = max_seconds
        self.max_channels = max_channels
        self.max_len = (65*max_channels*max_seconds)+max_channels # 6s, 6 channels. 65*6*6 + 6
        self.max_len_each = 65*max_seconds
        self.use_var = use_var
        
        # filter
        if one_ds_only is not None:
            self.fnames = [fn for fn in self.fnames if one_ds_only in fn]

    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, idx):
        # read data
        f_path = os.path.join(self.embed_path, self.fnames[idx].replace("\\", "/"))
        with open(f_path, 'rb') as f:
            curr_sample = pickle.load(f)
        
        # phsio, crossvit_phsio
        curr_sample["phsio"] = curr_sample["phsio"][0] # in case the shape is (1, C, L, H)
        C, L, H  = curr_sample["phsio"].shape

        # clip channels
        if C >= self.max_channels:
            sample = curr_sample["phsio"][:self.max_channels, :, :]
        else:
            sample = curr_sample["phsio"]

        # clip length of each channel
        if L-1 >= self.max_len_each:
            sample = torch.cat((sample[:, :1, :], sample[:, -self.max_len_each:, :]), dim=1)

        # pad total length
        C, L, H = sample.shape
        sample = torch.from_numpy(sample.reshape(C*L, H))
        mask = torch.ones(C*L)
        if C*L < self.max_len:
            sample = torch.cat((torch.zeros(self.max_len-(C*L), H), sample), dim=0)
            mask = torch.cat((torch.zeros(self.max_len-(C*L)), mask), dim=0)
        sample = sample[-self.max_len:, :]
        mask = mask[-self.max_len:]
        
        # get sentences
        if not self.use_var:
            questions = curr_sample["question"] # (384)
            answers = curr_sample["answers"] # (384)
        else:
            questions = curr_sample["question"][np.random.choice(np.arange(len(curr_sample["question"])), 1)[0]] # sample 1 from (n, 384)
            answers = curr_sample["answers"][np.random.choice(np.arange(len(curr_sample["answers"])), 1)[0]] # sample 1 from (n, 384)

        return {
            "embed_in": sample.float(),
            "questions": torch.from_numpy(questions).float(),
            "answers": torch.from_numpy(answers).float(),
            "mask": mask
        }
    
def optim_helper(model, optimizer, loss_f, batch_out, batch_label, scheduler):
    # calculate and backprop
    loss = loss_f(batch_out, batch_label)
    optimizer.zero_grad() # clear cache
    loss.backward() # calculate gradient

    # address gradient vanishing
    for p in model.parameters(): # addressing gradient vanishing
        if p.requires_grad and p.grad is not None:
            p.grad = torch.nan_to_num(p.grad, nan=0.0)
    
    # address gradient explosion
    clip_grad_norm_(model.parameters(), 5)

    # update
    optimizer.step() # update parameters
    scheduler.step()
    return loss.detach().cpu().float().item()

def train_matcher(
        lr=1e-3,
        epochs=40,
        batch_size=32,
        step_size=10,
        gamma=0.997,
        embed_folder='embeddings',
        one_ds_only=None,
        use_var=False
):  
    torch.cuda.empty_cache()
    gc.collect()
    # matcher = EmbedMatch().to(torch.bfloat16).to(DEVICE)
    matcher = TemporalFusion()
    # matcher.load_state_dict(torch.load('../data/matcher_trained_weights_{}.pt'.format(embed_folder), map_location=torch.device('cpu')))
    matcher = matcher.to(torch.bfloat16).to(DEVICE)

    dataloader = DataLoader(
        EmbMatch_Dataset(embed_path="data/{}".format(embed_folder), one_ds_only=one_ds_only,use_var=use_var), 
        batch_size=batch_size, 
        shuffle=True,
        # num_workers=4,
    )

    # construct optimizer
    optimizer = optim.Adam(
        matcher.parameters(),
        lr=lr,
        weight_decay=5e-6
    )
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    # loss
    loss_l1 = nn.L1Loss()
    loss_cos = nn.CosineEmbeddingLoss()
    loss_f = lambda x, y: 0.5*loss_l1(x, y) + 0.5*loss_cos(x, y, torch.ones(len(y)).to(DEVICE))

    # train matcher
    gc.collect()
    train_losses = list()
    for e in tqdm(range(epochs)):
        matcher.train()
        for data_pack in tqdm(dataloader):
            out = matcher(
                data_pack["embed_in"].to(torch.bfloat16).to(DEVICE), 
                data_pack["questions"].to(torch.bfloat16).to(DEVICE), 
                mask=data_pack['mask'].to(torch.bfloat16).to(DEVICE)
            )

            # update
            loss = optim_helper(matcher, optimizer, loss_f, out, data_pack["answers"].to(torch.bfloat16).to(DEVICE), scheduler)
            train_losses.append(loss)
            torch.cuda.empty_cache()
        gc.collect()
    
        # save results
        if one_ds_only is not None:
            saved_loss_path = "matcher_loss_{}_{}.pkl".format(embed_folder, one_ds_only)
        else:
            saved_loss_path = "matcher_loss_{}_cf.pkl".format(embed_folder)
        with open(saved_loss_path, 'wb') as f:
            pickle.dump(np.array(train_losses).astype(np.float16), f)
        
        matcher.eval()
        if one_ds_only is not None:
            saved_weight_path = 'matcher_trained_weights_{}_{}.pt'.format(embed_folder, one_ds_only)
        else:
            saved_weight_path = '../data/matcher_trained_weights_{}.pt'.format(embed_folder)
        torch.save(matcher.state_dict(), saved_weight_path)

        # for download command
        pod_n = "physio-model-59d4995db8-b6swn"
        saved_weight_path = saved_weight_path[3:]
        print("kubectl cp {}:{} {} -c physio-model".format(pod_n, saved_weight_path, saved_weight_path))
        print("kubectl cp {}:ppg_bp/{} {} -c physio-model".format(pod_n, saved_loss_path, saved_loss_path))

def check_loss(embed_folder="embedding_ds"):
    with open("matcher_loss_{}_cf.pkl".format(embed_folder), 'rb') as f:
        losses = pickle.load(f)
    plt.plot(losses, label="Raw loss")
    plt.legend()
    plt.show()

    plt.plot([np.mean(losses[:i]) for i in range(1, len(losses))], label='Running Means')
    plt.legend()
    plt.show()

    plt.plot([np.mean(losses[i-128:i]) for i in range(1, len(losses))], label='Window Means')
    plt.legend()
    plt.show()

# ============================================================================================================================

def max_match(query, keys, task='reg'): # (384), (N_c, 384)
    # # cosine similarity
    # cos_sims = torch.sum(query.unsqueeze(0)*keys, dim=1, keepdim=True) # (N_c, 1)
    # deno = torch.sqrt(torch.sum(query)**2) * torch.sqrt(torch.sum(keys, dim=1, keepdim=True)**2) # (N_c, 1)
    # cos_sims = (cos_sims / deno).squeeze(1) # (N_c)
    # # return torch.argmax(sims)

    # Euclidean distance
    sims = torch.sqrt(torch.sum((keys-query.unsqueeze(0))**2, dim=1)) # (N_c)

    # # integrate
    # sim_w = 1.0
    # sims = sim_w*sims + (1-sim_w)*cos_sims

    # return
    if task == 'reg':
        return torch.argmin(sims)
    else:
        sims = 1 - (sims / torch.sum(sims))
        sims = torch.nan_to_num(sims) + 1e-8
        # print(torch.sum(nn.functional.softmax(sims)))
        # exit()
        return nn.functional.softmax(sims.float(), dim=0).cpu().numpy().tolist()
        # return (sims / torch.sum(sims)).float().numpy().tolist()

def downstream_eval(matcher, ds_name, task="class", preprocess_only=False):
    # load ground truth label set
    print("Processing", ds_name, "...")
    with open("../data/downstream/{}/label_map.pkl".format(ds_name), 'rb') as f:
        label_map = pickle.load(f) # label (str) -> embeddings
    label_vals = [k for k in label_map]
    label_keys = torch.stack([torch.from_numpy(label_map[k]) for k in label_map]).to(torch.bfloat16).to(DEVICE)

    # print(label_vals)
    # exit()
    
    # check_set = list()
    with open("../data/downstream/{}/splits".format(ds_name), 'rb') as f:
        split = pickle.load(f)
        train_fnames = [fn.replace('\\', '/').split("/")[-1] for fn in split["train_fnames"]]
        test_fnames = [fn.replace('\\', '/').split("/")[-1] for fn in split["test_fnames"]]
    
    # finetune
    # TBD, currently removed here
    
    # zero-shot inference
    y_trues, y_preds = list(), list()
    matcher.eval()
    saved_folder_name = "fusion_crossvit" # fusion, fusion_crossvit
    os.makedirs("../data/downstream/{}/{}".format(ds_name, saved_folder_name), exist_ok=True)
    all_fns = test_fnames if not preprocess_only else sorted(os.listdir("../data/{}/samples".format(ds_name)))
    # for fn in tqdm(sorted(os.listdir("../data/{}/samples".format(ds_name)))):
    for fn in tqdm(all_fns):
        # if fn[0] == '.' or fn not in test_fnames: # change to test_fnames when not during preprocess embeddings and run only on test_fnames
        if fn[0] == '.': # process on all fnames
            continue
        with open("../data/downstream/{}/samples/{}".format(ds_name, fn), 'rb') as f:
            data = pickle.load(f)
        with open("../data/downstream/{}/nlp_embed/{}".format(ds_name, fn), 'rb') as f:
            nlp_embed = pickle.load(f)
            # check_set.append(nlp_embed['raw_label'])
        
        # phsio_model embedding
        with torch.no_grad():
            # print(torch.from_numpy(data['cwt']).float().permute(0, 3, 1, 2).shape)
            # exit()
            physio_out = physio_model.forward_all(torch.from_numpy(data['cwt']).to(torch.bfloat16).permute(0, 3, 1, 2).unsqueeze(0).to(DEVICE))[0] # (C, L, 768)
            # physio_out = physio_model(torch.from_numpy(data['cwt']).to(torch.bfloat16).permute(0, 3, 1, 2).to(DEVICE)) # (C, L, 768)

            # # read from preprocessed file
            # with open("", 'wb') as f:
            #     physio_out = pickle.load(f)

            C, L, H = physio_out.shape

            # save the phsio_out as numpy
            # if preprocess_only:
            #     # for downstream task
            #     saved_embed = {
            #         "phsio": physio_out.cpu().numpy().astype(np.float16),
            #         "question": nlp_embed['question'],
            #         "answers": nlp_embed['label']
            #     }
            #     with open("data/embeddings_ds/{}_{}".format(ds_name, fn), 'wb') as f:
            #         pickle.dump(saved_embed, f)

            #     continue

            # zero short forward
            zero_shot_out = matcher(
                physio_out.reshape(C*L, H).unsqueeze(0).to(torch.bfloat16).to(DEVICE), 
                torch.from_numpy(nlp_embed['question']).unsqueeze(0).to(torch.bfloat16).to(DEVICE)
            )[0] # 384

            # save for linear probing
            if preprocess_only:
                with open("data/{}/{}/{}".format(ds_name, saved_folder_name, fn), 'wb') as f:
                    pickle.dump(zero_shot_out.cpu().float().numpy().astype(np.float16), f)
                continue

            curr_y_true = label_vals[max_match(torch.from_numpy(nlp_embed['label']).to(DEVICE), label_keys, task='reg')]
            # curr_y_true = nlp_embed['raw_label']
            
            if task == 'reg':
                y_preds.append(label_vals[max_match(zero_shot_out, label_keys, task=task)])
                y_trues.append(curr_y_true)
            else:
                y_preds_prob = max_match(zero_shot_out, label_keys, task=task)
                # if sum(np.nan_to_num(y_preds_prob)) != 1.0:
                #     continue
                y_preds.append(y_preds_prob)
                y_trues.append(curr_y_true)
            
            # print(y_trues[-1], y_preds[-1])
        
        torch.cuda.empty_cache()
        gc.collect()
        # # for test the pipeline
        # if len(y_trues) == 40:
        #     break

    # evaluation
    # print(set(y_trues))
    # exit()
    if preprocess_only:
        return

    # evaluate
    if task == 'reg':
        if type(y_preds[0]) == tuple: # for the (sysbp, diabp)
            y_preds = [list(n) for n in y_preds]
            y_trues = [list(n) for n in y_trues]
        else: # usual case
            y_preds = [float(n) for n in y_preds]
            y_trues = [float(n) for n in y_trues]
            print(set(y_preds), set(y_trues))
    else:
        # y_preds = [int(n) for n in y_preds]
        y_trues = [int(n) for n in y_trues]
        print(set(np.argmax(y_preds, axis=1)), set(y_trues))
    print("Final Score:", eval_res(np.array(y_preds), np.array(y_trues), task=task))

if __name__ == '__main__':
    # embeddings, embeddings_cross_freeze_vit, embeddings_mae
    embed_folder = 'embeddings_mae_var'
    use_var = True

    # # preprocess_to_get_embeddings(use_var=use_var, saved_folder_name=embed_folder)
    # train_matcher(embed_folder=embed_folder, use_var=use_var)
    # check_loss(embed_folder=embed_folder)
    # exit()

    # downstream eval
    # load trained model
    # # for vanilla fuse
    # matcher = EmbedMatch().to(torch.bfloat16).to(DEVICE)
    # matcher.load_state_dict(torch.load("../data/matcher_trained_weights.pt", map_location=DEVICE)) # for vanilla vit
    # matcher.load_state_dict(torch.load("../data/matcher_trained_weights_{}_cf_raw.pt".format(embed_folder), map_location=DEVICE)) # for cross vit

    # for msitf
    matcher = TemporalFusion().to(torch.bfloat16).to(DEVICE)
    matcher.load_state_dict(torch.load("../data/matcher_trained_weights_{}.pt".format(embed_folder), map_location=DEVICE)) # for cross vit msitf fuse

    # main process
    for ds in [
        ('ppg_hgb', 'reg'),
        ('indian-fPCG', 'reg'),
        ("PPG_HTN", "class"),
        ("PPG_DM", "class"),
        ("PPG_CVA", "class"),
        ("PPG_CVD", "class"),
        ('non_invasive_bp', 'reg'),
        ("ecg_heart_cat", "class"),
        ("gameemo", "class"),
        ("drive_fatigue", "class"),
        ("uci_har", "class"),
        ("wesad", "class"),
    ]:
        # embeddings, embeddings_ds, crossvit_embeddings, embeddings_cross_freeze_vit
        # downstream_eval(ds[0], task=ds[1], preprocess_only=True, embed_folder='embeddings_cross_freeze_vit')

        #
        downstream_eval(matcher, ds[0], task=ds[1], preprocess_only=False)