import torch
import torch.nn as nn
import torch.optim as optim



class FFLinear(nn.Module):
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        alpha: float = 4.0,
        dropout: float = 0.0,
        use_bias: bool = False,
    ):
        super(FFLinear, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.alpha = alpha

        self.use_bias = use_bias
        self.disabled = False
        self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
        self.linear = nn.Linear(in_features=in_dim, out_features=out_dim, bias=use_bias)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        previous_dtype = x.dtype
        x = x.to(torch.float32)

        out = self.linear(x) 

        return out.to(previous_dtype)

def main():

    torch.cuda.reset_peak_memory_stats(device=None)
    torch.cuda.empty_cache()
    initial1=torch.cuda.max_memory_allocated(device=None)
    print(f"Initial1 gpu used {initial1} memory")

    input_dim = 1024
    output_dim = 1024

    ff_ft = FFLinear(in_dim=input_dim, out_dim=output_dim).to("cuda")
    print('ft',torch.cuda.max_memory_allocated(device=None))
    x = torch.randn(1,1,1,1024).to("cuda")
    target =  torch.randn(1,1,1024).to("cuda")
    criterion = nn.MSELoss().to("cuda")
    optimizer = optim.Adam(ff_ft.parameters(), lr=0.0001)
    print('targr+op_i',torch.cuda.max_memory_allocated(device=None))


    for step, batch_data in enumerate(x):

        print(batch_data.shape)
        output = ff_ft(batch_data.to("cuda"))
        print("Forward pass output shape:", output.shape)
        print('fw',torch.cuda.max_memory_allocated(device=None))

        optimizer.zero_grad()
        loss = criterion(output, target)
        
        loss.backward()
        print('bw',torch.cuda.max_memory_allocated(device=None))
        optimizer.step()
        print('op_s',torch.cuda.max_memory_allocated(device=None))


if __name__ == "__main__":
    main()