import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import adept_envs
import gym
import math
import cv2
import numpy as np
from PIL import Image
import os
import torchvision.transforms as T
from vip import load_vip
import matplotlib.pyplot as plt
import pickle
import time
torch.set_printoptions(edgeitems=10, linewidth=500)
# Basic global variables
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print("Device is", device, flush=True)
vip = load_vip()
vip.eval()
vip = vip.to(device)
transforms = T.Compose([T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor()]) # ToTensor() divides by 255


class Robotic_Environment: # This will only be used for evaluating a trained model
    # Creates the entire robotic environment on pybullet
    def __init__(self, video_resolution, gaussian_noise, camera_number, reset_information, in_hand_eval):

        env = gym.make('kitchen_relax-v1')
        self.env = env.env
        self.video_frames = [] # These are the frames of video saved for evaluation
        self.video_resolution = video_resolution
        self.env.reset()
        if(gaussian_noise):
            mean = 0  # Mean of the Gaussian noise
            std_dev =0.03  # Standard deviation of the Gaussian noise
            self.env.sim.data.qpos[:] = reset_information[0] * (1 + np.random.normal(mean, std_dev, reset_information[0].shape))
            self.env.sim.data.qvel[:] = reset_information[1] * (1 + np.random.normal(mean, std_dev, reset_information[1].shape))
        else:
            self.env.sim.data.qpos[:] = reset_information[0]
            self.env.sim.data.qvel[:] = reset_information[1]
        self.env.sim.forward() # The environment is setup

    def step(self, action):
        
        self.env.step(np.array(action)) # Execute some action
        curr_frame = self.env.render(mode='rgb_array') # Capture image
        rgb_array = np.array(curr_frame)
        rgb_array = Image.fromarray(rgb_array)
        rgb_array = np.array(rgb_array)
        bgr_array = cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR)
        bgr_array = cv2.resize(bgr_array, self.video_resolution)
        self.video_frames.append(bgr_array)

    def get_current_state(self, space): # This is the state in the format specified as input
        if(space == "joint_space"):
            return (self.env._get_obs()).tolist()
        elif(space == "both"):
            # Image embedding
            curr_frame = self.env.render(mode='rgb_array') # Capture image
            rgb_array = np.array(curr_frame)
            rgb_array = Image.fromarray(rgb_array)
            rgb_array = np.array(rgb_array)
            preprocessed_image = transforms(Image.fromarray(rgb_array.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            current_state = subgoal_embedding.cpu().tolist()[0]
            # Joint space
            non_fixed_current_joint_state = (self.env._get_obs()).tolist()
            # Concatenate image + joint
            return current_state + non_fixed_current_joint_state
        elif(space == "image_embedding"):
            curr_frame = self.env.render(mode='rgb_array') # Capture image
            rgb_array = np.array(curr_frame)
            rgb_array = Image.fromarray(rgb_array)
            rgb_array = np.array(rgb_array)
            preprocessed_image = transforms(Image.fromarray(rgb_array.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            return subgoal_embedding.cpu().tolist()[0]

    def save_video(self, video_filename, video_filename_in_hand):
        video_fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        video_out = cv2.VideoWriter(video_filename, video_fourcc, 30.0, self.video_resolution)
        for i in range(0 , len(self.video_frames),4 ):
            frame = self.video_frames[i]
            video_out.write(frame)
        video_out.release()

class MLP(nn.Module):
    def __init__(self, input_dim, output_dim , dropout, batch_normalization, hidden_layers=[1024, 512, 256], activation=nn.ReLU()): # hidden layers = [1024,512,256] ensures network is large enough to learn a lot
        super(MLP, self).__init__()
        self.batch_normalization=batch_normalization
        layers = []
        in_dim = input_dim
        for h_dim in hidden_layers:
            layers.append(nn.Linear(in_dim, h_dim))
            if(batch_normalization):
                layers.append(nn.BatchNorm1d(h_dim))
            layers.append(activation)
            layers.append(nn.Dropout(dropout))
            in_dim = h_dim
        layers.append(nn.Linear(in_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x): # x must be a tensor of input dimension
        if(self.batch_normalization and x.dim() == 1):
            x = x.unsqueeze(0)
            return self.model(x).squeeze()
        else:
            return self.model(x)

# Dateset for Behavioural cloning
class TrajectoryDataset(Dataset):
    def __init__(self, Trajectory_directories, base_directory , state_space,subgoal_conditioned, subgoal_format, subgoal_directory_path, camera_number, subgoal_frames_delta, subgoal_network_input, subgoal_change_format ,timestamp_input , action_chunking, subgoal_change_model):
        self.Trajectory_directories = Trajectory_directories # List of all the directories 
        self.base_directory= base_directory
        self.state_space = state_space
        self.subgoal_conditioned = subgoal_conditioned
        self.subgoal_format = subgoal_format
        self.subgoal_directory_path = subgoal_directory_path
        self.camera_number = camera_number
        self.subgoal_frames_delta = subgoal_frames_delta
        self.action_chunking = action_chunking
        self.trajectories = self._load_trajectories()
        self.subgoal_change_model = subgoal_change_model # bollean value True or False determining what type of model is this, useful only in subgoal_change_format = different_network
        self.subgoal_network_input = subgoal_network_input
        self.subgoal_change_format = subgoal_change_format
        self.timestamp_input = timestamp_input

    def _read_csv(self, file_path, directory): 
        with open(file_path, 'rb') as f: # Read the pickel file
            data_dict = pickle.load(f)

        observations = data_dict['observations']  # Shape: (244, 60)
        actions = data_dict['actions']  # Shape: (244, 9)
        data = []
        for i in range(observations.shape[0]):
            observation = observations[i]
            action = actions[i]
            row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
            data.append(row)
        # For every state append the embedding of image to the row
        video_path = f"{self.base_directory}/{directory}/camera_{camera_number}.avi"
        cap = cv2.VideoCapture(video_path)
        for i in range(len(data)):
            ret, frame = cap.read()  # ret is a boolean indicating success, frame is the image
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
            preprocessed_image = preprocessed_image.to(device)
            with torch.no_grad():
                subgoal_embedding = vip(preprocessed_image * 255.0)
            data[i].extend(subgoal_embedding.cpu().tolist()[0])
        cap.release()
        # data is 60(joint) + 3(task/buffer) + 1(time) + 9(action) + 1024(image embedding)
        if(self.subgoal_conditioned):
            subgoals_directory = f"{self.base_directory}/{directory}/{self.subgoal_directory_path}"
            files = os.listdir(subgoals_directory)
            png_files = [f for f in files if f.endswith('.png')]
            numbers = [int(f.replace('.png', '')) for f in png_files]
            list_of_subgoals = sorted(numbers) # This is the sorted list of all the subgoals for some trajectory
            list_of_subgoals.pop(0)
            for i in range(len(data)):
                subgoal_index = self.get_subgoal_index(i , list_of_subgoals)
                if( (abs(i - subgoal_index) <= subgoal_frames_delta )): # or (abs(i - last_subgoal_index) <= subgoal_frames_delta)  ):
                    data[i].append(1)
                else:
                    data[i].append(0)
                if(self.subgoal_format == "joint_space"):
                    data[i].extend(data[subgoal_index][:60])
                elif(self.subgoal_format == "both"):
                    data[i].extend(data[subgoal_index][73:1097]) # image embedding
                    data[i].extend(data[subgoal_index][:60]) # joint state
                elif(self.subgoal_format == "image_embedding"): # get the image embedding of subgoal_index frame in the video
                    data[i].extend(data[subgoal_index][73:1097]) # 20+1024
        list_of_actions = []
        for i in range(len(data)):
            action = data[i][64:73]
            for j in range(i+1 , i+self.action_chunking):
                if (j >= len(data)):
                    action+= [0.,0.,0.,0.,0.,0.,0.,0.,0.]
                else:
                    action+= data[j][64:73]
            list_of_actions.append(action)
        return data , list_of_actions

    def get_subgoal_index(self, ind, list_of_subgoals): # This gives the subgoal number for some index
        for i in list_of_subgoals:
            if(i>= ind):
                return i

    def _load_trajectories(self):
        trajectories = []
        for directory in self.Trajectory_directories:
            base_directory = f"{self.base_directory}/{directory}"
            file_path = f"{base_directory}/data.pkl"
            trajectory_data , list_of_actions = self._read_csv(file_path, directory)
            for i in range(len(trajectory_data)):
                trajectories.append( (trajectory_data[i] , list_of_actions[i]) )
        return trajectories

    def __len__(self):
        return len(self.trajectories)

    def __getitem__(self, idx): # This gives the exact state, action pair
        trajectory_data,list_of_actions = self.trajectories[idx]
        if(self.state_space == "joint_space"):
            state = trajectory_data[0:60]
        elif(self.state_space == "both"):
            state = trajectory_data[73:1097] + trajectory_data[0:60]
        elif(self.state_space == "image_embedding"):
            state = trajectory_data[73:1097]
        if(self.subgoal_conditioned):
            if(self.subgoal_change_model and self.subgoal_network_input == "subtract"): # subtract the two inputs before giving to the model
                state = [(a-b)**2 for a,b in zip(state, trajectory_data[1098:])]
            else: # if it is not for subgoal change network or the input is append type to  a subgoal change network
                state+= trajectory_data[1098:] # 1045 not 1044 because of {1, 0} bit of subgoal change
        if(self.timestamp_input):
            state += trajectory_data[63:64] # This is timestamp at that point
        state = torch.tensor(state , dtype=torch.float32)
        if(self.subgoal_conditioned and self.subgoal_change_format == "same_network"):
            action = torch.tensor(trajectory_data[1097:1098] + list_of_actions, dtype = torch.float32) # 1 bit subgoal transtion + 8 bits joint space action
        elif(self.subgoal_change_model): # The model predicts whether subgoal has achieved or not
            action = torch.tensor(trajectory_data[1097:1098] , dtype = torch.float32) # 1044 is the 0,1 bit which tells whether subgoal is achieved or not
        else:
            action = torch.tensor(list_of_actions, dtype=torch.float32)
        return state, action

# Define custom loss function
class BCEWithLogitsLoss_MSELoss(nn.Module): # The first bit detects subgoal change and last 8 bits detect joint movement
    def __init__(self):
        super(BCEWithLogitsLoss_MSELoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(6.0, dtype=torch.float32))
        self.mse_loss = nn.MSELoss()

    def forward(self, predicted_value, target):
        bce = self.bce_loss(predicted_value[:, 0], target[:, 0])
        mse = self.mse_loss(predicted_value[:, 1:], target[:, 1:])
        total_loss = bce + 100*mse # Combine the two losses (you can adjust the weights if needed)
        return total_loss

def find_largest_number(file_path): # Takes in a directory which contains files of the form number.mp4 and find the largest numbered file inside it
    with open(file_path, 'r') as file:
        lines = file.readlines()
    last_line = lines[-1].strip()
    first_word = last_line.split()[0]
    first_word_int = int(first_word)
    return first_word_int

if __name__ == '__main__':
    # Parameters
    train = True
    eval = True
    output_dimension = 9 # Action will always be 8 dimensional 7 dimension joint angles + 1 dimension gripper

    state_space = "both" # "joint_space", "both", "image_embedding"
    if(state_space == "joint_space"):
        input_dimension =60
    elif(state_space == "both"):
        input_dimension =1024+60 # image , joint
    elif(state_space == "image_embedding"):
        input_dimension = 1024 # See how many dimension is the image embedding????

    Trajectory_directories = ['1.1', '1.2', '1.3', '1.4', '1.5', '2.1', '2.2', '2.3', '2.4', '2.5', '3.1', '3.2', '3.3', '3.4', '3.5', '4.1', '4.2', '4.3', '4.4', '4.5', '5.1', '5.2', '5.3', '5.4', '5.5', '6.1', '6.2', '6.3', '6.4', '6.5', '7.1', '7.2', '7.3', '7.4', '7.5', '8.1', '8.2', '8.3', '8.4', '8.5', '9.1', '9.2', '9.3', '9.4', '9.5', '10.1', '10.2', '10.3', '10.4', '10.5', '11.1', '11.2', '11.3', '11.4', '11.5', '12.1', '12.2', '12.3', '12.4', '12.5', '13.1', '13.2', '13.3', '13.4', '13.5', '14.1', '14.2', '14.3', '14.4', '14.5', '15.1', '15.2', '15.3', '15.4', '15.5', '16.1', '16.2', '16.3', '16.4', '16.5', '17.1', '17.2', '17.3', '17.4', '17.5', '18.1', '18.2', '18.3', '18.4', '18.5', '19.1', '19.2', '19.3', '19.4', '19.5', '20.1', '20.2', '20.3', '20.4', '20.5', '21.1', '21.2', '21.3', '21.4', '21.5', '22.1', '22.2', '22.3', '22.4', '22.5', '23.1', '23.2', '23.3', '23.4', '23.5', '24.1', '24.2', '24.3', '24.4', '24.5', '25.1', '25.2', '25.3', '25.4', '25.5']
    list_of_subgoals_directory_to_eval = ['1.1', '1.2', '1.3', '1.4', '1.5', '1.6', '1.7', '1.8', '1.9', '1.10', '2.1', '2.2', '2.3', '2.4', '2.5', '2.6', '2.7', '2.8', '2.9', '2.10', '3.1', '3.2', '3.3', '3.4', '3.5', '3.6', '3.7', '3.8', '3.9', '3.10', '4.1', '4.2', '4.3', '4.4', '4.5', '4.6', '4.7', '4.8', '4.9', '4.10', '5.1', '5.2', '5.3', '5.4', '5.5', '5.6', '5.7', '5.8', '5.9', '5.10', '6.1', '6.2', '6.3', '6.4', '6.5', '6.6', '6.7', '6.8', '6.9', '6.10', '7.1', '7.2', '7.3', '7.4', '7.5', '7.6', '7.7', '7.8', '7.9', '7.10', '8.1', '8.2', '8.3', '8.4', '8.5', '8.7', '8.8', '8.9', '8.10', '9.1', '9.2', '9.3', '9.4', '9.5', '9.6', '9.7', '9.8', '9.9', '9.10', '10.1', '10.2', '10.3', '10.4', '10.5', '10.6', '10.7', '10.8', '10.9', '10.10', '11.1', '11.2', '11.3', '11.4', '11.5', '11.6', '11.7', '11.8', '11.9', '11.10', '12.1', '12.2', '12.3', '12.4', '12.5', '12.6', '12.7', '12.8', '12.9', '12.10', '13.1', '13.2', '13.3', '13.4', '13.5', '13.6', '13.7', '13.8', '13.9', '13.10', '14.1', '14.2', '14.3', '14.4', '14.5', '14.6', '14.7', '14.8', '14.9', '14.10', '15.1', '15.2', '15.3', '15.4', '15.5', '15.6', '15.7', '15.8', '15.9', '15.10', '16.1', '16.2', '16.3', '16.4', '16.5', '16.6', '16.7', '16.8', '16.9', '16.10', '17.1', '17.2', '17.3', '17.4', '17.5', '17.6', '17.7', '17.8', '17.9', '17.10', '18.1', '18.2', '18.3', '18.4', '18.5', '18.6', '18.7', '18.8', '18.9', '18.10', '19.1', '19.2', '19.3', '19.4', '19.5', '19.6', '19.7', '19.8', '19.9', '19.10', '20.1', '20.2', '20.3', '20.4', '20.5', '20.6', '20.7', '20.8', '20.9', '20.10', '21.1', '21.2', '21.3', '21.4', '21.5', '21.6', '21.7', '21.8', '21.9', '21.10', '22.1', '22.2', '22.3', '22.4', '22.5', '22.6', '22.7', '22.8', '22.9', '22.10', '23.1', '23.2', '23.3', '23.4', '23.5', '23.6', '23.7', '23.8', '23.9', '23.10', '24.1', '24.2', '24.3', '24.4', '24.5', '24.6', '24.7', '24.8', '24.9', '24.10', '25.1', '25.2', '25.3', '25.4', '25.5', '25.6', '25.7', '25.8', '25.9', '25.10']
    total_number_of_iterations=1 # number of iterations per task (helpful with gaussian noise)

    num_epochs = 10000 # number of epochs on the training dataset
    lr = 0.0003
    dropout = 0.0
    timestamp_input = False
    gaussian_noise = False
    batch_normalization = True
    in_hand_eval = False # Get in hand camera video or not
    camera_number = 2 # Camera for subgoals
    hidden_layers = [1024, 512, 256] # [1024, 2048, 4096] #[1024, 512, 256]
    action_chunking = 25 # Action chunking = 1 means only 1 step prediction
    temporal_ensemble = 0 # Weight given for combining actions
    output_dimension*=action_chunking

    subgoal_conditioned = True # goal conditioned subgoals
    subgoal_format= None # This is subgoal state space
    subgoal_directory_path = None
    subgoal_change_format = None # How to change subgoal using epsilon or neural nets
    subgoal_network_input = None # valid only for different network configuration, "subtract" , "append"    
    subgoal_frames_delta = 0 # This is number of frames before the subgoal we consider the subgoal as achieved

    if(subgoal_conditioned):
        subgoal_format = "both" # "joint_space", "both", "image_embedding"
        subgoal_change_format = "epsilon" # same_network, different_network, epsilon. This tells how do we detect whether a subgoal is achieved or not during inference
        subgoal_directory_path = f"decomposed_frames/mininterval_18/divisions_1/gamma_0.08/camera_{camera_number}" # Can change 0.08 to something else if required
        if(subgoal_format == "joint_space"):
            input_dimension+= 60
        elif(subgoal_format == "both"):
            input_dimension += 1024+60
        elif(subgoal_format == "image_embedding"):
            input_dimension += 1024 # See how many dimension is the image embedding????

        if(subgoal_change_format == "different_network"): # If we use a different network to predict subgoal achievement, then can subtract inputs or append them in the network
            subgoal_network_input = "append" #  "subtract" , "append"

    if(timestamp_input):
        input_dimension+=1
    if(subgoal_conditioned):
        if(subgoal_change_format == "same_network"):
            model = MLP(input_dimension, output_dimension +1 , dropout, batch_normalization,hidden_layers=hidden_layers).to(device) # Extra 1 bit for the subgoal change output (logit)
        elif(subgoal_change_format == "different_network"):
            model = MLP(input_dimension, output_dimension, dropout, batch_normalization,hidden_layers=hidden_layers).to(device)
            if(subgoal_network_input == "subtract"):
                if(timestamp_input):
                    model_subgoal_change = MLP(input_dimension//2+1 , 1 , dropout, batch_normalization,hidden_layers=hidden_layers).to(device) # This model is used to detect a subgoal change, we will give (input - subgoal)**2 as input to the network
                else:
                    model_subgoal_change = MLP(input_dimension//2 , 1 , dropout, batch_normalization,hidden_layers=hidden_layers).to(device) # This model is used to detect a subgoal change, we will give (input - subgoal)**2 as input to the network
            elif(subgoal_network_input == "append"):
                model_subgoal_change = MLP(input_dimension , 1 , dropout, batch_normalization,hidden_layers=hidden_layers).to(device)
        elif(subgoal_change_format == "epsilon"):
            model = MLP(input_dimension, output_dimension, dropout, batch_normalization,hidden_layers=hidden_layers).to(device)
    else: # when there is no subgoals used while training
        model = MLP(input_dimension, output_dimension, dropout, batch_normalization,hidden_layers=hidden_layers).to(device)

    saving_formatter = str(find_largest_number("./Parameter_mappings.txt")+1)

    with open('./Parameter_mappings.txt', 'a') as file:
        file.write(f'{saving_formatter}        : state_space_{state_space}_num_epochs_{num_epochs}_lr_{lr}_subgoal_conditioned_{subgoal_conditioned}_subgoal_{subgoal_format}_camera_{camera_number}_batch_normalization_{batch_normalization}_dropout_{dropout}_gaussian_noise_{gaussian_noise}_subgoal_change_format_{subgoal_change_format}_subgoal_frames_delta_{subgoal_frames_delta}_subgoal_network_input_{subgoal_network_input}_timestamp_input_{timestamp_input}_hidden_layers_{hidden_layers}_action_chunking_{action_chunking}_temporal_ensemble_{temporal_ensemble}_Training_directory_{Trajectory_directories}\n')  # Add a newline character to separate lines
    model_dump_file_path = f"./Trained_Models/{saving_formatter}.pth"
    model_subgoal_dump_file_path = f"./Trained_Models/{saving_formatter}_subgoal_change_model.pth"
    base_directory = f"./../../Data_Franka_Kitchen"
    print(model, flush=True)
    total_params = sum(p.numel() for p in model.parameters())
    print("Total number of parameters in the neural network is: ", total_params, flush=True)

    if(train):
        train_start_time = time.time()
        trajectory_dataset = TrajectoryDataset(Trajectory_directories, base_directory , state_space, subgoal_conditioned, subgoal_format, subgoal_directory_path, camera_number, subgoal_frames_delta, subgoal_network_input , subgoal_change_format, timestamp_input, action_chunking, subgoal_change_model=False ) # subgoal change model= False means the model is not for detecting subgoal change, used only with different network subgoal change format
        data_loader = DataLoader(trajectory_dataset, batch_size=512, shuffle=True, num_workers=4)
        if(subgoal_conditioned and subgoal_change_format == "same_network"): # add a new loss function here
            loss_function = BCEWithLogitsLoss_MSELoss() # This is a custom loss function defined above for only this type of network
        else: # if not subgoal conditioned or if subgoal change format is epsilon or different network
            loss_function = torch.nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) # cosine decay of learning rate
        # Supervised Learning Loop
        for epoch in range(num_epochs):
            model.train()  # Set the model to training mode
            running_loss = 0.0  # Initialize running loss for the epoch
            num_batches = 0     # Initialize batch counter
            for batch in data_loader:
                states, actions = batch
                states = states.to(device)
                actions = actions.to(device)
                optimizer.zero_grad()
                predicted_actions = model(states)
                loss = loss_function(predicted_actions, actions)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()  # Accumulate loss
                num_batches += 1             # Increment batch counter

            scheduler.step()

            if(epoch%50 == 0): # Print loss every 50 epochs
                current_lr = optimizer.param_groups[0]['lr'] # Current learning rate
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/num_batches}, Learning Rate: {current_lr}", flush=True)

        torch.save(model.state_dict(), model_dump_file_path)
        print(f"Model saved to {model_dump_file_path}", flush=True)

        if(subgoal_conditioned and subgoal_change_format == "different_network"): # Train a new network to detect a subgoal change in this case
            print("Training the subgoal change network...")
            num_epochs = 250 # Training epochs for subgoal change network
            trajectory_dataset = TrajectoryDataset(Trajectory_directories, base_directory , state_space, subgoal_conditioned, subgoal_format, subgoal_directory_path, camera_number, subgoal_frames_delta ,subgoal_network_input,subgoal_change_format ,timestamp_input , action_chunking, subgoal_change_model=True )
            data_loader = DataLoader(trajectory_dataset, batch_size=64, shuffle=True, num_workers=4)
            loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(10.0, dtype=torch.float32)) # Give 10 times more weight to positive samples than negative samples
            optimizer = optim.Adam(model_subgoal_change.parameters(), lr=lr)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs) # cosine decay of learning rate
            # Supervised Learning Loop
            for epoch in range(num_epochs):
                model_subgoal_change.train()  # Set the model to training mode
                for batch in data_loader:
                    states, actions = batch
                    states, actions = states.to(device), actions.to(device)
                    optimizer.zero_grad()
                    predicted_actions = model_subgoal_change(states)
                    loss = loss_function(predicted_actions, actions)
                    loss.backward()
                    optimizer.step()
                scheduler.step()
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

            torch.save(model_subgoal_change.state_dict(), f"{model_subgoal_dump_file_path}")
            print(f"Model saved to {model_subgoal_dump_file_path}")

        print("Time taken to train the model is ", (time.time() - train_start_time)/3600 , " hrs")

    if(eval):

        length_of_trajectories_during_inference = {}
        reset_info_of_trajectories_during_inference = {} # information about the trajectory to infer
        for directory in list_of_subgoals_directory_to_eval:
            file_path = f"{base_directory}/{directory}/data.pkl" # pkl file path
            with open(file_path, 'rb') as f: # Read the pickel file
                data_dict = pickle.load(f)
                length_of_trajectories_during_inference[directory] = data_dict['observations'].shape[0]
                reset_info_of_trajectories_during_inference[directory] = (data_dict['init_qpos'] , data_dict['init_qvel']) 

        for iteration_number in range(1,total_number_of_iterations+1,1): # Number of times to evaluate a single trajectory, to get the evaluation metrics
            for directory_for_subgoals in list_of_subgoals_directory_to_eval: # These are all the trajectories to get subgoals from and evaluate
                video_resolution = (224, 224) # This is during evaluation
                reset_information  = reset_info_of_trajectories_during_inference[directory_for_subgoals]
                robot_env = Robotic_Environment(video_resolution, gaussian_noise, camera_number, reset_information, in_hand_eval)

                def robot_inference(directory_for_subgoals): # Function to actually evaluate the neural network
                    max_steps = length_of_trajectories_during_inference[directory_for_subgoals]
                    if(subgoal_conditioned):
                        subgoals_directory = f"{base_directory}/{directory_for_subgoals}/{subgoal_directory_path}"
                        files = os.listdir(subgoals_directory)
                        png_files = [f for f in files if f.endswith('.png')]
                        numbers = [int(f.replace('.png', '')) for f in png_files]
                        list_of_subgoals = sorted(numbers) # This is the sorted list of all the subgoals for some trajectory
                        list_of_subgoals.pop(0) # Dont want initial state to be a subgoal
                        actual_subgoals = [] # This is either 8 or 4 or 1024 dimensional
                        if(subgoal_format == "joint_space"):
                            file_path = f"{base_directory}/{directory_for_subgoals}/data.pkl"
                            with open(file_path, 'rb') as f: # Read the pickel file
                                data_dict = pickle.load(f)
                            observations = data_dict['observations']  # Shape: (244, 60)
                            actions = data_dict['actions']  # Shape: (244, 9)
                            data = []
                            for i in range(observations.shape[0]):
                                observation = observations[i]
                                action = actions[i]
                                row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
                                data.append(row) # data now in csv format
                            for subgoal_index in list_of_subgoals:
                                actual_subgoals.append(data[subgoal_index][:60])
                        elif(subgoal_format == "both"):
                            video_path = f"{base_directory}/{directory_for_subgoals}/camera_{camera_number}.avi"
                            cap = cv2.VideoCapture(video_path)
                            for subgoal_index in list_of_subgoals:
                                cap.set(cv2.CAP_PROP_POS_FRAMES, subgoal_index)
                                ret, frame = cap.read()
                                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                                preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
                                preprocessed_image = preprocessed_image.to(device)
                                with torch.no_grad():
                                    subgoal_embedding = vip(preprocessed_image * 255.0)
                                actual_subgoals.append(subgoal_embedding.cpu().tolist()[0])
                            cap.release()
                            file_path = f"{base_directory}/{directory_for_subgoals}/data.pkl"
                            with open(file_path, 'rb') as f: # Read the pickel file
                                data_dict = pickle.load(f)
                            observations = data_dict['observations']  # Shape: (244, 60)
                            actions = data_dict['actions']  # Shape: (244, 9)
                            data = []
                            for i in range(observations.shape[0]):
                                observation = observations[i]
                                action = actions[i]
                                row = list(observation) + [0., 0., 0.] + [i] + list(action) # Create the row: 60 observation columns + 3 buffer columns + 1 timestamp + 9 action columns
                                data.append(row) # data now in csv format
                            for iterator in range(len(list_of_subgoals)):
                                subgoal_index = list_of_subgoals[iterator]
                                actual_subgoals[iterator] += data[subgoal_index][:60] # Add the 
                        elif(subgoal_format == "image_embedding"):
                            video_path = f"{base_directory}/{directory_for_subgoals}/camera_{camera_number}.avi"
                            cap = cv2.VideoCapture(video_path)
                            for subgoal_index in list_of_subgoals:
                                cap.set(cv2.CAP_PROP_POS_FRAMES, subgoal_index)
                                ret, frame = cap.read()
                                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                                preprocessed_image = transforms(Image.fromarray(frame.astype(np.uint8))).reshape(-1, 3, 224, 224)
                                preprocessed_image = preprocessed_image.to(device)
                                with torch.no_grad():
                                    subgoal_embedding = vip(preprocessed_image * 255.0)
                                actual_subgoals.append(subgoal_embedding.cpu().tolist()[0])
                            cap.release()
                    model.load_state_dict(torch.load(model_dump_file_path))
                    model.eval()  # Set the model to evaluation mode
                    if(subgoal_conditioned):
                        current_subgoal_index = 0
                        if(subgoal_change_format == "epsilon"): # Define the subgoal change epsilon values for different tasks and state spaces
                            if(subgoal_format == "joint_space"):
                                subgoal_change_epsilon = 0.008
                            elif(subgoal_format == "image_embedding"):
                                subgoal_change_epsilon = 0.01
                            elif(subgoal_format == "both"):
                                subgoal_change_epsilon = 0.0001
                        elif(subgoal_change_format == "different_network"):
                            model_subgoal_change.load_state_dict(torch.load(model_subgoal_dump_file_path))
                            model_subgoal_change.eval()
                        
                    Buffer = [[] for _ in range(max_steps + action_chunking)]  # Initialize buffer correctly
                    for i in range(max_steps):
                        if(subgoal_conditioned and (current_subgoal_index == len(actual_subgoals)) ):
                            break
                        if(i%100==0):
                            print(f"Timestamp: {i}/{max_steps}")
                        state = robot_env.get_current_state(state_space)
                        if(subgoal_conditioned):
                            state += actual_subgoals[current_subgoal_index]
                        if(timestamp_input):
                            state += [i] # Add current timestamp as input
                        state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
                        with torch.no_grad(): 
                            action = model(state_tensor)
                        action = action.cpu().tolist()

                        if(subgoal_conditioned and subgoal_change_format == "same_network"):
                            subgoal_achievement_bit = action[0]
                            action = action[1:] # action without the subgoal achievement bit

                        action = np.array(action, dtype='float32')
                        action = action.reshape(action_chunking, output_dimension // action_chunking)  # Reshape into action chunks

                        for j in range(action_chunking): # Add the action chunks to the buffer
                            Buffer[i + j].append(action[j])
                        weights = np.exp(-temporal_ensemble * np.arange(len(Buffer[i])))  # Perform temporal ensemble: weighted average of the actions
                        weights /= weights.sum()  # Normalize weights
                        current_action = np.sum([w * a for w, a in zip(weights, Buffer[i])], axis=0)
                        current_action = current_action.tolist()  # Convert to list before passing to `step`
                        robot_env.step(current_action)

                        if(subgoal_conditioned): # see if the subgoal is achieved then transition to the next one
                            if(subgoal_change_format == "same_network"):
                                if(subgoal_achievement_bit > 0):
                                    print(f"Subgoal number {current_subgoal_index+1} achieved at timestamp {i}")
                                    current_subgoal_index+=1
                            elif(subgoal_change_format == "different_network"):
                                state = robot_env.get_current_state(subgoal_format)
                                if(subgoal_network_input == "subtract"):
                                    state = [(a-b)**2 for a,b in zip(state, actual_subgoals[current_subgoal_index])]
                                elif(subgoal_network_input == "append"):
                                    state += actual_subgoals[current_subgoal_index]
                                if(timestamp_input):
                                    state += [i]
                                state_tensor = torch.tensor(state, dtype=torch.float32).to(device)
                                with torch.no_grad(): 
                                    subgoal_change_bit = model_subgoal_change(state_tensor)
                                subgoal_change_bit = subgoal_change_bit.cpu().tolist()
                                if(subgoal_change_bit[0] >0): # Change the subgoal
                                    print(f"Subgoal number {current_subgoal_index} achieved at timestamp {i}")
                                    current_subgoal_index+=1
                            elif(subgoal_change_format == "epsilon"):
                                if ( np.mean((np.array(robot_env.get_current_state(subgoal_format)) - np.array(actual_subgoals[current_subgoal_index]))**2)  < subgoal_change_epsilon):
                                    print(f"Subgoal number {current_subgoal_index} achieved at timestamp {i}")
                                    current_subgoal_index+=1

                print(f"Evaluating subgoals from {directory_for_subgoals}, iteration number {iteration_number}...", flush=True)
                robot_inference(directory_for_subgoals)
                print("------------------------------")

                video_path = f"./Evaluation/{saving_formatter}/subgoals_{directory_for_subgoals}"
                os.makedirs(video_path, exist_ok=True) # Directory to save Evaluation Videos
                video_filename = f"{video_path}/{iteration_number}.mp4"
                video_filename_in_hand = f"{video_path}/{iteration_number}_in_hand.mp4"
                robot_env.save_video(video_filename, video_filename_in_hand)