import torch
import torch.nn as nn
import os
import sys
import numpy as np
import copy
import warnings
warnings.filterwarnings("ignore")
from typing import Optional
from torch import Tensor
from torch.nn import  Module
import torch.nn.functional as F
from torch.nn.modules.container import ModuleList


class Custom_block(Module):

    __constants__ = ['norm']

    def __init__(
        self,
        block,
        num_layers: int,
        norm: Optional[Module] = None) -> None:
        
        super().__init__()
        torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
        self.layers = ModuleList([copy.deepcopy(block) for i in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def _get_attn_weight(self):
        weight_list = []
        for mod in self.layers:
            weight_list.append(mod._get_attn_weight())
        
        return weight_list

    def forward(
            self,
            src: Tensor,
            single_eval_pos: int,
            topk:Optional[int]=None) -> Tensor:

        output = src
        first_layer = self.layers[0]
        why_not_sparsity_fast_path = ''
        str_first_layer = "self.layers[0]"
        is_fastpath_enabled = torch.backends.mha.get_fastpath_enabled()
        ortho_loss_list = torch.tensor([], device=src.device, dtype=src.dtype)

        if not is_fastpath_enabled:
            why_not_sparsity_fast_path = "torch.backends.mha.get_fastpath_enabled() was not True"
        elif not hasattr(self, "use_nested_tensor"):
            why_not_sparsity_fast_path = "use_nested_tensor attribute not present"
        elif not self.use_nested_tensor:
            why_not_sparsity_fast_path = "self.use_nested_tensor (set in init) was not True"
        elif first_layer.training:
            why_not_sparsity_fast_path = f"{str_first_layer} was in training mode"
        elif output.is_nested:
            why_not_sparsity_fast_path = "NestedTensor input is not supported"
        elif torch.is_autocast_enabled():
            why_not_sparsity_fast_path = "autocast is enabled"

        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                first_layer.self_attn.in_proj_weight,
                first_layer.self_attn.in_proj_bias,
                first_layer.self_attn.out_proj.weight,
                first_layer.self_attn.out_proj.bias,
                first_layer.norm1.weight,
                first_layer.norm1.bias,
                first_layer.norm2.weight,
                first_layer.norm2.bias,
                first_layer.linear1.weight,
                first_layer.linear1.bias,
                first_layer.linear2.weight,
                first_layer.linear2.bias,
            )
            _supported_device_type = ["cpu", "cuda", torch.utils.backend_registration._privateuse1_backend_name]
            if torch.overrides.has_torch_function(tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
            elif src.device.type not in _supported_device_type:
                why_not_sparsity_fast_path = f"src device is neither one of {_supported_device_type}"
            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
                why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
                                              "input/output projection weights or biases requires_grad")

        for mod in self.layers:
            output, ortho_loss = mod(output, single_eval_pos=single_eval_pos)
            ortho_loss_list = torch.concat((ortho_loss_list,ortho_loss))
        
        return output, ortho_loss_list