from src.models.hyena import HyenaDNAModel

import torch.nn as nn
import torch
from transformers.models.gpt2.modeling_gpt2 import GPT2Config, GPT2Model, GPT2LMHeadModel

class HyenaWithTaskHead(nn.Module):
    def __init__(self, d_model: int, vocab_size: int, is_reg=False, **hyena_kwargs):
        super().__init__()
        self.is_reg = is_reg
        if is_reg:
            self.linear = nn.Conv1d(d_model, 1, kernel_size=1)
        else:
            self.linear = nn.Conv1d(d_model, vocab_size, kernel_size=1)
        self.hyena = HyenaDNAModel(d_model=d_model, vocab_size=vocab_size, **hyena_kwargs)

    def forward(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
        # input_ids: (batch_size, seq_length)
        # labels: (batch_size, label_seq_length)
        assert len(input_ids.size()) == 2
        assert len(labels.size()) == 2
        assert input_ids.size(0) == labels.size(0)
        label_seq_length = labels.size(-1)

        # compute the output and cut the output to the label_seq_length
        output = self.hyena(input_ids)
        assert input_ids.size(0) == output.size(0) and input_ids.size(1) == output.size(1)
        output = output[:, -label_seq_length:, :] # (batch_size, label_seq_length, d_model)

        if self.is_reg:
            # compute the regression output
            return self.linear(output.transpose(-1, -2)).transpose(-1, -2)
        else:
            # compute the logits for each token
            logits = self.linear(output.transpose(-1, -2)) # (batch_size, vocab_size, label_seq_length)
            return logits.transpose(-1, -2) # (batch_size, label_seq_length, vocab_size)
    
    def param_count(self):
        ssm_param_count = 0
        mlp_param_count = 0
        for block in self.hyena.backbone.layers:
            ssm_param_count += sum(p.numel() for p in block.mixer.parameters())
            mlp_param_count += sum(p.numel() for p in block.mlp.parameters())
        if hasattr(self.hyena.backbone, "front_mlp"):
            mlp_param_count += sum(p.numel() for p in self.hyena.backbone.front_mlp.parameters())
        all_param_count = sum(p.numel() for p in self.parameters())
        return {"ssm": ssm_param_count, "mlp": mlp_param_count, "all": all_param_count}

class GPT2WithTaskHead(nn.Module):
    def __init__(self, d_model: int, vocab_size: int, n_layer: int, d_inner: int, 
                 is_reg=False,
                 **kwargs):
        super().__init__()
        self.is_reg = is_reg

        # print unused keyword arguments
        for k, v in kwargs.items():
            print(f"Unused keyword argument: {k}={v}")
        
        # initialize the GPT2 model
        self.gpt2_config = GPT2Config(vocab_size=vocab_size, 
                                      n_embd=d_model, 
                                      n_layer=n_layer, 
                                      n_head=8 if d_model // 8 >= 2 else d_model // 2,
                                      n_inner=d_inner)
        
        if is_reg:
            self.gpt2 = GPT2Model(self.gpt2_config)
            self.linear = nn.Conv1d(d_model, 1, kernel_size=1)
        else:
            self.lm_head_gpt2 = GPT2LMHeadModel(self.gpt2_config)
    
    def forward(self, input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
        # input_ids: (batch_size, seq_length)
        # labels: (batch_size, label_seq_length)
        assert len(input_ids.size()) == 2
        assert len(labels.size()) == 2
        assert input_ids.size(0) == labels.size(0)
        label_seq_length = labels.size(-1)

        if self.is_reg:
            # compute the output and cut the output to the label_seq_length
            output = self.gpt2(input_ids).last_hidden_state
            output = output[:, -label_seq_length:, :]
            return self.linear(output.transpose(-1, -2)).transpose(-1, -2)
        else:
            output = self.lm_head_gpt2(input_ids).logits
            logits = output[:, -label_seq_length:, :]
            return logits
    
    def param_count(self):
        attn_param_count = 0
        mlp_param_count = 0
        blocks =  self.gpt2.h if self.is_reg else self.lm_head_gpt2.transformer.h
        for block in blocks:
            attn_param_count += sum(p.numel() for p in block.attn.parameters())
            mlp_param_count += sum(p.numel() for p in block.mlp.parameters())
        all_param_count = sum(p.numel() for p in self.parameters())
        return {"attn": attn_param_count, "mlp": mlp_param_count, "all": all_param_count}