import torch
from torch import nn
from .visual_areas_neural_network.visual_network import VisualNetwork
from .decision_making_areas_neural_network.decision_making_network import DecisionMakingNetwork
from .action_planning_areas_neural_network.action_planning_network import ActionPlanningNetwork

class BrainLikeModel(nn.Module):
    def __init__(self, config_dict):
        super().__init__()
        self.version_identifier = config_dict["version_identifier"]
        
        config_dict["visual_network"]["version_identifier"] = config_dict["version_identifier"]
        config_dict["decision_making_network"]["version_identifier"] = config_dict["version_identifier"]
        config_dict["action_planning_network"]["version_identifier"] = config_dict["version_identifier"]
        
        self.visual_network = VisualNetwork(config_dict["visual_network"])
        self.decision_making_network = DecisionMakingNetwork(config_dict["decision_making_network"])
        self.action_planning_network = ActionPlanningNetwork(config_dict["action_planning_network"])
        
        self.eeg_loss_weight = config_dict["eeg_loss_weight"]
        self.canbus_loss_weight = config_dict["canbus_loss_weight"]
        
    def forward(self, batch_dict):
        self.visual_network(batch_dict)
        
        large_version = self.version_identifier.split(".")[0]
        small_version = self.version_identifier.split(".")[1]
        if large_version == "v1":
            eeg_info_output = self.decision_making_network(batch_dict)
            canbus_info_output = self.action_planning_network(batch_dict)
        
            return eeg_info_output, canbus_info_output
        elif large_version == "v2":
            visual_info_processed = batch_dict["visual_info_processed"]
            batch_size, n_frame, n_feature = visual_info_processed.shape
            # hidden dimension for decision making/action planning
            d_hidden_dm = self.decision_making_network.d_hidden
            d_hidden_ap = self.action_planning_network.d_hidden
            # result for two lstm
            output_lstm_eeg = []
            output_lstm_canbus = []
            # first, initialize the hidden state and cell state for the two lstm
            h_dm = torch.zeros((batch_size, d_hidden_dm), dtype=torch.float32).to(visual_info_processed.device)
            c_dm = torch.zeros((batch_size, d_hidden_dm), dtype=torch.float32).to(visual_info_processed.device)
            h_ap = torch.zeros((batch_size, d_hidden_ap), dtype=torch.float32).to(visual_info_processed.device)
            c_ap = torch.zeros((batch_size, d_hidden_ap), dtype=torch.float32).to(visual_info_processed.device)
            # Second, iterate the data along the timeline
            for frame_id in range(n_frame):
                # (batch_size, n_feature)
                frame_visual_info = visual_info_processed[:,frame_id,:]
                # the lstm_eeg process first, shape of output tensor is (batch_size, d_hidden_dm)
                h_dm, c_dm = self.decision_making_network.lstm_eeg(frame_visual_info, (h_dm, c_dm))
                # the lstm_canbus process after that, shape of output tensor is (batch_size, d_hidden_ap)
                h_ap, c_ap = self.action_planning_network.lstm_canbus(h_dm, (h_ap, c_ap))
                # save the output of each lstm
                output_lstm_eeg.append(h_dm.unsqueeze(1)) # (batch_size, 1, d_hidden_dm), the 1 is set for cat at last
                output_lstm_canbus.append(h_ap.unsqueeze(1)) # (batch_size, 1, d_hidden_ap)
            # get the final output
            output_lstm_eeg = torch.cat(output_lstm_eeg, dim=1) # (batch_size, n_frame, d_hidden_dm)
            output_lstm_canbus = torch.cat(output_lstm_canbus, dim=1) # (batch_size, n_frame, d_hidden_ap)
            
            # now we generate the eeg and lstm's hidden info
            # now we can try to generate the real world eeg/canbus
            batch_dict["output_lstm_eeg"] = output_lstm_eeg
            batch_dict["output_lstm_canbus"] = output_lstm_canbus
            
            # generate the real world eeg and save to batch_dict
            eeg_info_output = self.decision_making_network(batch_dict)
            # generate the real world canbus and save to batch_dict
            canbus_info_output = self.action_planning_network(batch_dict)
            
            return eeg_info_output, canbus_info_output
        else:
            raise Exception("not implemented!")
        
    def get_loss(self, batch_dict):
        eeg_loss = self.decision_making_network.get_loss(batch_dict)
        canbus_loss = self.action_planning_network.get_loss(batch_dict)
        
        total_loss = self.eeg_loss_weight*eeg_loss + self.canbus_loss_weight*canbus_loss
        
        return total_loss