"""
This file defines the RGB and IMU models, the training functions and runs the experiments.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from datasets.dataset import load_dataloaders


import os
from argparse import ArgumentParser
import wandb
from tabulate import tabulate
from torchvision.models import resnet18, ResNet18_Weights, vit_b_16, ViT_B_16_Weights, VisionTransformer
import numpy as np
import matplotlib.pyplot as plt

from custom_vision_transformer import TempRepsVisionTransformer
# from transformers import ViTFeatureExtractor, ViTModel
# be careful not to confuse your custom vision transformer with the actual one



import subprocess
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

from tqdm import tqdm
import time
import sys
import signal
DEBUG_TIME = False 

#GET HOME DIRECTORY
HOME_DIR = os.environ['HOME']

# NOTE: Need to have my custom fork of Imagebind cloned for this to work (not linking here to keep anonymity during submission)
# Use imagebind without packaging it
sys.path.append(f'{HOME_DIR}')
sys.path.append(f'{HOME_DIR}/ImageBind')
from ImageBind.finetune_lp_ANON import ImageBind_baseline, train_linear, evaluate_ib, eval_log_ib


# from models import Early_Fusion, Middle_Fusion, Late_Fusion, IMU_MLP, joint_IMU_MLP
# from train_utils import train, evaluate

args = None #main will fill in args 
overall_start_time = time.time()

# print(os.environ) #debugging

actions_dict = {
    1: 'Swipe left',
    2: 'Swipe right',
    3: 'Wave',
    4: 'Clap',
    5: 'Throw',
    6: 'Arm cross',
    7: 'Basketball shoot',
    8: 'Draw X',
    9: 'Draw circle (clockwise)',
    10: 'Draw circle (counter clockwise)',
    11: 'Draw triangle',
    12: 'Bowling',
    13: 'Boxing',
    14: 'Baseball swing',
    15: 'Tennis swing',
    16: 'Arm curl',
    17: 'Tennis serve',
    18: 'Push',
    19: 'Knock',
    20: 'Catch',
    21: 'Pickup and throw',
    22: 'Jog',
    23: 'Walk',
    24: 'Sit to stand',
    25: 'Stand to sit',
    26: 'Lunge',
    27: 'Squat'
}

# Cross-modal Fusion Method
class CM_Fusion(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rgb_video_length, imu_length, imu_channels=6, keep_time=False):
        super(CM_Fusion, self).__init__()
        self.hidden_size = hidden_size #here hiddent size will be the size the two features join at (addition)
        self.FE_rgb = RGB_Action(output_size = hidden_size, video_length = rgb_video_length, keep_time=keep_time)
        # self.FE_imu = IMU_MLP(input_size, hidden_size*2, hidden_size)
        self.FE_imu = IMU_Action(input_channels=imu_channels, hidden_size=hidden_size*2, output_size=hidden_size, keep_time=keep_time, imu_length = imu_length)
        # The module below is just called IMU_MLP, it's actually MLP with dropout. We can use it for joint processing
        self.joint_processing = IMU_MLP(input_size = hidden_size, hidden_size = hidden_size//2, output_size = output_size)
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.keep_time = keep_time

        if self.keep_time:
            if args.c3t_head=='vit_clstkn':
                self.joint_processing  = TempRepsVisionTransformer(
                    seq_length=rgb_video_length//2, # 30/2, with 30 as rgb_video_length and 180 as imu_video_length it comes down to 15, can adjust later
                    num_layers=1, 
                    num_heads=8, #2048//8 = 256 dim per head 
                    hidden_dim=hidden_size, # keep the same to align the batch class token
                    mlp_dim=hidden_size, # 3072
                    num_classes=output_size,
                )
            elif args.c3t_head=='vit_proj':
                # NOTE: This method probs can't handle variable sized inputs because it concats and projects all the tokens for the vit instead of using the class token.
                self.joint_processing = TempRepsVisionTransformer(
                    seq_length=rgb_video_length//2, # 30/2, with 30 as rgb_video_length and 180 as imu_video_length it comes down to 15, can adjust later
                    num_layers=1, 
                    num_heads=8, #2048//8 = 256 dim per head 
                    hidden_dim=hidden_size, # keep the same to align the batch class token
                    mlp_dim=hidden_size, # 3072
                    num_classes=hidden_size//4,
                    linear_proj=True
                )
            elif args.c3t_head=='mlp_add':
                self.joint_processing = IMU_MLP(hidden_size, hidden_size//2, output_size)
            elif args.c3t_head=='mlp_concat':
                self.joint_processing = MLP(hidden_size*15, hidden_size//2, output_size)


    def forward(self, x, sensors, entropy_fusion=False, clip_align=False, clip_eval=False):        
        # exit()
        if clip_align or clip_eval:
            # This is the CLIP alignment method
            # print("Performing CLIP stuff:")
            # print("x shape",x[0].shape,x[1].shape) # [16, 30, 3, 224, 224], [16, 180, 60] (for czu dataset, 16 is batch size)
            # #try croppping to [16, 24, 3, 224, 224] and [16, 144, 60] for testing
            # x0 = x[0][:,3:27] #crop to 24 frames
            # x1 = x[1][:,18:162] #crop to 144 frames
            # x = (x0,x1)
            # print("x shape",x[0].shape,x[1].shape)
            
            z_rgb = self.FE_rgb(x[0]) # z is the latent feature vector
            z_imu = self.FE_imu(x[1])
            z_rgb = z_rgb / z_rgb.norm(dim=-1, keepdim=True)
            z_imu = z_imu / z_imu.norm(dim=-1, keepdim=True)

            #this generalizes to keep_time=True as well (collapse time dim into batch and then do alignments)
            z_rgb = z_rgb.view(-1, z_rgb.shape[-1])
            z_imu = z_imu.view(-1, z_imu.shape[-1])
            # print("z_rgb shape",z_rgb.shape) # [240, hidden dimension] (fe rgb decreases time by half) * batch size
            # print("z_imu shape",z_imu.shape) # [240, hidden dimension] (fe imu decreases time by ANONor of 12) * batch size
            # print("")
            # So basically the inputs can be any size we want as long as imu length is 6x rgb length (also rgb length should be divisible by 2 => imu/12)

            if clip_align:
                # loss = 1 - (z_rgb * z_imu).sum(dim=-1).mean()
                #using clip style training: https://github.com/openai/CLIP/issues/83
                # cosine similarity as logits
                logit_scale = self.logit_scale.exp()
                logits_per_rgb = logit_scale * z_rgb @ z_imu.t()
                # logit_scale is the temperature parameter, probs help stabilize training with lots of data
                # in our case it makes it not train, or train very slowly...

                # logits_per_rgb =  z_rgb @ z_imu.t() #time cont one works better with temperature!
                logits_per_imu = logits_per_rgb.t()


                curr_device = x[0].device
                ground_truth = torch.arange(z_rgb.shape[0],dtype=torch.long).to(curr_device)

                return logits_per_rgb, logits_per_imu, ground_truth
            elif clip_eval:
                # cosine similarity as logits no temperature
                similarity = z_rgb @ z_imu.t()
                probs = torch.nn.functional.softmax(similarity, dim=-1).max(-1)[1]
                curr_device = x[0].device
                ground_truth = torch.arange(z_rgb.shape[0],dtype=torch.long).to(curr_device)
                
                return probs, ground_truth
        else:
            if 'RGB' in sensors and 'IMU' in sensors:
                z_rgb = self.FE_rgb(x[0])
                z_rgb = z_rgb / z_rgb.norm(dim=-1, keepdim=True)
                z_imu = self.FE_imu(x[1])
                z_imu = z_imu / z_imu.norm(dim=-1, keepdim=True)
                if args.cosine_fusion:
                    z = z_rgb + z_imu
                    z = z / z.norm(dim=-1, keepdim=True)
                else:
                    z = (z_rgb+z_imu)/2

                if entropy_fusion:
                    # idea for entropy based fusion
                    logits_rgb = self.joint_processing(z_rgb)
                    p_rgb = torch.nn.functional.softmax(logits_rgb, dim=1)
                    p_rgb = torch.clamp(p_rgb, 1e-7, 1.0 - 1e-7)
                    entropy_rgb = -torch.sum(p_rgb * torch.log(p_rgb), dim=1)

                    logits_imu = self.joint_processing(z_imu)
                    p_imu = torch.nn.functional.softmax(logits_imu, dim=1)
                    p_imu = torch.clamp(p_imu, 1e-7, 1.0 - 1e-7)
                    entropy_imu = -torch.sum(p_imu * torch.log(p_imu), dim=1)

                    entropies = torch.stack((entropy_rgb, entropy_imu), dim=1)
                    priors = torch.nn.functional.softmax(entropies, dim=1)
                    priors = torch.ones_like(priors) - priors # high entropy => low weight
                    z = priors[:, 0].unsqueeze(1) * z_rgb + priors[:, 1].unsqueeze(1) * z_imu 

            elif 'RGB' in sensors:
                # This is the CLIP alignment method
                # print("RGB HAR:")
                # print("x shape",x.shape) # [16, 30, 3, 224, 224], [16, 180, 60] (for czu dataset, 16 is batch size)
                # #try croppping to [16, 24, 3, 224, 224] and [16, 144, 60] for testing
                # x = x[:,9:27] #crop to 24 frames
                # print("x shape",x.shape)

                z = self.FE_rgb(x) #the decouple_inputs function called in train.py  will only give us rgb here no need for x[0]
                z = z / z.norm(dim=-1, keepdim=True)

                # print("RGB Gradients")
                # for name, param in self.FE_rgb.named_parameters():
                #     if param.requires_grad:
                #         print(name, param.grad)

            elif 'IMU' in sensors:
                z = self.FE_imu(x)
                z = z / z.norm(dim=-1, keepdim=True)

            if self.keep_time:
                # for keep_time z is bs,seq_lenght (15), hidden_size
                # TempRepsVisionTransformer takes this as input but for our mlp we need to concatenate the 15 vectors, or add and normalize them.
                if args.c3t_head=='mlp_concat':
                    z = z.view(z.shape[0],-1)
                elif args.c3t_head=='mlp_add':
                    z = z.sum(dim=1)
                    z = z / z.norm(dim=-1, keepdim=True) #mlp fuse
                
            out = self.joint_processing(z) #joint processing is h, the har head

            return out


# Student teacher model
class Student_Teacher(nn.Module):
    #This is exactly like late fusion where we have two separate models, but we apply each modality given the sensors like in cross modal fusion

    def __init__(self, input_size, hidden_size, output_size, rgb_video_length, imu_length, imu_channels=6):
        super(Student_Teacher, self).__init__()
        self.hidden_size = hidden_size
        #whether rgb or imu is the teacher depends on the code that is using this model, our code uses rgb as the teacher and IMU as the student
        self.rgb_model = RGB_Action(output_size, rgb_video_length)
        # self.imu_model = IMU_MLP(input_size, hidden_size, output_size)
        self.imu_model = IMU_Action(input_channels=imu_channels, hidden_size=hidden_size*2, output_size=output_size, imu_length=imu_length)
        
        

    def forward(self, x, sensors,  entropy_fusion=False):
        # Entropy fusion arg used just to match the same interface as cross-modal module above
        if 'IMU' in sensors and 'RGB' in sensors:
            # print("Both RGB and IMU")
            rgb_out = self.rgb_model(x[0])
            imu_out = self.imu_model(x[1])
            out = (imu_out+rgb_out)/2
            if args.cosine_fusion:
                out = imu_out+rgb_out
                out = out / out.norm(dim=-1, keepdim=True)
        elif 'RGB' in sensors:
            # print("RGB only")
            out = self.rgb_model(x)
        elif 'IMU' in sensors:
            # print("IMU only")
            out = self.imu_model(x)
        
        return out


#Attnetion based feature fusion
class HAMLET(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rgb_video_length, imu_length, imu_channels=6):
        super(HAMLET, self).__init__()
        self.hidden_size = hidden_size #here hiddent size will be the size the two features join at (addition)
        
        #Unimodal encoders
        self.rgb_model = RGB_Action(hidden_size, rgb_video_length)
        self.imu_model = IMU_Action(input_channels=imu_channels, hidden_size=hidden_size*2, output_size=hidden_size, imu_length=imu_length)

        # Multimodal mutli-head self attention
        self.attn = nn.MultiheadAttention(hidden_size, num_heads=2, dropout=0, batch_first=True)



        # feed forward network
        self.joint_processing = IMU_MLP(hidden_size, hidden_size//2, output_size)

    def forward(self, x):
        z_rgb = self.rgb_model(x[0])
        z_imu = self.imu_model(x[1])
        # batch, features
        z_rgb = z_rgb.unsqueeze(1) # add time dimension
        z_imu = z_imu.unsqueeze(1) # add time dimension

        z_features = torch.cat((z_rgb, z_imu), dim=1) # 
        out, _ = self.attn(z_features, z_features, z_features)
        # batch, seq, features = out.shape
        out = out.sum(dim=1) 
        out = self.joint_processing(out)
        return out


#Fusion is adding
class Middle_Fusion(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rgb_video_length, imu_length, imu_channels=6):
        super(Middle_Fusion, self).__init__()
        self.hidden_size = hidden_size #here hiddent size will be the size the two features join at (addition)
        self.rgb_model = RGB_Action(hidden_size, rgb_video_length)
        # self.imu_model = IMU_MLP(input_size, hidden_size*2, hidden_size)
        self.imu_model = IMU_Action(input_channels=imu_channels, hidden_size=hidden_size*2, output_size=hidden_size, imu_length=imu_length)
        self.joint_processing = IMU_MLP(hidden_size, hidden_size//2, output_size)

    def forward(self, x):
        z_rgb = self.rgb_model(x[0])
        z_imu = self.imu_model(x[1])
        z_sum = z_rgb + z_imu
        out = self.joint_processing(z_sum)
        return out

class Early_Fusion(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rgb_video_length, imu_channels=6):
        super(Early_Fusion, self).__init__()
        self.hidden_size = hidden_size #here hidden size will be the size the two models squentially join at
        self.rgb_model = RGB_Action(hidden_size, rgb_video_length) #here we use hiddent size to connect rgb and imu into one big model
        self.imu_model = IMU_MLP(hidden_size, hidden_size//2, output_size)
        # self.imu_model = IMU_Action(input_channels=imu_channels, hidden_size=hidden_size*2, output_size=output_size)
        # In early fusion we should actually just use mlp bc we're just using it as layers like a projection head (nothing imu specific)

    def forward(self, x):
        # Just flatten, add the data, and unflatten
        shape = x[0].shape
        x0flat = x[0].flatten() #rgb
        x1flat = x[1].flatten() #imu
        if len(x0flat) > len(x1flat):
            padding = torch.zeros(x0flat.shape[0] - x1flat.shape[0], dtype=x1flat.dtype, device=x1flat.device)
            x1flat = torch.cat((x1flat, padding))
        else:
            raise ValueError("RGB data is bigger than IMU data, early fusion failed")
            #NOTE: to handle this case could input imu_length into __ini__ and padd rgb to imu and add. But would need to rewrite a lot then bcause then should use IMU models for the combined data.
        x_sum = x0flat + x1flat
        x_sum = x_sum.view(shape)
        
        # For simplicity/reusability right now the full model is just the rgb followed by the IMU
        y_rgb = self.rgb_model(x_sum)
        out = self.imu_model(y_rgb)
        return out
    
class Late_Fusion(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, rgb_video_length, imu_length, imu_channels=6):
        super(Late_Fusion, self).__init__()
        self.hidden_size = hidden_size
        self.rgb_model = RGB_Action(output_size, rgb_video_length)
        # self.imu_model = IMU_MLP(input_size, hidden_size, output_size)
        self.imu_model = IMU_Action(input_channels=imu_channels, hidden_size=hidden_size*2, output_size=output_size, imu_length=imu_length)

    def forward(self, x):
        rgb_out = self.rgb_model(x[0])
        imu_out = self.imu_model(x[1])
        out = (imu_out+rgb_out)/2
        return out

# Define the 3D CNN model 
class RGB_Action(nn.Module):
    def __init__(self, output_size, video_length, keep_time=False):
        super(RGB_Action, self).__init__()
        self.keep_time = keep_time
        self.video_length = video_length
        if args.rgb_svit:
            # vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
            # vit.heads = nn.Linear(vit.hidden_dim, 128*28*28) # 100352 is probs too big lol
            # self.s_feature_extractor = vit # spatial feature extractor
            self.s_feature_extractor = VisionTransformer(
                image_size=224,
                patch_size=16,
                num_classes=128*28*28,
                num_layers=1,
                num_heads=8,
                hidden_dim=args.hidden_size,
                mlp_dim=args.hidden_size,
            )
        else:
            resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
            self.s_feature_extractor = nn.Sequential(*list(resnet.children())[:-4])
            
        if args.rgb_tvit:
            self.t_proj_down = nn.Linear(128*28*28, args.hidden_size)
            # This is our custom Vision Transformer, Temporal Representation ViT
            self.t_feature_extractor = TempRepsVisionTransformer(seq_length=video_length,
                num_layers=1, 
                num_heads=8, #2048//8 = 256 dim per head 
                hidden_dim=args.hidden_size, #this is the output hidden size of the s_feature_extractor
                mlp_dim=args.hidden_size, 
                num_classes=188160, # Model is too biggg. this is to match output of 3d conv
            )
        else:
            self.t_feature_extractor = nn.Sequential(
                nn.Conv3d(128, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
                nn.BatchNorm3d(64),
                nn.ReLU(),
                nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))
            )

        if not self.keep_time:
            self.fc_block = nn.Sequential(
                nn.Linear(int(video_length/2)*64*14*14, 512), # 64C,video_len/2=T,28H,28W
                nn.ReLU(),
                nn.Linear(512, output_size)
            )
        else:
            # self.fc_block = nn.Sequential(
            #     nn.Linear(64*14*14, output_size//4), # 64C,video_len/2=T,28H,28W
            #     nn.ReLU(),
            #     nn.Linear(output_size//4, output_size//2),
            #     nn.ReLU(),
            #     nn.Linear(output_size//2, output_size)
            # )
            self.fc_block = nn.Sequential(
                nn.Linear(64*14*14, output_size//2), # 64C,video_len/2=T,28H,28W
                nn.BatchNorm1d(output_size//2),
                nn.ReLU(),
                nn.Linear(output_size//2, output_size)
             )
        #initialize weights -> I actually don't recall why i did this, perhaps it helps with training
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                

    def forward(self, x):
         # Apply ResNet to each frame
        b, t, c, h, w = x.shape
        # print("Initial shape:",x.shape) #[8,30,3,224,224]
        # x = x.view(-1, c, h, w) #RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
        x = x.reshape(-1, c, h, w)
        # print("x shape",x.shape) #[240,3,224,224] = [b*t,c,h,w] or [8,30,3,224,224]= [b,t,c,h,w]
        if args.rgb_svit:
            # We want to use this class token as the feature vector but the class token needs to be the correct size so we need to project it, which is just the head. we have to use the given size class token if we want to use the pretrained weights, but we can train/retrain the head.
            x = self.s_feature_extractor(x)
            # print("1 Post feature extractor shape:",x.shape) #[240,100352] = [b*t,c] or [8,30,100352]= [b,t,c]
            x = x.reshape(-1, 128, 28, 28)
        else:
            # Original Resnet Method
            x = self.s_feature_extractor(x)
        post_fe_x = x.clone()
        # print("Post feature extractor shape:",x.shape) #[240,128,28,28] = [b*t,c,h,w] or [8,30,128,28,28]= [b,t,c,h,w]


        if args.rgb_tvit:
            x = x.view(b*t, -1) # need to flatten and project down as the output of the spatial feature extraction is two large to be the token dimension of the temporal transformer.
            x = self.t_proj_down(x)
            #insert time dim after b
            x = x.view(b, t, -1) # [8,30,128,14,14]
            # print("before t fe:", x.shape)
            x = self.t_feature_extractor(x)
            # print("after t fe:", x.shape)
            b = x.shape[0]
            c,t,h,w = 64,self.video_length//2,14,14 # THIS IS BAD, i shouldn't hard code this... ideally it should automatically be calculated based on conv structure...
            # We basically just need it to match the output of the 3d conv in the else statement.
        else:
            #insert time dim after b
            x = x.view(b, t, *x.shape[1:]) # [8,30,128,14,14]
            x = x.permute(0,2,1,3,4) # permute to BCTHW bc conv3d expects C before T
            # Apply 3D convolutional layers
            x = self.t_feature_extractor(x)
            # print("after conv3d",x.shape) # [8, 64, 15, 14, 14] #time dimension is 15 here too! it's just video_len/2
            b,c,t,h,w = x.shape
            x = x.view(b,-1) # flatten preserve batch_size

        # print("after t fe:", x.shape) #    [8, 188160]

        #residual connection
        post_fe_x = post_fe_x.view(b,-1).chunk(16,dim=1) #Note that post feature extractor is 2x bigger in the last 4 dims compared to post 3dconv 2^4=16 
        # print("post_fe_x shape",len(post_fe_x),post_fe_x[0].shape)
        chunks = torch.stack(post_fe_x,dim=1)
        # print("chunks shape",chunks.shape)
        summed_chunks = chunks.sum(dim=1)
        # print("summed chunks shape",summed_chunks.shape)
        x = x+summed_chunks #maybe this will help with gradient propogation
        # exit()

        if not self.keep_time:
            #Fully connected for class prediction
            x = self.fc_block(x)
        else:
            x = x.view(b,c,t,h,w) #reshape back to BCTHW
            x = x.permute(0,2,1,3,4) # permute back to BTCHW
            b,t_new,c,h,w = x.shape
            x = x.reshape(b*t_new,-1) #conv3d will decrease the original time by half, hence we call it t_new, technically h and w are also new lol
            # now pass each frame through the fc block
            x = self.fc_block(x)
            x = x.reshape(b,t_new,-1) #reshape back to BT, embeddings
            # The final output will be a batch of t/2 out_putsize vectors
        return x
    
class joint_IMU_MLP(nn.Module): #This is IMU to PID+Action
    def __init__(self, input_size, hidden_size, output_size):
        super(joint_IMU_MLP, self).__init__()
        self.hidden_size = hidden_size
        self.pid = IMU_MLP(input_size, hidden_size, output_size=8) #pid has 8 classes
        model_path = "./models/pid_best_model2048_68.7861.pt"
        self.pid.load_state_dict(torch.load(model_path))
        self.pid.layers[3] = nn.Identity() #remove the last linear layer, output should be bs x hidden_size/16
        # Freeze pid layers
        for param in self.pid.parameters():
            param.requires_grad = False

        self.action_features = MLP(input_size, hidden_size, hidden_size//4)

        self.action = IMU_MLP(hidden_size//4+hidden_size//16, hidden_size, output_size)
        #basically 2/3 contriubtion from inputs and 1/3 from pid (1/4/(1/4+1/16)=2/3)... shouldn't it be swapped? idk can test that

    def forward(self, x):
        pid_out = self.pid(x)
        action_features = self.action_features(x)
        out = self.action(torch.cat((pid_out, action_features), dim=1)) #cat on dim1 to keep batch size
        return out

#Define 1D CNN model and ViT now for IMU
class IMU_Action(nn.Module):
    def __init__(self, input_channels, hidden_size, output_size, imu_length, keep_time=False):
        super(IMU_Action, self).__init__()
        self.hidden_size = hidden_size
        self.keep_time = keep_time
        self.imu_length = imu_length

        if args.imu_vit:
            self.feature_extractor = TempRepsVisionTransformer(
                seq_length=imu_length, 
                num_layers=1, 
                num_heads=1, #2048//8 = 256 dim per head, 8 might be overkill for the imu
                hidden_dim=input_channels, # keep the same to align the batch class token
                mlp_dim=input_channels,
                num_classes=hidden_size//4*imu_length//12,
            )
        else:
            self.feature_extractor = nn.Sequential(
                nn.Conv1d(input_channels, hidden_size, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(hidden_size),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=3, stride=3),
                nn.Conv1d(hidden_size, hidden_size//2, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(hidden_size//2),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
                nn.Conv1d(hidden_size//2, hidden_size//4, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(hidden_size//4),
                nn.ReLU(),
                nn.MaxPool1d(kernel_size=2, stride=2),
            )
        # assuming starting dim is 6x180 output should be hidden_size/4 x 15 time decreases by ANONor of 12
        # Note in our cross-modal fusion we make hidden_size = hidden_size*2, and output_size = hidden_size, just look at when we instatiate this.

        if not self.keep_time:
            # Flatten and apply to all frames at once
            self.fc = nn.Sequential(
                nn.Linear(hidden_size//4*imu_length//12, output_size, dtype=torch.float32),
            )
        else:
            # Apply fc to every frame individually
            self.fc = nn.Sequential(
                nn.Linear(hidden_size//4, output_size, dtype=torch.float32),
            )

    def forward(self, x):
        # x is bs x timesteps x channels
        if args.imu_vit:
            out = self.feature_extractor(x)
            # The convolution decreases imu_timesteps by 12, so let's reshape vit output to match it for consistency.
            out = out.view(out.shape[0], -1, self.imu_length//12)
        else:
            x = x.permute(0,2,1) # permute to bs x channels x timesteps for conv1d
            out = self.feature_extractor(x)
        # print("out shape",out.shape) # bs,1024,15

        if not self.keep_time:
            # Used in CA model
            # Flatten and apply to all frames at once
            out = out.view(out.shape[0], -1)
            out = self.fc(out)
        else:
            # Used in C3T Model
            # Apply fc to every frame individually
            out = out.permute(0,2,1) # permute back to bs x timesteps x channels
            b,t,c = out.shape
            out = out.reshape(b*t, -1)
            out = self.fc(out)
            out = out.reshape(b,t,-1)
            # the final output will be batch of t=15 output_sized vectors

        return out

class IMU_MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(IMU_MLP, self).__init__()
        self.hidden_size = hidden_size
        self.layers = nn.Sequential(
            MLP(input_size, hidden_size, int(hidden_size/4)),
            nn.Dropout(0.5),
            MLP(int(hidden_size/4), int(hidden_size/4), int(hidden_size/16)),
            nn.Linear(int(hidden_size/16), output_size, dtype=torch.float32)
        )

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        out = self.layers(x)
        return out

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size, dtype=torch.float32),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size, dtype=torch.float32),
            nn.BatchNorm1d(output_size),
            nn.ReLU()
        )

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        out = self.layers(x)
        return out

def decouple_inputs(data_batch, model_info, device):
    # Extract the inputs and labels, data loader will always return rgb, sensor, label, pid in that order
    # but are model wants either rgb or sensor or both, so we need to restructure input accordingly
    #THERE is probs a better way to do this...
    if 'IMU' in model_info['sensors'] and 'RGB' in model_info['sensors']:
        inputs = (data_batch[0].to(device).type(torch.float32), data_batch[1].to(device).type(torch.float32))
    else:
        if 'IMU' in model_info['sensors']:
            inputs = data_batch[1].to(device).type(torch.float32)
        elif 'RGB' in model_info['sensors']:
            inputs = data_batch[0].to(device).type(torch.float32)
    if 'HAR' in model_info['tasks'] and 'PID' in model_info['tasks']:
        labels = (data_batch[2], data_batch[3])
    elif 'HAR' in model_info['tasks']:
        labels = data_batch[2]
    elif 'PID' in model_info['tasks']:
        labels = data_batch[3]
    

    return inputs, labels

# Define the training loop
def train(model, train_loader, val_loader, criterion, optimizer, scheduler,  num_epochs, device, model_info, psuedo_labels=False):   
    # print(f"Current GPU: {torch.cuda.current_device()}")

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        #NOTE for mmact dataset time for data load is veryyyy inconsistent, sometimes 3 seconds, sometimes 10 sometimes .01
        # the time for one batch is always around .3-.4 seconds
        # actually what it's doing is all num_works loading during 10 seconds, then it rapidly goes through the next num_worker batches
        # so if num_workers is 8 it will go slow every 8th batch
        if DEBUG_TIME: e1 = time.time() #end1
        # # if device==0: 
        # acc_rgb, acc_imu, acc_both = eval_log(model, val_loader, device, model_info)
        # if device==0: print("finished computing accuracy")
        
        for i, data_batch in enumerate(train_loader):
            if DEBUG_TIME: s1 = time.time() #start1
            if DEBUG_TIME: print("Time for one data load",s1-e1)

            if not args.T_noise_in_eval:
                data_batch[0], data_batch[1] = random_cropbatch(data_batch[0], data_batch[1])

            inputs, labels = decouple_inputs(data_batch, model_info, device)

            optimizer.zero_grad()
            if model_info['fusion_type'] == 'cross_modal' or model_info['fusion_type'] == 'student_teacher':
                outputs = model(inputs, model_info['sensors'])
            elif model_info['fusion_type'] in ['early', 'late', 'middle', 'attn']:
                inputs=[inputs[0],torch.zeros_like(inputs[1])] #only use rgb
                if psuedo_labels:
                    with torch.no_grad():
                        outputs = model(inputs)
                    labels = outputs.argmax(dim=1)
                    inputs = [torch.zeros_like(inputs[0]),inputs[1]] #only use imu
                    outputs = model(inputs) 
                else: 
                    outputs = model(inputs)
            elif model_info['fusion_type'] in ['supervised_rgb', 'supervised_imu', 'supervised_fusion']:
                outputs = model(inputs)
            else:
                raise ValueError("Invalid fusion type")
            if len(model_info['tasks']) == 2:
                loss = criterion(outputs[0], labels[0].to(device)) + criterion(outputs[1], labels[1].to(device))
            else:
                loss = criterion(outputs, labels.to(device))

            loss.backward()           
            optimizer.step()

            running_loss += loss.item()
            # print(f"Current GPU: {torch.cuda.current_device()}, Batch: {i}, device: {device}")

            if (i+1) % 5 == 0:
                # if torch.cuda.current_device() == device.index:
                if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
            
            if DEBUG_TIME: e1 = time.time() #end1
            if DEBUG_TIME: print("Time for one batch",e1-s1)
        scheduler.step()

        if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
            # Log the loss to wandb
            if not args.no_wandb: wandb.log({'train_loss'+(f'_{model_info["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): running_loss / len(train_loader)})

            #just always save based on project_name
            torch.save(model.state_dict(), f'./models/{model_info["project_name"]}_{epoch+1}.pt')

        #NOTE: EVALUATION CURRENTLY ASSUMES ONE OUTPUT LABEL
        if (epoch+1) % args.eval_every == 0:
            acc_rgb, acc_imu, acc_both = eval_log(model, val_loader, device, model_info)
            if args.rank == 0 or args.single_gpu:
                if model_info['fusion_type'] == 'supervised_rgb':
                    acc = acc_rgb
                elif model_info['fusion_type'] == 'supervised_fusion':
                    acc = acc_both
                else:
                    # for supervised imu and UMA setting save model based on imu data
                    acc = acc_imu
                
                #The snippet below is to save the best model
                best_val_acc=0
                best_val_file= None
                try:
                    if not os.path.exists("./models"):
                            os.mkdir("./models")
                    for f in os.listdir('./models/'):
                        prefix = model_info['project_name']+'_best_model'
                        if f.startswith(prefix):
                            best_val_acc = float(f.split("_")[-1][:-3]) #get the accuracy from the filename, remove .pth in the end
                            best_val_file = os.path.join("./models",f)
                    if acc > best_val_acc:
                        if best_val_file: os.remove(best_val_file)
                        best_val_acc = acc
                        if model_info['fusion_type'] not in ['supervised_rgb', 'supervised_imu', 'supervised_fusion']:
                            # the single modal supervised baselines don't have an intermediate latent vector (no hidden size)
                            hidden_size = model.hidden_size if args.single_gpu else model.module.hidden_size
                        else:
                            hidden_size = 0
                        fname = f'./models/{model_info["project_name"]}_best_model{hidden_size}_{acc:.4f}.pt'
                        torch.save(model.state_dict(), fname)
                except (FileExistsError, FileNotFoundError) as e:
                    # lol i'm not sure why the "if not os.path.exists" is failing maybe cuz i have some symlinks?
                    print("Error saving model",e)

def inter_train(model, train_loader, train_2_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, model_info):
    # Here we intersperse the training for aligning representation through CLIP and training rgb har model
    
    model.train()
    loss_imu = nn.CrossEntropyLoss()
    loss_rgb = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):

        # First do the CLIP training
        model.train()
        running_loss_CLIP = 0.0
        for i, data_batch in enumerate(train_loader):
            
            inputs, labels = decouple_inputs(data_batch, model_info, device)

            optimizer.zero_grad()

            logits_per_rgb, logits_per_imu, ground_truth = model(inputs, model_info['sensors'], clip_align=True)

            # # outputs = model(inputs)
            # z_rgb = model.FE_rgb(inputs[0]) # z is the latent feature vector
            # z_imu = model.FE_imu(inputs[1])
            # z_rgb = z_rgb / z_rgb.norm(dim=-1, keepdim=True)
            # z_imu = z_imu / z_imu.norm(dim=-1, keepdim=True)

            # #this generalizes to keep_time=True as well (collapse time dim into batch and then do alignments)
            # z_rgb = z_rgb.view(-1, z_rgb.shape[-1])
            # z_imu = z_imu.view(-1, z_imu.shape[-1])

            # logits_per_rgb =  z_rgb @ z_imu.t()
            # logits_per_imu = logits_per_rgb.t()

            # ground_truth = torch.arange(z_rgb.shape[0],dtype=torch.long,device=device)

            if args.beta is None:
                total_loss = (loss_rgb(logits_per_rgb,ground_truth) + loss_imu(logits_per_imu,ground_truth))/2
            else:
                total_loss = (loss_rgb(logits_per_rgb,ground_truth) + loss_imu(logits_per_imu,ground_truth))/2*(1-args.beta)
            # Why are we averaging and scaling by beta? don't we usually just add losses together?

            total_loss.backward()
            optimizer.step()
            running_loss_CLIP += total_loss
            if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                if (i+1) % 5 == 0:
                    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), total_loss.item()))

                #just always save based on project_name
                torch.save(model.state_dict(), f'./models/{model_info["project_name"]}_{epoch+1}.pt')
        
        if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
            # Log the loss to wandb
            if not args.no_wandb: wandb.log({'CLIP_loss': running_loss_CLIP.item() / len(train_loader)})

        #Eval and save best model
        if (epoch+1) % args.eval_every == 0:
            acc = CLIP_evaluate(model, val_loader, device, model_info)
            if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                print('Test accuracy CLIP: {:.4f} %'.format(acc))
                if not args.no_wandb: wandb.log({'CLIP_val_acc'+(f'_{model_info["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): acc})
                #The snippet below is to save the best model
                best_val_acc=0
                best_val_file= None
                if not os.path.exists("./models"):
                    os.mkdir("./models")
                for f in os.listdir('./models/'):
                    prefix = model_info['project_name']+"-FEs"+'_best_model'
                    if f.startswith(prefix):
                        best_val_acc = float(f.split("_")[-1][:-3]) #get the accuracy from the filename, remove .pth in the end
                        best_val_file = os.path.join("./models",f)
                if acc > best_val_acc:
                    if best_val_file: os.remove(best_val_file)
                    best_val_acc = acc
                    fname = f'./models/{model_info["project_name"]}-FEs_best_model{model.hidden_size if args.single_gpu else model.module.hidden_size}_{acc:.4f}.pt'
                    torch.save(model.state_dict(), fname)

        # --------------------------------------------------
        # Now do the RGB model training
        camera_only = model_info.copy()
        camera_only['sensors'] = ['RGB']
        running_loss = 0.0
        for i, data_batch in enumerate(train_2_loader):
            
            inputs, labels = decouple_inputs(data_batch, model_info=camera_only, device=device)

            optimizer.zero_grad()

            assert model_info['fusion_type'] == 'cross_modal' #otherwise next line will be an error
            outputs = model(inputs, sensors=['RGB']) 

            #assume one task for now
            # if len(model_info['tasks']) == 2:
            #     loss = criterion(outputs[0], labels[0].to(device)) + criterion(outputs[1], labels[1].to(device))
            # else:
            if args.beta is None:
                loss = criterion(outputs, labels.to(device))
            else:
                loss = criterion(outputs, labels.to(device))*args.beta
            loss.backward()
            optimizer.step()
            running_loss += loss
            if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                if (i+1) % 5 == 0:
                    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_2_loader), loss.item()))

                    #just always save based on project_name
                    torch.save(model.state_dict(), f'./models/{model_info["project_name"]}_{epoch+1}.pt')
        
        if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
            # Log the loss to wandb
            if not args.no_wandb: wandb.log({'train_loss'+(f'_{camera_only["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): running_loss.item() / len(train_2_loader)})

        if (epoch+1) % args.eval_every == 0:
            eval_log(model, val_loader, device, model_info)


def intra_train(model, train_loader, train_2_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, model_info):
    # Here we train both models within the same batch. e.g. there is only one loss being updated
    
    model.train()
    loss_imu = nn.CrossEntropyLoss()
    loss_rgb = nn.CrossEntropyLoss()
    loss_mse = nn.MSELoss()

    for epoch in range(num_epochs):

        # Do both training at the same time
        model.train()
        camera_only = model_info.copy()
        camera_only['sensors'] = ['RGB']
        running_loss_CLIP = 0.0 #Running loss is through the whole epoch, not just one batch
        running_loss_HAR = 0.0

        if len(train_loader) != len(train_2_loader):
            if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                print("Warning: train_loader and train_2_loader have different lengths, this might cause issues")
        for i, data_batch in enumerate(zip(train_loader,train_2_loader)):
            loss_CLIP=0.0
            loss_HAR=0.0

            data_clip, data_rgbHAR = data_batch
            inputs_clip, labels_clip = decouple_inputs(data_clip, model_info, device) #NOTE THIS is not setup for noise experiments!
            inputs_rgbHAR, labels_rgbHAR = decouple_inputs(data_rgbHAR, model_info=camera_only, device=device)

            optimizer.zero_grad()
            # CLIP stuff first
            logits_per_rgb, logits_per_imu, ground_truth = model(inputs_clip, model_info['sensors'], clip_align=True)
            # Note: CLIP benefits from larger batch size, so by sharing the batch size with HAR, we might be degrading CLIP performance
            # z_rgb = model.FE_rgb(inputs_clip[0]) # z is the latent feature vector
            # z_imu = model.FE_imu(inputs_clip[1])

            # # CLIP LOSS:
            # z_rgb = z_rgb.view(-1, z_rgb.shape[-1])
            # z_imu = z_imu.view(-1, z_imu.shape[-1])

            # logits_per_rgb =  z_rgb @ z_imu.t()
            # logits_per_imu = logits_per_rgb.t()

            # ground_truth = torch.arange(z_rgb.shape[0],dtype=torch.long,device=device)
            loss_CLIP = (loss_rgb(logits_per_rgb,ground_truth) + loss_imu(logits_per_imu,ground_truth))#/2*(1-args.beta)

            #SHARED LOSS:
            # loss_CLIP = loss_mse(z_imu, z_rgb)
            
            # Now do the RGB model training
            assert model_info['fusion_type'] == 'cross_modal' #otherwise next line will be an error
            outputs = model(inputs_rgbHAR, sensors=['RGB'])
            loss_HAR = criterion(outputs, labels_rgbHAR.to(device)) #*args.beta
            total_loss = loss_CLIP + loss_HAR
            
            #now optimize with both losses
            total_loss.backward()
            optimizer.step()
            #let's also log both losses separately, if one seems to be dominating we can scale it down
            running_loss_CLIP += loss_CLIP
            running_loss_HAR += loss_HAR

            if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                if (i+1) % 5 == 0:
                    #Total loss every step logged to the terminal
                    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), total_loss.item()))

                    #just always save based on project_name
                    torch.save(model.state_dict(), f'./models/{model_info["project_name"]}_{epoch+1}.pt')

        if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
            # Log the loss to wandb
            if not args.no_wandb: 
                wandb.log({'CLIP_loss': running_loss_CLIP.item() / len(train_loader)})
                wandb.log({'train_loss'+(f'_{camera_only["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): running_loss_HAR.item() / len(train_2_loader)})

        #Eval and save best model
        if (epoch+1) % args.eval_every == 0:
            acc = CLIP_evaluate(model, val_loader, device, model_info)
            if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                print('Test accuracy CLIP: {:.4f} %'.format(acc))
                if not args.no_wandb: wandb.log({'CLIP_val_acc'+(f'_{model_info["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): acc})
                eval_log(model, val_loader, device, model_info) # eval 3 different inputs on HAR output
                #The snippet below is to save the best model by CLIP accuracy
                best_val_acc=0
                best_val_file= None
                if not os.path.exists("./models"):
                    os.mkdir("./models")
                for f in os.listdir('./models/'):
                    prefix = model_info['project_name']+"-FEs"+'_best_model'
                    if f.startswith(prefix):
                        best_val_acc = float(f.split("_")[-1][:-3]) #get the accuracy from the filename, remove .pth in the end
                        best_val_file = os.path.join("./models",f)
                if acc > best_val_acc:
                    if best_val_file: os.remove(best_val_file)
                    best_val_acc = acc
                    fname = f'./models/{model_info["project_name"]}-FEs_best_model{model.hidden_size}_{acc:.4f}.pt'
                    torch.save(model.state_dict(), fname)

def random_cropbatch(rgb, sensor):
    if args.crop == 1.0:
        return rgb, sensor

    vid_length = 30
    imu_length = 180
    
    #STUFF FOR TIME SHIFT EXPERIMENTS!
    DEBUG = False
    if DEBUG: print("Random time crop")
    if DEBUG: print("rgb shape:", rgb.shape, "Sensor shape:", sensor.shape)

    #Apply a random time between .80 and 1.00 of the video lenght
    # t_ANONor = torch.rand(1) * 0.2 + 0.8 #gives a number between .6 and 1
    # let's try .6 to 1.0
    # t_ANONor = torch.rand(1) * 0.4 + 0.6 #gives a number between .6 and 1
    # now .4 to 1.0
    # t_ANONor = torch.rand(1) * 0.6 + 0.4 #gives a number between .4 and 1
    # Now general 
    # t_ANONor = torch.rand(1) * (1-args.crop) + args.crop #gives a number between args.crop and 1
    t_ANONor = args.crop # crop always to args.crop

    t_shift = torch.rand(1) * (1.0 - t_ANONor) # between 0 and t_ANONor
    if DEBUG: print("Cropping ANONor, shift ANONor:", t_ANONor, t_shift)
    if not args.misalign: #if misalign is true, we only shift the IMU!
        t_rgb = int(t_ANONor * vid_length)
        rgb = rgb[:,int(t_shift * vid_length):int(t_shift * vid_length) + t_rgb]
    t_imu = int(t_ANONor * imu_length)
    sensor = sensor[:,int(t_shift * imu_length):int(t_shift * imu_length) + t_imu]
    if DEBUG: print("rgb shape:", rgb.shape, "Sensor shape:", sensor.shape)
    if DEBUG: print("")
    if DEBUG: print("Now pad it back:")
    # ASSUME WE PAD BACK ALWAYS TO FULL!
    if False: #args.keep_time and not args.dialate:
        #fine nearest ANONor of 2 for rgb.shape[1] -> This is to gaurantee alignment in temporal latent dim when trimming video
        if rgb.shape[1] %2==1:
            rgb = rgb[:,:-1]
        # pad_size_rgb = vid_length - rgb.shape[1]
        new_imu_len = int(rgb.shape[1]*imu_length//vid_length) # imu is 6x temporal size than rgb (can change this later)
        if sensor.shape[1] > new_imu_len:
            sensor = sensor[:,:new_imu_len]
        else:
            pad_size_imu = new_imu_len - sensor.shape[1] 
            # rgb = torch.cat([rgb, torch.zeros(rgb.shape[0], pad_size_rgb, *rgb.shape[2:])], dim=1) 
            sensor = torch.cat([sensor, torch.zeros(sensor.shape[0], pad_size_imu, *sensor.shape[2:])], dim=1)
    else:
        if args.dialate:
            # interpolate
            #rn it's batch x time x channels x height x width, we want batch x channels x time x height x width for https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            if DEBUG: print("Dialating")
            rgb = rgb.permute(0,2,1,3,4)
            rgb = nn.functional.interpolate(rgb, size=(vid_length,rgb.shape[-2], rgb.shape[-1]), mode='nearest')
            rgb = rgb.permute(0,2,1,3,4)
            sensor = sensor.permute(0,2,1)
            sensor = nn.functional.interpolate(sensor, size=imu_length, mode='nearest')
            sensor = sensor.permute(0,2,1)
            if DEBUG: print("Finished interpolating")
        else:
            # pad zeros to defined lengths (student_teacher and ANON)
            pad_size_rgb = vid_length - rgb.shape[1]
            pad_size_imu = imu_length - sensor.shape[1]
            rgb = torch.cat([rgb, torch.zeros(rgb.shape[0], pad_size_rgb, *rgb.shape[2:])], dim=1) 
            sensor = torch.cat([sensor, torch.zeros(sensor.shape[0], pad_size_imu, *sensor.shape[2:])], dim=1)
        # Note: Use NN.interpolate method for dialating time dim instead of torch.cat, also need to add args.keep_time and not args.dialate above
        # When running dialation should probs crop by like 50%.
    if DEBUG: print("rgb shape:", rgb.shape, "Sensor shape:", sensor.shape)
    if DEBUG: print("")
    return rgb, sensor

# CLIP based cosine similarity representation alignment training
def CLIP_train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, model_info):
    model.train()
    loss_imu = nn.CrossEntropyLoss()
    loss_rgb = nn.CrossEntropyLoss()

    loss_mse = nn.MSELoss()

    # acc = CLIP_evaluate(model, val_loader, device, model_info)
    #num_epochs = num_epochs//2 # from wandb clip alignment loss usually converges 2x as fast as HAR training, let's do this to save GPU hours, wait but then the number of times seed is called is different between stu_teach, ANON, TANON, so they'll see different data. Maybe just keep epochs the same.
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, data_batch in enumerate(train_loader):
            if not args.T_noise_in_eval:
                data_batch[0], data_batch[1] = random_cropbatch(data_batch[0], data_batch[1])
            inputs, labels = decouple_inputs(data_batch, model_info, device)

            optimizer.zero_grad()
            logits_per_rgb, logits_per_imu, ground_truth = model(inputs, model_info['sensors'], clip_align=True)
            # z_rgb = model.module.FE_rgb(inputs[0]) # z is the latent feature vector
            # z_imu = model.module.FE_imu(inputs[1])
            # z_rgb = z_rgb / z_rgb.norm(dim=-1, keepdim=True)
            # z_imu = z_imu / z_imu.norm(dim=-1, keepdim=True)

            # #this generalizes to keep_time=True as well (collapse time dim into batch and then do alignments)
            # z_rgb = z_rgb.view(-1, z_rgb.shape[-1])
            # z_imu = z_imu.view(-1, z_imu.shape[-1])


            # # loss = 1 - (z_rgb * z_imu).sum(dim=-1).mean()
            # #using clip style training: https://github.com/openai/CLIP/issues/83
            # # cosine similarity as logits
            # logit_scale = model.module.logit_scale.exp()
            # logits_per_rgb = logit_scale * z_rgb @ z_imu.t()
            # # logit_scale is the temperature parameter, probs help stabilize training with lots of data
            # # in our case it makes it not train, or train very slowly...

            # # logits_per_rgb =  z_rgb @ z_imu.t() #time cont one works better with temperature!
            # logits_per_imu = logits_per_rgb.t()

            # ground_truth = torch.arange(z_rgb.shape[0],dtype=torch.long,device=device)
            ground_truth = ground_truth.to(device)
            if args.align_type=='cosine':
                # print("Using cosine loss for alignment")
                total_loss = (loss_rgb(logits_per_rgb,ground_truth) + loss_imu(logits_per_imu,ground_truth))/2
            elif args.align_type=='mse':
                # print("Using MSE loss for alignment")
                total_loss = loss_mse(logits_per_rgb,logits_per_imu)
            total_loss.backward()
            optimizer.step()
            running_loss += total_loss.item()
            if (i+1) % 5 == 0:
                # if torch.cuda.current_device() == device.index:
                if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), total_loss.item()))
        scheduler.step()
        
        if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
            # Log the loss to wandb
            if not args.no_wandb: wandb.log({'CLIP_loss': running_loss / len(train_loader)})

            #just always save based on project_name
            torch.save(model.state_dict(), f'./models/{model_info["project_name"]}_{epoch+1}.pt')

        #Eval and save best model
        if (epoch+1) % args.eval_every == 0:
            acc = CLIP_evaluate(model, val_loader, device, model_info)
            if args.rank == 0 or args.single_gpu:
                print('CLIP val accuracy: {:.4f} %'.format(acc))
                if not args.no_wandb: wandb.log({'CLIP_val_acc'+(f'_{model_info["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): acc})
            eval_log(model, val_loader, device, model_info)
            if args.rank == 0 or args.single_gpu:
                #The snippet below is to save the best model
                best_val_acc=0
                best_val_file= None
                if not os.path.exists("./models"):
                    os.mkdir("./models")
                for f in os.listdir('./models/'):
                    prefix = model_info['project_name']+"-FEs"+'_best_model'
                    if f.startswith(prefix):
                        best_val_acc = float(f.split("_")[-1][:-3]) #get the accuracy from the filename, remove .pth in the end
                        best_val_file = os.path.join("./models",f)
                if acc > best_val_acc:
                    if best_val_file: os.remove(best_val_file)
                    best_val_acc = acc
                    fname = f'./models/{model_info["project_name"]}-FEs_best_model{model.hidden_size if args.single_gpu else model.module.hidden_size}_{acc:.4f}.pt'
                    torch.save(model.state_dict(), fname)

# Shared representation space alignment training
def shared_train(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, model_info):
 # Here we don't align the reprsentations with clip, we make the Camera representation == to the IMU one precisely...
 # ie we use an L2 loss
    loss_imu = nn.MSELoss()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, data_batch in enumerate(train_loader):
            
            inputs, labels = decouple_inputs(data_batch, model_info, device)

            optimizer.zero_grad()
            # outputs = model(inputs)
            z_rgb = model.FE_rgb(inputs[0]) # z is the latent feature vector
            z_imu = model.FE_imu(inputs[1])
            z_rgb = z_rgb / z_rgb.norm(dim=-1, keepdim=True)
            z_imu = z_imu / z_imu.norm(dim=-1, keepdim=True)

            total_loss = loss_imu(z_imu, z_rgb)
            total_loss.backward()
            optimizer.step()
            running_loss += total_loss.item()
            if (i+1) % 5 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), total_loss.item()))
        
        # Log the loss to wandb
        if not args.no_wandb: wandb.log({'Shared_loss': running_loss / len(train_loader)})

        #Eval and save best model
        if (epoch+1) % args.eval_every == 0:
            acc = CLIP_evaluate(model, val_loader, device, model_info)
            print('Val accuracy: {:.4f} %'.format(acc))
            if not args.no_wandb: wandb.log({'Shared_val_acc'+(f'_{model_info["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): acc})
            #The snippet below is to save the best model
            best_val_acc=0
            best_val_file= None
            if not os.path.exists("./models"):
                os.mkdir("./models")
            for f in os.listdir('./models/'):
                prefix = model_info['project_name']+"-FEs"+'_best_model'
                if f.startswith(prefix):
                    best_val_acc = float(f.split("_")[-1][:-3]) #get the accuracy from the filename, remove .pth in the end
                    best_val_file = os.path.join("./models",f)
            if acc > best_val_acc:
                if best_val_file: os.remove(best_val_file)
                best_val_acc = acc
                fname = f'./models/{model_info["project_name"]}-FEs_best_model{model.module.hidden_size}_{acc:.4f}.pt'
                torch.save(model.state_dict(), fname)

            # shared val acc is how well we can predict rgb from imu
            # now eval imu-> har, rgb->har, and rgb+imu->har
            eval_log(model, val_loader, device, model_info) 

# Define fineutning loop
def fine_tune(model, tune_loader, test_loader, criterion, optimizer, scheduler, num_epochs, device, model_info):
    model.train()

    shots = []
    accs = []
    # print("len(tune_loader)",len(tune_loader))
    # print("batch_size",args.batch_size)
    # print("len(tune_loader)//args.batch_size",len(tune_loader)//args.batch_size)
    # print("shots", range(len(tune_loader)//args.batch_size))
    for num_shot in range(len(tune_loader)):
        for epoch in range(num_epochs):
            running_loss =0.0
            for i, data_batch in enumerate(tune_loader):
                inputs, labels = decouple_inputs(data_batch, model_info, device)

                optimizer.zero_grad()
                if model_info['fusion_type'] == 'cross_modal':
                    outputs = model(inputs, model_info['sensors'])
                else:
                    inputs=[inputs[0],torch.zeros_like(inputs[1])] #only test rgb
                    # inputs = [torch.zeros_like(inputs[0]),inputs[1]]
                    outputs = model(inputs)
                if len(model_info['tasks']) == 2:
                    loss = criterion(outputs[0], labels[0].to(device)) + criterion(outputs[1], labels[1].to(device))
                else:
                    loss = criterion(outputs, labels.to(device))
                loss.backward()
                optimizer.step()
                # scheduler.step()
                running_loss += loss.item()
                if (i+1) % 5 == 0:
                    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(tune_loader), loss.item()))

                if i==num_shot: break
        #save the accuracies after finetuning on num_shot batches
        shots.append((num_shot+1)*args.batch_size)
        accs.append(eval_log(model, test_loader, device, model_info))
        
        print("num_shot:",shots[-1],"acc:",accs[-1])
        # if num_shot==2: # for debugging!
        #     break 
    
    return shots, accs


def train_student_teacher(model, train_loader, train_2_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, model_info):

    #First train the teacher model, regular RGB-> HAR model training
    camera_only = model_info.copy()
    camera_only['sensors'] = ['RGB']
    if args.rank == 0 or args.single_gpu: print("Training teacher model")
    # train(model, train_2_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, camera_only)
    for epoch in range(num_epochs):
        model.rgb_model.train() if args.single_gpu else model.module.rgb_model.train()
        # print("made it here")
        running_loss = 0.0
        for i, data_batch in enumerate(train_2_loader):
            if not args.T_noise_in_eval:
                data_batch[0], data_batch[1] = random_cropbatch(data_batch[0], data_batch[1])
            inputs, labels = decouple_inputs(data_batch, camera_only, device) # get only camera data

            optimizer.zero_grad()
            outputs = model.rgb_model(inputs) if args.single_gpu else model.module.rgb_model(inputs)
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            # scheduler.step()

            running_loss += loss.item()
            if (i+1) % 5 == 0:
                if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                    print('Teacher Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_2_loader), loss.item()))
        

        # Log the loss to wandb
        if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
            if not args.no_wandb: wandb.log({'teacher_train_loss': running_loss / len(train_2_loader)})
            #just always save based on project_name
            torch.save(model.state_dict(), f'./models/{model_info["project_name"]}_{epoch+1}.pt')

        if (epoch+1) % args.eval_every == 0:
            acc = evaluate(model, val_loader, device, camera_only)
            if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                print('Teacher Test accuracy: {:.4f} %'.format(acc))
                if not args.no_wandb: wandb.log({"teacher_val_acc_['RGB']": acc})
        



    #Then train the student model, BUT use the teacher model as psuedo-labels
    if args.rank == 0 or args.single_gpu: print("Training student model")
    #freeze model.rgb_model since that is the teacher model
    if args.single_gpu:
        for param in model.rgb_model.parameters():
            param.requires_grad = False
    else:
        for param in model.module.rgb_model.parameters():
            param.requires_grad = False

    # From the train function:
    for epoch in range(args.num_epochs):
        running_loss = 0.0
        for i, data_batch in enumerate(train_loader):
            if not args.T_noise_in_eval:
                data_batch[0], data_batch[1] = random_cropbatch(data_batch[0], data_batch[1])
            inputs, _ = decouple_inputs(data_batch, model_info, device)

            optimizer.zero_grad()
            if args.single_gpu:
                with torch.no_grad():
                    psuedo_labels = model.rgb_model(inputs[0]).argmax(dim=1)
                outputs = model.imu_model(inputs[1])
            else:
                with torch.no_grad():
                    psuedo_labels = model.module.rgb_model(inputs[0]).argmax(dim=1)
                outputs = model.module.imu_model(inputs[1])
            loss = criterion(outputs, psuedo_labels.to(device))

            loss.backward()
            optimizer.step()
            # scheduler.step()

            running_loss += loss.item()
            if (i+1) % 5 == 0:
                if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                    print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))
        
        if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
            # Log the loss to wandb
            if not args.no_wandb: wandb.log({'train_loss'+(f'_{model_info["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): running_loss / len(train_loader)})
            
            #just always save based on project_name
            torch.save(model.state_dict(), f'./models/{model_info["project_name"]}_{epoch+1}.pt')

        #NOTE: EVALUATION CURRENTLY ASSUMES ONE OUTPUT LABEL
        if (epoch+1) % args.eval_every == 0:
            # acc = evaluate(model, val_loader, device, model_info)
            # print('Test accuracy: {:.4f} %'.format(acc))
            # if not args.no_wandb: wandb.log({'val_acc'+(f'_{model_info["sensors"]}' if model_info['fusion_type']== "cross_modal" else ''): acc})
            acc_rgb, acc_imu, acc_both = eval_log(model, val_loader, device, model_info)
            if args.rank == 0 or args.single_gpu: # only print from the zeroth gpu
                acc = acc_imu
                #The snippet below is to save the best model
                best_val_acc=0
                best_val_file= None
                if not os.path.exists("./models"):
                    os.mkdir("./models")
                for f in os.listdir('./models/'):
                    prefix = model_info['project_name']+'_best_model'
                    if f.startswith(prefix):
                        best_val_acc = float(f.split("_")[-1][:-3]) #get the accuracy from the filename, remove .pth in the end
                        best_val_file = os.path.join("./models",f)
                if acc > best_val_acc:
                    if best_val_file: os.remove(best_val_file)
                    best_val_acc = acc
                    fname = f'./models/{model_info["project_name"]}_best_model{model.module.hidden_size}_{acc:.4f}.pt'
                    torch.save(model.state_dict(), fname)




#NOTE: EVALUATION CURRENTLY ASSUMES ONE OUTPUT LABEL
# Define evaluation loop
def evaluate(model, val_loader, device, model_info, top_k=1, entropy_fusion=False):
    if model_info['fusion_type'] == 'imagebind':
        if "RGB" in model_info['sensors'] and "IMU" in model_info['sensors']:
            if args.dataset == 'czu-ANON':
                ib_sensor = "both_depth"
            else:
                ib_sensor = "both"
        elif "RGB" in model_info['sensors']:
            if args.dataset == 'czu-ANON':
                ib_sensor = "depth"
            else:
                ib_sensor = "vision"
        elif "IMU" in model_info['sensors']:
            ib_sensor = "imu"
        if args.rank == 0 or args.single_gpu:
            print("Running eval on imagebind model with sensor:", ib_sensor)
        return evaluate_ib(model, ib_sensor, val_loader, args)
    else:
        model.eval()
        with torch.no_grad():

            # FIRST FEED in only RGB data
            correct = 0
            total = 0
            if args.rank == 0 or args.single_gpu: 
                pbar = tqdm(val_loader, desc="Evaluating")
            else:
                pbar = val_loader
            for data_batch in pbar:

                if args.T_noise_in_eval:
                    data_batch[0], data_batch[1] = random_cropbatch(data_batch[0], data_batch[1])
                inputs, labels = decouple_inputs(data_batch, model_info, device=device)
                if model_info['fusion_type'] in ['cross_modal', 'student_teacher']:
                    # inputs = [torch.zeros_like(inputs[0]),inputs[1]] # wait. we don't want to this for cross-fusion right? ... 
                    outputs = model(inputs, model_info['sensors'], entropy_fusion=entropy_fusion)
                elif model_info['fusion_type'] in ['early', 'late', 'middle', 'attn']: # sensor fusion attempt
                    # Only in sensor fusion case, we need to handle missing inputs
                    # the cross modal model will handle this internally, but the sensor fusion model assumes both inputs
                    # Replace missing inputs with zeros, if any is missing
                    if 'RGB' in model_info['sensors'] and 'IMU' not in model_info['sensors']: 
                        # print("In rgb only:")
                        inputs = [inputs,torch.zeros_like(data_batch[1]).to(device)] #we can't do zeros_like(inputs) bc we already decoupled it to only rgb
                    elif 'RGB' not in model_info['sensors'] and 'IMU' in model_info['sensors']: 
                        # print("in imu only:")
                        inputs = [torch.zeros_like(data_batch[0]).to(device),inputs]
                    # if both are in senosrs then we can just pass through model
                    outputs = model(inputs)
                elif model_info['fusion_type'] in ['supervised_rgb', 'supervised_imu', 'supervised_fusion']:
                    outputs = model(inputs)

                if top_k==1:
                    _, predicted = torch.max(outputs.cpu().data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum()
                else:
                    # returns values, indices
                    _, predicted = torch.topk(outputs.cpu().data, top_k, 1, largest=True, sorted=True)
                    total += labels.size(0)
                    correct += sum(labels.view(-1, 1) == predicted).sum().item()
                    # if labels.view is (bs,1) and predicted is (bs, top_k) then I think == will broadcast (copy the columns) to (bs, top_k), the first sum is over the rows and the second is over the columns.

            # Convert to tensor and aggregate across all processes
            correct_tensor = torch.tensor(correct, dtype=torch.float32, device=device)
            total_tensor = torch.tensor(total, dtype=torch.float32, device=device)

            if not args.single_gpu:
                # All reduce
                dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
                dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)

            accuracy = 100 * correct_tensor / total_tensor

            return accuracy


def eval_log(model, val_loader, device, model_info):
    if model_info['fusion_type'] == 'imagebind':
        return eval_log_ib(model, val_loader, args)
    elif model_info['fusion_type'] in ['supervised_rgb', 'supervised_imu', 'supervised_fusion']:
        if args.rank==0 or args.single_gpu: print(f"Evaluating on {model_info['sensors']} only")
        acc = evaluate(model, val_loader, device, model_info=model_info)
        if args.rank==0: 
            if not args.no_wandb: wandb.log({f'val_acc_{model_info["sensors"]}': acc})
            # print('Test accuracy {}: {:.4f} %'.format(model_info["sensors"], acc))
            print(f'Test accuracy {model_info["sensors"]}: {acc:.4f} %')
        if model_info['fusion_type']=='supervised_rgb':
            return acc.item(), None, None
        elif model_info['fusion_type']=='supervised_imu':
            return None, acc.item(), None
        elif model_info['fusion_type']=='supervised_fusion':
            return None, None, acc.item()
    else:
        #Finally evaluate on camera->HAR, imu->HAR, camera+imu->HAR
        imu_only = model_info.copy()
        imu_only['sensors'] = ['IMU']
        camera_only = model_info.copy()
        camera_only['sensors'] = ['RGB']
        both = model_info.copy()
        both['sensors'] = ['RGB','IMU']
        
        if args.rank==0 or args.single_gpu: print("Evaluating on RGB only")
        acc_rgb = evaluate(model, val_loader, device, model_info=camera_only)
        if args.rank==0 or args.single_gpu: 
            if not args.no_wandb: wandb.log({'val_acc'+(f'_{camera_only["sensors"]}' if model_info['fusion_type'] in ["cross_modal", "student_teacher"] else ''): acc_rgb})
            print('Test accuracy RGB: {:.4f} %'.format(acc_rgb))


        if args.rank==0 or args.single_gpu: print("Evaluating on IMU only")
        acc_imu = evaluate(model, val_loader, device, model_info=imu_only)
        if args.rank==0 or args.single_gpu: 
            if not args.no_wandb: wandb.log({'val_acc'+(f'_{imu_only["sensors"]}' if model_info['fusion_type'] in ["cross_modal", "student_teacher"] else ''): acc_imu})
            print('Test accuracy IMU: {:.4f} %'.format(acc_imu))

        if args.rank==0 or args.single_gpu: print("Evaluating on RGB and IMU")
        acc_both = evaluate(model, val_loader, device, model_info=both)
        if args.rank==0 or args.single_gpu: 
            if not args.no_wandb: wandb.log({'val_acc'+(f'_{both["sensors"]}' if model_info['fusion_type'] in ["cross_modal", "student_teacher"] else ''): acc_both})
            print('Test accuracy: {:.4f} %'.format(acc_both))

        return acc_rgb.item(), acc_imu.item(), acc_both.item()



def CLIP_evaluate(model, val_loader, device, model_info):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        if args.rank == 0 or args.single_gpu: 
            pbar = tqdm(val_loader, desc="Evaluating")
        else:
            pbar = val_loader
        for data_batch in pbar:
            inputs, labels = decouple_inputs(data_batch, model_info, device=device)
            # z_rgb = model.FE_rgb(inputs[0]) # z is the latent feature vector
            # z_imu = model.FE_imu(inputs[1])
            # z_rgb = z_rgb / z_rgb.norm(dim=-1, keepdim=True)
            # z_imu = z_imu / z_imu.norm(dim=-1, keepdim=True)

            # #this generalizes to keep_time=True as well (collapse time dim into batch and then do alignments)
            # z_rgb = z_rgb.view(-1, z_rgb.shape[-1])
            # z_imu = z_imu.view(-1, z_imu.shape[-1])

            # # cosine similarity as logits
            # similarity = z_rgb @ z_imu.t()
            # probs = torch.nn.functional.softmax(similarity, dim=-1).max(-1)[1]
            # ground_truth = torch.arange(z_rgb.shape[0],dtype=torch.long,device=device)
            probs, ground_truth = model(inputs, model_info['sensors'], clip_eval=True)

            correct += (probs == ground_truth).sum()
            total += ground_truth.size(0)

            # Convert to tensor and aggregate across all processes
            correct_tensor = torch.tensor(correct, dtype=torch.float32, device=device)
            total_tensor = torch.tensor(total, dtype=torch.float32, device=device)

            if not args.single_gpu:
                # All reduce
                dist.all_reduce(correct_tensor, op=dist.ReduceOp.SUM)
                dist.all_reduce(total_tensor, op=dist.ReduceOp.SUM)

            accuracy = 100 * correct_tensor / total_tensor

            # # Print incorrect predictions
            # headers = ["Predicted, Actual, Path"]
            # data = []
            # if (predicted == labels).sum() != val_loader.batch_size:
            #     incorrect_indices = (predicted != labels).nonzero()[:,0]
            #     for i in incorrect_indices:
            #         print("Predicted:", actions_dict[predicted[i].item()+1], ", Actual:", actions_dict[labels[i].item()+1])#, ", Path:", path[i])
                    

        return accuracy



def main():
    # Best Performance: 2048_Adam_0.0001485682045159312_8_240 under toy-rgb-imu-har-middle
    #    2 Args:  Namespace(num_epochs=240, batch_size=8, learning_rate=0.0001485682045159312, optimizer='Adam', test=False, hidden_size=2048)
    # peaked 94 % at 149 steps (or 150 steps)
    # 95.95% epochs at 347 steps
    # python train.py --batch_size=16 --learning_rate=0.00015 --optimizer=Adam --hidden_size=2048 --num_epochs=20 --device='cuda:0' --experiment=1 --fusion_type='cross_modal' --single_gpu --no_wandb

    # Scratch:
    # python train.py --no_wandb --single_gpu --eval_every=1 --num_epochs=1 --batch_size=16 --fusion_type=supervised_rgb
    # python train.py --no_wandb --single_gpu --eval_every=1 --num_epochs=1 --batch_size=16 --fusion_type=supervised_imu
    # python train.py --batch_size=8 --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=100 --device='cuda:0' --fusion_type='cross_modal' --experiment=1 --keep_time="True" --test --single_gpu
    # python train.py --batch_size=8 --experiment=3 --learning_rate=0.00015 --beta=None
    # python train.py --batch_size=8 --learning_rate=0.00015 --optimizer=Adam --hidden_size=2048 --num_epochs=100  --single_gpu --no_wandb --fusion_type='student_teacher' --dataset='ANON-ANON'
    # python train.py --batch_size=8 --learning_rate=0.00015 --optimizer=Adam --hidden_size=2048 --num_epochs=100 --no_wandb --fusion_type='student_teacher' --dataset='ANON-ANON'
    # python train.py --batch_size=16 --learning_rate=0.00015 --optimizer=Adam --hidden_size=2048 --num_epochs=20 --device='cuda:0' --experiment=1 --fusion_type='imagebind' --single_gpu --no_wandb
    

    # Try this for all experiments
    # python train.py --batch_size=8 --learning_rate=0.015 --optimizer=Adam --hidden_size=2048 --num_epochs=100 --device='cuda:0' --experiment=2 --fusion_type='cross_modal'
    

    # Parse command-line arguments
    parser = ArgumentParser()
    # parser.add_argument('--sweep', action='store_true')
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--learning_rate', type=float, default=.001)
    parser.add_argument('--optimizer', type=str, default='Adam')
    parser.add_argument('--test', action='store_true',default=False)
    parser.add_argument('--no_wandb', action='store_true',default=False)
    parser.add_argument('--hidden_size', type=int, default=2048)
    parser.add_argument('--experiment', type=int, default=1)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--fusion_type', type=str, default='cross_modal')
    parser.add_argument('--beta', type=float, default=None) #weighting clip vs har in inter loss
    parser.add_argument('--lr_gamma', type=float, default=0.9) #lr decay
    parser.add_argument('--lr_step_size', type=int, default=10)
    parser.add_argument('--dataset', type=str, default='ANON-ANON') #ANON-ANON or mmact or mmea or czu-ANON
    # parser.add_argument('--keep_time', action='store_true',default=False) #keep time dim in the model
    parser.add_argument('--keep_time', type=bool, default=False) #keep time dim in the model '' keeps it false 'anything' makes it true
    parser.add_argument('--learning_rate_2', type=float, default=None) #learning rate for second model in student teacher
    parser.add_argument('--lr_gamma_2', type=float, default=None) #lr decay for second model in student teacher
    parser.add_argument('--lr_step_size_2', type=int, default=None) #lr decay for second model in student teacher
    parser.add_argument('--single_gpu', action='store_true',default=False) # default use DDP multi GPU.
    parser.add_argument('--rank', type=int, default=0) # i'm kind of using this as a global variable for when distributed training is not used
    parser.add_argument('--eval_every', type=int, default=10) # how often to perform eval (every eval_every epochs)
    parser.add_argument('--load_model', type=str, default=None) # load a model from a file
    parser.add_argument('--seed', type=int, default=1) # seed for reproducibility
    parser.add_argument('--crop', type=float, default=1.0) # randomly crop the video to max this percentage
    parser.add_argument('--misalign', action='store_true', default=False) # randomly misalign the video and imu data
    parser.add_argument('--dialate', action='store_true', default=False) # randomly misalign the video and imu data
    parser.add_argument('--T_noise_in_eval', action='store_true', default=False) # by default if there is temporal noise its in training, set this to true and it will ONLY be in testing
    parser.add_argument('--rgb_svit', action='store_true', default=False) # use the vision transformer for the spatial extractor rgb model
    parser.add_argument('--rgb_tvit', action='store_true', default=False) # use the vision transformer for the temporal extractor in the rgb model
    parser.add_argument('--imu_vit', action='store_true', default=False) # use the vision transformer for the imu model
    parser.add_argument('--cosine_fusion', action='store_true', default=False) # use cosine similarity for fusion (A+B)/ ||A+B|| instead of (A+B)/2
    parser.add_argument('--c3t_head', type=str, default='vit_clstkn') # head for the c3t model, options: vit_clstkn, vit_proj, mlp_add, mlp_concat
    parser.add_argument('--align_type', type=str, default='cosine') # type of alignment to use, options: cosine, mse

    global args
    global overall_start_time
    overall_start_time = time.time()
    args = parser.parse_args()

    # Set the seed for reproducibility
    torch.manual_seed(args.seed) # 1 for trial 1s, 2 for trial 2s, 3 for trial 3s.

    if args.test:
        args.no_wandb = True

    if args.learning_rate_2 is None:
        args.learning_rate_2 = args.learning_rate
    if args.lr_gamma_2 is None:
        args.lr_gamma_2 = args.lr_gamma
    if args.lr_step_size_2 is None:
        args.lr_step_size_2 = args.lr_step_size
    # Set the hyperparameters (note most set in argparser above)
    model_info = {
        'sensors' : ['RGB', 'IMU'], #['RGB', 'IMU']
        'tasks' : ['HAR'], #['HAR', 'PID'],
        'fusion_type' : args.fusion_type, #'middle', #'cross_modal', # 'early', 'middle', 'late', 'cross_modal', 'student_teacher
        'num_classes' : -1,
        'project_name' : ""
    }
    rgb_video_length = 30 #*2
    imu_length= 180 #*2 
    dataset = args.dataset #'ANON-ANON' #'mmact'#'ANON-ANON'

    #lets fill in some more model_info stuff
    #create a project name
    model_info['project_name'] = "toy-"+"-".join(model_info['sensors']+model_info['tasks'])+'-'+model_info['fusion_type']+(f'{args.experiment}' if model_info['fusion_type'] == 'cross_modal' else '')

    # Handle GPU and multiprocessing
    if args.single_gpu:
        # OLD GPU STUFF
        if int(args.device.split(":")[-1]) < torch.cuda.device_count(): # if user specifies a device use it.
            device = torch.device(args.device)
        else: # otherwise find a device
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        args.device = device
        print("Using device: ",device)
        local_rank = device.index
        args.rank = local_rank
        world_size=1
    else:
        # Check if running under Slurm
        if 'SLURM_JOB_ID' in os.environ:
            if 'SLURM_LOCALID' in os.environ:
                local_rank = int(os.environ['SLURM_LOCALID'])
            else:
                raise RuntimeError(f"SLURM_LOCALID environment variable not found. Check your Slurm setup for job {os.environ['SLURM_JOB_ID']}.")
            
            if 'SLURM_PROCID' in os.environ:
                args.rank = int(os.environ['SLURM_PROCID'])
            else:
                raise RuntimeError(f"SLURM_PROCID environment variable not found. Check your Slurm setup for job {os.environ['SLURM_JOB_ID']}.")

            nnodes = int(os.environ['SLURM_JOB_NUM_NODES'])
            nproc_per_node = int(os.environ['SLURM_NTASKS_PER_NODE'])
            # Export SLURM_LAUNCH_NODE_IPADDR environ variable as MASTER_ADDR
            os.environ["MASTER_ADDR"] = os.environ["SLURM_LAUNCH_NODE_IPADDR"]
            os.environ["MASTER_PORT"] = "29500"
        else:
            # Assume running without Slurm, fallback to rank and local_rank
            local_rank = int(os.environ['LOCAL_RANK'])
            args.rank = int(os.environ['RANK'])
            nnodes = 1  # Assume single node if not under Slurm
            nproc_per_node = int(os.environ['NPROC_PER_NODE'])
        # print("Local rank: ",local_rank)
        # print("Rank: ",args.rank)

        # Calculate world_size
        world_size = nnodes * nproc_per_node

        # Set device for the current process
        # device = torch.cuda.current_device() # For some reason, this doesn't work and torchrun is not automatically assigning gpu devices correctly yet, let's do it manually with local_rank
        device = torch.device('cuda', local_rank)
        args.device = device #ignore the user inputted device, each process will have its own device
        torch.cuda.set_device(device)

        # Initialize process group
        dist.init_process_group(backend='nccl', rank=args.rank, world_size=world_size)

        # Get world size
        # world_size = dist.get_world_size()
        if args.rank == 0 or args.single_gpu: print("World size: ",world_size)
        assert args.batch_size % world_size == 0, "Batch size must be divisible by the number of GPUs."

    # Load the dataloaders
    if model_info['fusion_type'] == 'imagebind':
        #Imagebind model inputs the path to the video (not the raw video tensor like everything else)
        train_loader, train_2_loader, val_loader, test_loader, model_info = load_dataloaders(dataset, model_info, rgb_video_length, imu_length, world_size, args, return_path=True)
    else:
        train_loader, train_2_loader, val_loader, test_loader, model_info = load_dataloaders(dataset, model_info, rgb_video_length, imu_length, world_size, args)

        test_loader = val_loader # NOTE: IMPORTANT: NOT RECOMMENDED! THIS IS JUST FOR FINAL RUN TESTING in case we hit a time wall on the HPC and it doesn't run the test at the end of training we still have result to report.

    # keep_time indicates T-ANON Model, add to project name
    if args.keep_time:
        model_info['project_name'] += "_keep_time"

    #MLP Specficic:
    input_size = 180*6  #imu length * 6 sensors see IMU/dataset.py for more info
    hidden_size = args.hidden_size
    output_size = model_info['num_classes']
    print("Args: ", args)

    # Initialize wandb
    if args.rank == 0 or args.single_gpu: #only init wandb once
        if not args.test and not args.no_wandb: #don't run this if we are just running eval or way say no_wandb=true
            # start a new wandb run to track this script
            wandb.init(
                # set the wandb project where this run will be logged
                project=model_info['project_name'],
                
                # track hyperparameters and run metadata
                config={
                'num_epochs': args.num_epochs,
                'batch_size': args.batch_size,
                'learning_rate': args.learning_rate,
                'optimizer': args.optimizer,
                'hidden_size': args.hidden_size,
                'experiment': args.experiment,
                'fusion_type': args.fusion_type,
                'rgb_video_length': rgb_video_length,
                'imu_length': imu_length
            })
            wandb_name = f"{args.hidden_size}_{args.optimizer}_{args.learning_rate}_{args.batch_size}_{args.num_epochs}"
            wandb.run.name = wandb_name
            print("Creating a wandb run:",wandb_name)

    # Define the model, loss function, and optimizer
    if 'IMU' in model_info['sensors'] and 'RGB' in model_info['sensors']:
        if model_info['fusion_type'] == 'cross_modal':
            model = CM_Fusion(input_size, hidden_size, output_size, rgb_video_length=rgb_video_length, imu_length=imu_length, keep_time=args.keep_time, imu_channels=model_info['num_imu_channels']).to(device).float()
        elif model_info['fusion_type'] == 'early':
            model = Early_Fusion(input_size, hidden_size, output_size, rgb_video_length=rgb_video_length, imu_channels=model_info['num_imu_channels']).to(device).float() #doesn't need imu_length assuming imu < rgb data size, see model for details
        elif model_info['fusion_type'] == 'middle':
            model = Middle_Fusion(input_size, hidden_size, output_size, rgb_video_length=rgb_video_length, imu_length=imu_length, imu_channels=model_info['num_imu_channels']).to(device).float()
        elif model_info['fusion_type'] == 'late':
            model = Late_Fusion(input_size, hidden_size, output_size, rgb_video_length=rgb_video_length, imu_length=imu_length, imu_channels=model_info['num_imu_channels']).to(device).float()
        elif model_info['fusion_type'] == 'attn':
            model = HAMLET(input_size, hidden_size, output_size, rgb_video_length=rgb_video_length, imu_length=imu_length, imu_channels=model_info['num_imu_channels']).to(device).float()
        elif model_info['fusion_type'] == 'student_teacher':
            model = Student_Teacher(input_size, hidden_size, output_size, rgb_video_length=rgb_video_length, imu_length=imu_length, imu_channels=model_info['num_imu_channels']).to(device).float() 
        elif model_info['fusion_type'] == 'imagebind':
            model = ImageBind_baseline(output_size).to(device).float()
        elif model_info['fusion_type'] == 'supervised_rgb':
            model = RGB_Action(output_size, rgb_video_length).to(device).float()
            model_info['sensors'] = ['RGB']
        elif model_info['fusion_type'] == 'supervised_imu':
            model = IMU_Action(input_channels=model_info['num_imu_channels'], hidden_size=hidden_size*2, output_size=output_size, imu_length=imu_length).to(device).float()
            model_info['sensors'] = ['IMU']
        elif model_info['fusion_type'] == 'supervised_fusion':
            # Note this is different from the middle fusion above since this method performs supervised training not UMA training
            model = Middle_Fusion(input_size, hidden_size, output_size, rgb_video_length=rgb_video_length, imu_length=imu_length, imu_channels=model_info['num_imu_channels']).to(device).float()
            model_info['sensors'] = ['RGB', 'IMU']
        else:
            raise ValueError("Invalid fusion type")   
    
    elif 'IMU' in model_info['sensors']:
        if 'HAR' in model_info['tasks'] and 'PID' in model_info['tasks']: #this doesn't make full sense rn, bc my joint imu -> 2 task model doesn't output pid rn
            model = joint_IMU_MLP(input_size, hidden_size, output_size).to(device).float()
            train_loader = ((x[1], x[2], x[3]) for x in train_loader)
            val_loader  = ((x[1], x[2], x[3]) for x in val_loader)
        else:
            model = IMU_MLP(input_size, hidden_size, output_size).to(device).float()
            if 'PID' in model_info['tasks']:
                train_loader = ((x[1], x[2]) for x in train_loader)
                val_loader  = ((x[1], x[2]) for x in val_loader)
            elif 'HAR' in model_info['tasks']:
                train_loader = ((x[1], x[3]) for x in train_loader)
                val_loader  = ((x[1], x[3]) for x in val_loader)

    if args.single_gpu:
        # class dummy_ddp(model):
        #     def __init__(self, model):
        #         super().__init__()
        #         self.module = model

        # model = dummy_ddp(model) # When freezing weights and loading state dict allows for consistency... hopefully
        # print("Setting module")
        # model.module = model
        # print(model.module)
        model = model.to(device)
    else:
        # Wrap the model with DataParallel using the free GPUs
        # model = nn.DataParallel(model, device_ids=free_gpus)
        if args.fusion_type in ['supervised_rgb', 'supervised_imu', 'supervised_fusion']:
            #huge speedup (like 10x)
            # model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
            # Wait after adding ablations (if statements to use ViT for certain parts of the models) this is not working anymore :( . Torch is throwing errors asking me to turn find_unused back to True...
            model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
        else:
            model = DDP(model, device_ids=[local_rank], find_unused_parameters=True) # Device ids is if model is too big to fit one GPU, you can split it across multiple
            # Also, I think i must set find_unused_parameters to True for the cross modal models and UMA setting to work (handles when some parameters are not used in the model for certain updates)

        torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) #need for batchnorm

        # Move the model to the first free GPU (this is required for DataParallel to work correctly)
        model.to(local_rank)

        #to launch:
        # python -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_addr="192.17.100.218" --master_port=12355 train.py --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=3 --experiment=1 --fusion_type=cross_modal --keep_time="" --dataset="mmact" --batch_size=32 --no_wandb
        # python -m torch.distributed.launch --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_addr="172.28.23.127" --master_port=12355 
        # torchrun --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.17.100.218" --master_port=12355 train.py --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=3 --experiment=1 --fusion_type=cross_modal --keep_time="" --dataset="mmea" --batch_size=32 --no_wandb
        #NOTE: TO specifiy devices, do as below:
        # CUDA_VISIBLE_DEVICES=1,2 torchrun --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="192.17.100.218" --master_port=12354 train.py --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=3 --experiment=1 --fusion_type=cross_modal --keep_time="" --dataset="mmea" --batch_size=8 --no_wandb

        # torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr="192.17.100.218" --master_port=12355 train.py --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=3 --experiment=1 --fusion_type=cross_modal --keep_time="" --dataset="mmea" --batch_size=16 --no_wandb

        # torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr="192.17.100.218" --master_port=12355 train.py --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=100 --experiment=1 --fusion_type=cross_modal --keep_time="True" --dataset="ANON-ANON" --batch_size=16

        # For ANON
        # torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_addr="172.28.23.117" --master_port=12355 train.py --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=100 --experiment=1 --fusion_type=cross_modal --keep_time="True" --dataset="ANON-ANON" --batch_size=16 --no_wandb
        # torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_addr="172.28.23.127" --master_port=12355 train.py --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=100 --experiment=1 --fusion_type=cross_modal --keep_time="True" --dataset="ANON-ANON" --batch_size=16 --no_wandb
        # torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 --master_addr="172.28.23.102" --master_port=12355 train.py --learning_rate=0.00015 --optimizer=Adam --hidden_size=1024 --num_epochs=100 --experiment=1 --fusion_type=cross_modal --keep_time="True" --dataset="ANON-ANON" --batch_size=16 --no_wandb


    if args.load_model:
        model.load_state_dict(torch.load(args.load_model))
        print("Loaded model from: ",args.load_model)

    
    criterion = nn.CrossEntropyLoss()
    if args.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
        optimizer_2 = optim.Adam(model.parameters(), lr=args.learning_rate_2)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9)
        optimizer_2 = optim.SGD(model.parameters(), lr=args.learning_rate_2, momentum=0.9)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
    scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=args.lr_step_size_2, gamma=args.lr_gamma_2)

    # Test or train the model
    start = time.time()
    if args.test:
        # UNCOMMENT FOR AUTO BEST MODEL
        # models = os.listdir('./models/')
        # for m in models:
        #     # prefix = model_info['project_name']+'_best_model'
        #     prefix = model_info['project_name']+'-FEs'+'_best_model'
        #     if m.startswith(prefix):
        #         model_path = os.path.join('./models/', m)
        #         break
        # print("Evaluating model: ", model_path)
        # model.load_state_dict(torch.load(model_path))
        # # acc = evaluate(model, test_loader, device, model_info=model_info)
        # # print('Test accuracy: {:.4f} %'.format(acc))
        # # print("Before Finetuning:")
        # # eval_log(model, test_loader, device, model_info=model_info)

        #UNCOMMENT FOR CUSTOM MODEL
        model_path = f"{HOME_DIR}/toy_HAR/ANON/models/mmact-toy-RGB-IMU-HAR-student_teacher_18.pt"
        model_path = f"{HOME_DIR}/toy_HAR/ANON/models/toy-RGB-IMU-HAR-cross_modal2_best_model1024_89.6552.pt"
        model_path = f"{HOME_DIR}/toy_HAR/ANON/models/toy-RGB-IMU-HAR-cross_modal1_keep_time_99.pt"
        # model_path = f"{HOME_DIR}/toy_HAR/ANON/models/toy-RGB-IMU-HAR-cross_modal1-FEs_best_model2048_94.1860.pt"
        model_path = "/u/ANON/ANON/models/trial1/toy-RGB-IMU-HAR-cross_modal1_100.pt"
        # model_path = "/u/ANON/ANON/models/toy-RGB-IMU-HAR-cross_modal1_keep_time_100.pt"
        # dist.init_process_group(backend='nccl', rank=0, world_size=1)
        # model = DDP(model)
        model.load_state_dict(torch.load(model_path))
        

        # UNCOMMENT BELOW for fine-tuning
        # print("After Finetuning:")
        # #fine-tune for 5 epochs
        # shots, accs =  fine_tune(model, val_loader, test_loader, criterion, optimizer, scheduler, 5, device, model_info)

        # # Save a plot of shots on the x axis and rgb,imu,both accuracy on the y axis
        # plt.title("Few Shot Learning")
        # accs = np.array(accs)
        # print(accs)
        # print(shots)
        # accs_rgb, accs_imu, accs_both = accs[:,0], accs[:,1], accs[:,2]
        # plt.plot(shots, accs_rgb, label='RGB')
        # plt.plot(shots, accs_imu, label='IMU')
        # plt.plot(shots, accs_both, label='Both')
        # plt.xlabel("Number of Shots")
        # plt.ylabel("Accuracy")
        # plt.legend()
        # plt.savefig(f"entropy_{model_info['project_name']}_few_shot.png")

    else:
        if model_info['fusion_type'] == 'cross_modal':
            
            if args.experiment==1:

                # ----------- EXPERIMENT 1 ----------- 
                # First Align, then freeze encoders and train RGB, then evaluate

                #First align modalities representations
                if args.rank == 0 or args.single_gpu: print("Aligning Modalities")
                fname = f'./models/cross-modal/trained-FEs-{model_info["project_name"]}.pt'
                CLIP_train(model, train_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info)
                ## or load them from the best model
                # print("Loading encoders")
                # torch.load(fname)
                # models = os.listdir('./models/')
                # for m in models:
                #     prefix = model_info['project_name']+'-FEs'+'_best_model'
                #     if m.startswith(prefix):
                #         model_path = os.path.join('./models/', m)
                #         break
                # print("Loaded model: ", model_path)
                # model.load_state_dict(torch.load(model_path))

                # if args.rank == 0: 
                #Perform CLIP Eval:
                if args.rank == 0 or args.single_gpu: print("Evaluating on RGB and IMU CLIP representations")
                acc = CLIP_evaluate(model, val_loader, device, model_info)
                if args.rank == 0 or args.single_gpu: 
                    print('\tTest accuracy: {:.4f} %'.format(acc))
                    if not os.path.exists('./models/cross-modal/'):
                        os.mkdir('./models/cross-modal/')
                    fname = f'./models/cross-modal/trained-FEs-{model_info["project_name"]}-acc{acc}.pt'
                    torch.save(model.state_dict(), fname)
                    print("Saved model as ", fname)
                    # exit()

                #Freeze the encoders
                if args.single_gpu:
                    for param in model.FE_rgb.parameters():
                        param.requires_grad = False
                    for param in model.FE_imu.parameters():
                        param.requires_grad = False
                else:
                    #Freeze weights of model.FE_rgb and model.FE_imu
                    for param in model.module.FE_rgb.parameters():
                        param.requires_grad = False
                    for param in model.module.FE_imu.parameters():
                        param.requires_grad = False

                camera_only = model_info.copy()
                camera_only['sensors'] = ['RGB']

                # # hybrid https://wandb.ai/ANON/toy-RGB-IMU-HAR-cross_modal2_keep_time/runs/dgl4nuef/logs?nw=nwuserANON
                # args.lr_gamma=.1 
                # args.learning_rate=.006
                # args.num_epochs=60
                # optimizer_2 = optim.Adam(model.parameters(), lr=args.learning_rate)
                # scheduler_2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=args.lr_step_size, gamma=args.lr_gamma)

                #Then train the model with camera HAR 
                print("Training on RGB only")
                train(model, train_2_loader, val_loader, criterion, optimizer_2, scheduler_2, args.num_epochs, device, model_info=camera_only)

                
                # or load it from the best model
                # print("Loading model RGB trained model")
                # models = os.listdir('./models/')
                # for m in models:
                #     prefix = model_info['project_name']+'_best_model'
                #     if m.startswith(prefix):
                #         model_path = os.path.join('./models/', m)
                #         break
                # print("Loaded model: ", model_path)
                # model.load_state_dict(torch.load(model_path))

            elif args.experiment==2:
                # ----------- EXPERIMENT 2 ----------- 
                # First Train RGB HAR, freeze RGB encoder, then align (w/ l2 metric!!!), then evaluate 
                # I feel like this should work for shared representation space...

                # First train RGB HAR model
                if args.rank == 0 or args.single_gpu: print("Training RGB HAR")

                # print size of model
                # print("Model size:", sum(p.numel() for p in model.parameters() if p.requires_grad))
                # print("RGB size:", sum(p.numel() for p in model.FE_rgb.parameters() if p.requires_grad))
                # print("IMU size:", sum(p.numel() for p in model.FE_imu.parameters() if p.requires_grad))
                # print("Joint processing size:", sum(p.numel() for p in model.joint_processing.parameters() if p.requires_grad))
                # exit()

                # model._set_static_graph() # Attempted workaround to use DDP for multi batch inference.
                # Wait is't multi batch training exp3? why did i put this here. I'm actually getting an error to set this static_graph to false which is defualt behavior.

                camera_only = model_info.copy()
                camera_only['sensors'] = ['RGB']
                train(model, train_2_loader, val_loader, criterion, optimizer_2, scheduler_2, args.num_epochs, device, model_info=camera_only)

                if args.rank == 0 or args.single_gpu: print("Evaluating on RGB only")
                acc = evaluate(model, val_loader, device, model_info=camera_only)
                if args.rank == 0 or args.single_gpu: 
                    print('\tTest accuracy: {:.4f} %'.format(acc))
                    fname = f'./models/cross-modal/trained-FEs-{model_info["project_name"]}-acc{acc}.pt'
                    torch.save(model.state_dict(), fname)
                    print("Saved model as ", fname)


                # print("Loading model RGB trained model")
                # fname=f'{HOME_DIR}/toy_HAR/ANON/models/cross-modal/trained-FEs-toy-RGB-IMU-HAR-cross_modal2_keep_time-acc93.02325439453125.pt'
                # model.load_state_dict(torch.load(fname))



                # Then Freeze RGB
                if args.single_gpu:
                    for param in model.FE_rgb.parameters():
                        param.requires_grad = False
                else:
                    for param in model.module.FE_rgb.parameters():
                        param.requires_grad = False

                # Align the representations directly not through cosine similarity
                if args.rank == 0 or args.single_gpu: print("Aligning Modalities")
                # shared_train(model, train_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info)
                CLIP_train(model, train_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info)
                
                # torch.save(model.state_dict(), "debug_exp2.pt")
                
            elif args.experiment==3:
                    # print("In here")
                    # ----------- EXPERIMENT 3 -----------
                    # While training RGB HAR Algin Representations (do both intermittantly), than evaluate

                    #Align modalities and train rgb har together
                    print("Aligning Modalities and Training RGB HAR")
                    inter_train(model, train_loader, train_2_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info)

                    #Perform CLIP Eval:
                    print("Evaluating on RGB and IMU CLIP representations")
                    acc = CLIP_evaluate(model, val_loader, device, model_info)
                    print('\tTest accuracy: {:.4f} %'.format(acc))


            elif args.experiment==4:
                # ----------- EXPERIMENT 4 -----------
                # While training RGB HAR Algin Representations this time within the batch, than evaluate

                # model._set_static_graph() # Attempted workaround to use DDP for multi batch inference.
                
                #Align modalities and train rgb har together
                print("Aligning Modalities and Training RGB HAR")
                intra_train(model, train_loader, train_2_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info)

                #Perform CLIP Eval:
                print("Evaluating on RGB and IMU CLIP representations")
                acc = CLIP_evaluate(model, val_loader, device, model_info)
                print('\tTest accuracy: {:.4f} %'.format(acc))

        elif model_info['fusion_type'] == 'student_teacher':
            # train_student_teacher(model, train_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info)
            train_student_teacher(model, train_loader, train_2_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info)

        elif model_info['fusion_type'] == 'imagebind':
            train_linear(train_2_loader, test_loader, model, "vision", args)

        elif model_info['fusion_type'] in ['early', 'middle', 'late', 'attn']:
            train(model, train_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info=model_info)
            train(model, train_2_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info=model_info, psuedo_labels=True)
        elif model_info['fusion_type'] in ['supervised_rgb', 'supervised_imu', 'supervised_fusion']:
            #NOTE train_2 is the (X_rgb,Y) used for HAR, train_1 is the (X_IMU, X_rgb). Let's just use it for supervised as well
            train(model, train_2_loader, val_loader, criterion, optimizer, scheduler, args.num_epochs, device, model_info=model_info)
            


        if args.rank == 0 or args.single_gpu:
            print("Training time:", (time.time()-start)/60, "minutes")

    # REGARDLESS OF WHICH EXPERIMENT, END WITH FINAL TESTING:
        
    #Finally Test on camera, imu, camera+imu
    imu_only = model_info.copy()
    imu_only['sensors'] = ['IMU']
    camera_only = model_info.copy()
    camera_only['sensors'] = ['RGB']

    if args.rank == 0 or args.single_gpu:
        print("\n\nTesting on Final Test Dataset")

    for k in [1,3,5,7,10]:

        if model_info['fusion_type'] in ['supervised_rgb', 'supervised_imu', 'supervised_fusion']:
            if args.rank==0 or args.single_gpu: print(f"Evaluating on {model_info['sensors']} only")
            acc = evaluate(model, val_loader, device, model_info=model_info, top_k=k)
            if args.rank==0 or args.single_gpu: 
                print(f'Top-{k} Test accuracy {model_info["sensors"]}: {acc:.4f} %')

            if args.rank == 0 or args.single_gpu:
                # Log the final test accuracies
                if not args.no_wandb:
                    if k==1: #for consistency with old runs
                        wandb.log({f'final_test_acc_{model_info["sensors"]}': acc})
                    else:
                        wandb.log({f'final_test_acc_{model_info["sensors"]}_top{k}': acc})

        else:
            if args.rank == 0 or args.single_gpu:
                print(f"Top-{k} Accuracy")
                print("Evaluating on RGB only")
            acc_rgb = evaluate(model, test_loader, device, model_info=camera_only, top_k=k)
            if args.rank == 0 or args.single_gpu:
                print('\tTest accuracy: {:.4f} %'.format(acc_rgb))

                print("Evaluating on IMU only")
            acc_imu = evaluate(model, test_loader, device, model_info=imu_only, top_k=k)
            if args.rank == 0 or args.single_gpu:
                print('\tTest accuracy: {:.4f} %'.format(acc_imu))

                print("Evaluating on RGB and IMU")
            acc_both = evaluate(model, test_loader, device, model_info=model_info, top_k=k)
            if args.rank == 0 or args.single_gpu:
                print('\tTest accuracy: {:.4f} %'.format(acc_both))
                print()

            if args.rank == 0 or args.single_gpu:
                # Log the final test accuracies
                if not args.no_wandb:
                    if k==1: #for consistency with old runs
                        wandb.log({'final_test_acc_RGB': acc_rgb})
                        wandb.log({'final_test_acc_IMU': acc_imu})
                        wandb.log({'final_test_acc_both': acc_both})
                    else:
                        wandb.log({f'final_test_acc_RGB_top{k}': acc_rgb})
                        wandb.log({f'final_test_acc_IMU_top{k}': acc_imu})
                        wandb.log({f'final_test_acc_both_top{k}': acc_both})

    # Perform CLIP only once
    if model_info['fusion_type'] == 'cross_modal':
        if args.rank == 0 or args.single_gpu: print("CLIP Accuracy")
        acc_clip = CLIP_evaluate(model, test_loader, device, model_info)
        if args.rank == 0 or args.single_gpu: print('\tTest accuracy: {:.4f} %'.format(acc_clip))
    else:
        acc_clip = None
    if (args.rank == 0 or args.single_gpu) and not args.no_wandb:
            wandb.log({'final_test_acc_CLIP': acc_clip})

    if args.rank == 0 or args.single_gpu:
        # save latent representations to a file
        if model_info['fusion_type'] == 'cross_modal':
            print("Saving latent representations of test set")
            # for i, data_batch in enumerate(train_loader):
            for i, data_batch in enumerate(test_loader):
                inputs, labels = decouple_inputs(data_batch, model_info, device=device)

                with torch.no_grad():
                    if model_info['fusion_type'] == 'cross_modal':
                        logits = outputs = model(inputs, ["RGB","IMU"])
                        if args.single_gpu:
                            z_rgb = model.FE_rgb(inputs[0])
                            z_imu = model.FE_imu(inputs[1])
                        else:
                            z_rgb = model.module.FE_rgb(inputs[0])
                            z_imu = model.module.FE_imu(inputs[1])
                        z_rgb = z_rgb.cpu().detach().numpy()
                        z_imu = z_imu.cpu().detach().numpy()
                        logits = logits.cpu().detach().numpy()
                        labels = labels.cpu().detach().numpy()
                    else:
                        # outputs = model(inputs)
                        # outputs = outputs.cpu().detach().numpy()
                        raise NotImplementedError("To implement look at evaluate() function")
                if i == 0:
                    if model_info['fusion_type'] == 'cross_modal':
                        rgb_latent = z_rgb
                        imu_latent = z_imu
                        logits_latent = logits
                        labels_all = labels
                    else:
                        outputs_latent = outputs
                else:
                    if model_info['fusion_type'] == 'cross_modal':
                        # print("Concatenating")
                        # print(rgb_latent.shape, z_rgb.shape)
                        # print(imu_latent.shape, z_imu.shape)
                        # print(logits_latent.shape, logits.shape)
                        # print(labels_all.shape, labels.shape)
                        rgb_latent = np.vstack((rgb_latent, z_rgb))
                        imu_latent = np.vstack((imu_latent, z_imu))
                        logits_latent = np.vstack((logits_latent, logits))
                        labels_all = np.hstack((labels_all, labels))
                    else:
                        outputs_latent = np.vstack((outputs_latent, outputs))
            if model_info['fusion_type'] == 'cross_modal':
                np.save(f'./latent_representations/{model_info["project_name"]}_rgb.npy', rgb_latent)
                np.save(f'./latent_representations/{model_info["project_name"]}_imu.npy', imu_latent)
                np.save(f'./latent_representations/{model_info["project_name"]}_logits.npy', logits_latent)
                np.save(f'./latent_representations/{model_info["project_name"]}_labels.npy', labels_all)
            else:
                np.save(f'./latent_representations/{model_info["project_name"]}.npy', outputs_latent)
            print(f"Saved Latent Representations as ./latent_representations/{model_info['project_name']}_stuff.npy")
        else:
            print("Saving latent reprensetations for baselines (not cross-modal) is not implemented yet")

        # Let's also print the size of the model for future reference
        # print("Model size:", sum(p.numel() for p in model.parameters() if p.requires_grad))
        print("Model size:", sum(p.numel() for p in model.parameters()))

        #The code below just prints 0 for RGB and IMU individually and everything for joint -- TODO: debug or delete
        # if args.single_gpu:
        #     print("RGB size:", sum(p.numel() for p in model.FE_rgb.parameters() if p.requires_grad))
        #     print("IMU size:", sum(p.numel() for p in model.FE_imu.parameters() if p.requires_grad))
        #     print("Joint processing size:", sum(p.numel() for p in model.joint_processing.parameters() if p.requires_grad))
        # else:
        #     print("RGB size:", sum(p.numel() for p in model.module.FE_rgb.parameters() if p.requires_grad))
        #     print("IMU size:", sum(p.numel() for p in model.module.FE_imu.parameters() if p.requires_grad))
        #     print("Joint processing size:", sum(p.numel() for p in model.module.joint_processing.parameters() if p.requires_grad))



    if not args.single_gpu:
        # Cleanup
        dist.destroy_process_group()


if __name__ == '__main__':
    main()


def signal_handler(sig, frame):
    """
    Signal handler function.
    """
    print("\nExiting gracefully...")
    # Perform cleanup operations here
    # For example, close files, release resources, etc.
    # Cleanup
    dist.destroy_process_group()

    print("Total execution time:", (time.time()- overall_start_time)/60, "minutes")
    sys.exit(0)

# # Register the signal handler
# signal.signal(signal.SIGINT, signal_handler)  # Handle Ctrl+C
# signal.signal(signal.SIGTERM, signal_handler)  # Handle termination signal

# Register the signal handler for multiple signal
for sig in (signal.SIGINT, signal.SIGTERM, signal.SIGABRT, signal.SIGQUIT, signal.SIGILL, signal.SIGSEGV, signal.SIGFPE):
    signal.signal(sig, signal_handler)