import torch
from torch.cuda.amp import autocast
from transformers import AutoModel

class BrainEncodingModel(AutoModel):
    def __init__(
        self,
        model,
        model_type: str,
        model_name: str,
        device: str,
    ):
        self.apply_pca = model.apply_pca
        self.per_layer_pca = model.per_layer_pca
        self.language_model = model.model
        self.model_type = model_type
        self.model_name = model_name
        self.device = device
    
    def update_weights(self, linear_weights: torch.Tensor):
        in_dim, out_dim = linear_weights.shape       # (H, V)

        # Do we already have a layer of the correct shape?
        must_recreate = (
            not hasattr(self, "linear_layer") or
            self.linear_layer.in_features  != in_dim or
            self.linear_layer.out_features != out_dim
        )

        if must_recreate:
            # Free the old layer
            if hasattr(self, "linear_layer"):
                old_layer = self.linear_layer
                del self.linear_layer
                del old_layer
                torch.cuda.empty_cache()

            # Build the new layer
            self.linear_layer = torch.nn.Linear(
                in_dim, out_dim, bias=False
            ).to(self.device)

        # Copy the new weights in-place (no extra allocation)
        with torch.no_grad():
            self.linear_layer.weight.copy_(linear_weights.t().float())

    def forward(self, token_ids, attention_mask, layer_idx, token_idxs_to_avg, true_activity, roi_masks, n_steps=None, input_embeds=False, **kwargs):
        R = 1
        if roi_masks is not None:
            R = len(roi_masks)
        B = token_ids.shape[0] // R

        if self.model_type == 'encoder' or self.model_type == 'decoder':
            if input_embeds:
                outputs = self.language_model(inputs_embeds=token_ids, attention_mask=attention_mask, output_hidden_states=True)
            else:
                outputs = self.language_model(token_ids, attention_mask=attention_mask, output_hidden_states=True)
            layer_embeddings = outputs['hidden_states'][1+layer_idx]
        elif self.model_type == 'encoder-decoder':
            if input_embeds:
                outputs = self.language_model(inputs_embeds=token_ids, attention_mask=attention_mask, decoder_input_ids=token_ids, output_hidden_states=True)
            else:
                outputs = self.language_model(token_ids, attention_mask=attention_mask, decoder_input_ids=token_ids, output_hidden_states=True)
            encoder_hidden_states = outputs['encoder_hidden_states']
            decoder_hidden_states = outputs['decoder_hidden_states']
            hidden_states = encoder_hidden_states + decoder_hidden_states
            layer_embeddings = hidden_states[1+layer_idx] # (num_words_in_tr*R, num_tokens, H)
                
        # Reshape layer embeddings to (R, num_words_in_tr, num_tokens, H)
        layer_embeddings = layer_embeddings.view(R, B, -1, layer_embeddings.shape[-1])
        
        # Average token embeddings for the tokens in the context's last word
        tr_embeddings = torch.zeros(layer_embeddings.shape[0], layer_embeddings.shape[1], layer_embeddings.shape[-1]).to(self.device)
        if len(token_idxs_to_avg) < layer_embeddings.shape[1] and not input_embeds: # GradientSHAP
            for b_idx in range(layer_embeddings.shape[1]):
                ctx_idx = b_idx // n_steps            # 0 0 0 0 0 1 1 1 …
                tr_embeddings[:, b_idx, :] = layer_embeddings[
                    :, b_idx, token_idxs_to_avg[ctx_idx], :
                ].mean(dim=1) # (R, 4, H)
        else:
            if len(token_idxs_to_avg) < layer_embeddings.shape[1]: # IG
                token_idxs_to_avg = token_idxs_to_avg * n_steps

            for b_idx in range(layer_embeddings.shape[1]):
                tr_embeddings[:, b_idx, :] = torch.mean(layer_embeddings[:, b_idx, token_idxs_to_avg[b_idx], :], dim=1) # (R, 4, H)
        tr_embeddings = tr_embeddings.mean(dim=1) # (R, H)
 
        # Apply PCA
        if self.apply_pca:
            layer_trs = torch.nan_to_num(tr_embeddings)
            layer_trs = (layer_trs - layer_trs.mean(axis=0)) / (layer_trs.std(axis=0) + 1e-10)
            pca = self.per_layer_pca[layer_idx]
            tr_embeddings = pca.transform(layer_trs) # (R, num_components)

        pred_activity = self.linear_layer(tr_embeddings) # (R, V)

        # Compute MSE loss between true and predicted activity
        true_activity = torch.as_tensor(true_activity,
                                        device=pred_activity.device,
                                        dtype=pred_activity.dtype) # (V,) or (1,V)
        true_activity = true_activity.expand_as(pred_activity) # (R, V)

        # 2) squared error for every voxel
        mse_loss = (pred_activity - true_activity) ** 2 # (R, V)

        # 3) apply ROI masks and average inside each ROI
        if roi_masks is not None:
            mask_f = torch.stack([torch.from_numpy(roi_mask).float() for roi_mask in roi_masks]).to(mse_loss.device) # (R, V)
            mse_loss = (mse_loss * mask_f).sum(dim=1) / mask_f.sum(dim=1) # (R,)
        else:
            mse_loss = mse_loss.mean(dim=1)
        return mse_loss

    @torch.inference_mode
    @autocast()
    def brain_alignment_forward(self, tr_token_ids, tr_attention_mask, layer_idx, token_idxs_to_avg, true_activity, roi_masks, device, num_delays=4, **kwargs):
        embeddings_to_concatenate = []
        for tr_idx, token_ids in enumerate(tr_token_ids):
            num_masks = token_ids.shape[0]
            R = len(roi_masks)
            B = token_ids.shape[1] // R

            batch_tok = token_ids.view(-1, token_ids.shape[-1]) # (num_masks*4*R, num_tokens)
            batch_att = tr_attention_mask[tr_idx].view(-1, tr_attention_mask[tr_idx].shape[-1]) # (num_masks*4*R, num_tokens)

            if self.model_type == 'encoder' or self.model_type == 'decoder':
                outputs = self.language_model(batch_tok.to(device), attention_mask=batch_att.to(device), output_hidden_states=True)
                layer_embeddings = outputs['hidden_states'][1+layer_idx]
            elif self.model_type == 'encoder-decoder':
                outputs = self.language_model(batch_tok.to(device), attention_mask=batch_att.to(device), decoder_input_ids=batch_tok.to(device), output_hidden_states=True)
                encoder_hidden_states = outputs['encoder_hidden_states']
                decoder_hidden_states = outputs['decoder_hidden_states']
                hidden_states = encoder_hidden_states + decoder_hidden_states
                layer_embeddings = hidden_states[1+layer_idx] # (num_masks*4*R, num_tokens, H)
                    
            # Reshape layer embeddings to (num_masks*R, B, num_tokens, H)
            layer_embeddings = layer_embeddings.view(num_masks*R, B, -1, layer_embeddings.shape[-1])
            
            # Average token embeddings for the tokens in the context's last word
            tr_embeddings = torch.zeros(layer_embeddings.shape[0], layer_embeddings.shape[1], layer_embeddings.shape[-1]).to(self.device)
            for b_idx in range(layer_embeddings.shape[1]):
                tr_embeddings[:, b_idx, :] = torch.mean(layer_embeddings[:, b_idx, token_idxs_to_avg[tr_idx][b_idx], :], dim=1) # (num_masks*R, 4, H)
            tr_embeddings = tr_embeddings.mean(dim=1) # (num_masks*R, H)
    
            # Apply PCA
            if self.apply_pca:
                layer_trs = torch.nan_to_num(tr_embeddings)
                layer_trs = (layer_trs - layer_trs.mean(axis=0)) / (layer_trs.std(axis=0) + 1e-10)
                pca = self.per_layer_pca[layer_idx]
                tr_embeddings = pca.transform(layer_trs) # (num_masks*R, num_components)
            
            embeddings_to_concatenate.append(tr_embeddings)

        # Concatenate all TR embeddings
        cat_tr_embeddings = torch.cat(embeddings_to_concatenate, dim=1) # (num_masks*R, num_components x num_delays)
        # Pad with zeros if there are not enough preceding TRs
        if cat_tr_embeddings.shape[-1] < tr_embeddings.shape[-1] * num_delays:
            padding_size = tr_embeddings.shape[-1] * num_delays - cat_tr_embeddings.shape[-1]
            padding = torch.zeros(cat_tr_embeddings.shape[0], padding_size, device=cat_tr_embeddings.device, dtype=cat_tr_embeddings.dtype)
            cat_tr_embeddings = torch.cat([cat_tr_embeddings, padding], dim=1)

        pred_activity = self.linear_layer(cat_tr_embeddings) # (num_masks*R, V)

        # Compute MSE loss between true and predicted activity
        true_activity = torch.as_tensor(true_activity,
                                        device=pred_activity.device,
                                        dtype=pred_activity.dtype) # (V,) or (1,V)
        true_activity = true_activity.expand_as(pred_activity) # (num_masks*R, V)

        # 2) squared error for every voxel
        sq_err = (pred_activity - true_activity) ** 2 # (num_masks*R, V)

        # 3) apply ROI masks and average inside each ROI
        mask_f = torch.stack([torch.from_numpy(roi_mask).float() for roi_mask in roi_masks]).to(sq_err.device) # (R, V)
        mask_f = mask_f.repeat_interleave(num_delays, dim=0) # (num_masks*R, V)
        mse_loss = (sq_err * mask_f).sum(dim=1) / mask_f.sum(dim=1) # (num_masks*R,)
        mse_loss = mse_loss.view(num_masks, R) # (num_masks, R)

        return mse_loss
