import torch 
import numpy as np
import torch.nn as nn

class C8(nn.Module):
    """Represents the C8 class of group theory, where each element represents a discrete rotation."""

    def __init__(self):
        super().__init__()
        self.register_buffer('identity', torch.Tensor([0.]))
    
    def size(self):
        """Outputs the size of this group."""
        return 8

    def elements(self):
        """Returns all the elements of this group"""
        delta = np.pi / 4
        return torch.tensor([0., delta, delta * 2, delta * 3, delta * 4, delta * 5, delta * 6, delta * 7])
    
    def product(self, h, g):
        """Compute the product of two elements g and h in the group C8"""
        return torch.remainder(h + g, 2 * np.pi)
    
    def inverse(self, h):
        """Computes the inverse of the element h in the group C8"""
        return torch.remainder(-h, 2 * np.pi)
    
    def matrix_representation(self, h):
        """Returns the matrix representation of this element"""
        cos_t = torch.cos(h)
        sin_t = torch.sin(h)
        representation = torch.tensor([
            [cos_t, -sin_t],
            [sin_t, cos_t]
        ], device=self.identity.device)
        return representation