import torch
from torch import nn

class SLIMP(nn.Module):
    def __init__(self, 
                 tabular_model_patient, 
                 tabular_model_lesion, 
                 vit_model, 
                 d_model,
                 loss_fn=None,
                 lambda_outer=0.1):
        super(SLIMP, self).__init__()
        self.trace_patient = tabular_model_patient
        self.trace_lesion = tabular_model_lesion
        self.vit = vit_model
        self.d_model = d_model
        self.loss_fn = loss_fn
        self.lambda_outer = lambda_outer

        if self.trace_patient is not None:
            self.d_tabular_model_patient = tabular_model_patient.hidden_size  

        if self.trace_lesion is not None:
            self.d_tabular_model_lesion = tabular_model_lesion.hidden_size
            self.lesion_fc = nn.Linear(d_model+d_model, d_model)

        self.head = nn.Identity()      
    
    def forward_patient_level(self, x_patient):
        # Pass clinical data through TRACE model for all the patients of the batch
        patient_tabular, _ = x_patient
        z1 = self.trace_patient(patient_tabular) # clinical represenation output: (B, D)
        return z1

    def forward_lesion_level(self, x_lesion, patient_idx):
        lesion_tab, lesion_img, _ = x_lesion
        # Pass lesion metadata through TRACE model for all lesions of the current patient
        z2 = self.trace_lesion(lesion_tab[patient_idx]) # lesion metadata represenation output: (N, D)
        # Pass lesion images through ViT model for all lesions of the current patient
        z3 = self.vit(lesion_img[patient_idx]) # lesion image output: (N, D)

        #TODO: consider cross attention pooling here
        z_lesion_total = torch.cat((z2, z3), dim=-1) # (N, 2D)
        z_lesion_total = torch.mean(z_lesion_total, dim=0) # (1, 2D)
        z_lesion_total = self.lesion_fc(z_lesion_total) # (1, D)

        return z2, z3, z_lesion_total
    
    def forward_nested_loss(self, x_patient, x_lesion):
        z1 = self.forward_patient_level(x_patient)
        
        z_lesion_list = []
        total_inner_loss=0
        num_patients = x_patient[0].size(0)
        
        for p in range(num_patients):
            z2, z3, z_lesion_total = self.forward_lesion_level(x_lesion, p)
            inner_loss = self.loss_fn(z2,z3)/num_patients
            total_inner_loss += inner_loss    
            z_lesion_list.append(z_lesion_total)
        z_lesion_batch = torch.stack(z_lesion_list, dim=0)

        outer_loss = self.loss_fn(z1, z_lesion_batch)

        total_loss = (1 - self.lambda_outer) * (total_inner_loss) + self.lambda_outer * outer_loss
        return total_loss, total_inner_loss, outer_loss
    
    def forward_flattened_loss(self, x_lesion):
        lesion_tab, lesion_img, _ = x_lesion
        # Pass lesion metadata through TRACE model for all lesions of the batch
        z2 = self.trace_lesion(lesion_tab) # lesion metadata represenation output: (N, D)
        # Pass lesion images through ViT model for all lesions of the batch
        z3 = self.vit(lesion_img) # lesion image output: (N, D)
        inner_loss = self.loss_fn(z2,z3)
        total_loss = inner_loss
        return total_loss, inner_loss
    
    def forward_features_per_lesion(self, x, args, return_attn=False):
        # per-lesion dataset format
        if len(x) == 4: # self.trace_lesion is not None and self.trace_patient is not None
            lesion_tab, patient_tabular, lesion_img, _ = x
        elif len(x) == 3: # self.trace_lesion is not None
            lesion_tab, lesion_img, _ = x
        elif len(x) == 2:
            lesion_img, _ = x            
        
        # input tensor: (B, C, H, W)
        if return_attn:
            z3, image_attn = self.vit.get_last_selfattention(lesion_img)
        else:
            z3 = self.vit(lesion_img) # lesion image output: (B, D)
        if args.image_only:
            return z3 if not return_attn else (z3, image_attn) # (B, D)
        
        # input tensor: (B, metadata_features)
        if return_attn:
            z2, lesion_attn = self.trace_lesion(lesion_tab, return_attn=True) # lesion metadata represenation output: (B, D) 
        else:
            z2 = self.trace_lesion(lesion_tab) # lesion metadata represenation output: (B, D) 
        if args.inner_only:
            return torch.cat((z2, z3), dim=-1) if not return_attn else (torch.cat((z2, z3), dim=-1), (lesion_attn, image_attn)) # (B, 2D)
        
        # input tensor: (B, patient_features)
        if return_attn:
            z1, patient_attn = self.trace_patient(patient_tabular, return_attn=True) # patient represenation output: (B, D)
        else:
            z1 = self.trace_patient(patient_tabular) # patient represenation output: (B, D)
        return torch.cat((z1, z2, z3), dim=-1) if not return_attn else (torch.cat((z1, z2, z3), dim=-1), (patient_attn, lesion_attn, image_attn)) # (B, 3D)
    
    def forward(self, *args, return_features=False, **kwargs):
        if kwargs.get("mode", "default") == "nested_loss":
            return self.forward_nested_loss(*args)
        elif kwargs.get("mode", "default") == "flattened_loss":
            return self.forward_flattened_loss(*args)
        else:
            x = self.forward_features_per_lesion(*args, **kwargs)
            if return_features:
                return x
            else:
                return self.head(x)