import torch
import torch.nn as nn


from models.QuantumLayer import QuantumLayer
from models.Embedding import Embedding



class QNerfModel(nn.Module):
    def __init__(self,
                 embedding_dim_pos=10,
                 embedding_dim_direction=4,
                 hidden_dim=128,
                 hidden_layers = 3,
                 n_qubits=8,
                 rep=1,
                 ansatz='default',
                 pad_with=0,
                 use_scaler=True,
                 use_BE=False,
                 custom_name = None,
                 with_skip = False,
                 skip_at = 0,
                 device = 'cpu'
                 ):
        
        super(QNerfModel, self).__init__()
        self.device = device

        # Initialize QuantumLayer
        self.n_qubits = n_qubits
        self.rep = rep
        self.pad_with = pad_with
        
        self.ansatz = ansatz

        self.embedding_dim_pos = embedding_dim_pos
        self.embedding_dim_direction = embedding_dim_direction

        self.use_scaler = use_scaler
        self.use_BE = use_BE


        self.real_input_size = 6 *(1 + embedding_dim_pos +  embedding_dim_direction)

        self.embed_func = Embedding(input_dim = self.real_input_size,
                            output_dim = 2**n_qubits,
                            hidden_dim = hidden_dim,
                            num_hidden_layers = hidden_layers,
                            with_skip = with_skip,
                            skip_at = skip_at,
                            use_BE = use_BE,
                            device=device,)

        self.quantum_layer = QuantumLayer(n_qubits=self.n_qubits, rep=self.rep,
                                          ansatz=self.ansatz,
                                          pad_with=self.pad_with, device=self.device)

        self.DEBUG = False
        self.block1 = nn.Sequential(
            self.quantum_layer,
            nn.ReLU(),
        )

        if self.use_scaler:
            self.scaler_R = nn.Parameter(torch.tensor(1.0, requires_grad=True))
            self.scaler_G = nn.Parameter(torch.tensor(1.0, requires_grad=True))
            self.scaler_B = nn.Parameter(torch.tensor(1.0, requires_grad=True))
            self.scaler_sigma = nn.Parameter(torch.tensor(1.0, requires_grad=True))

        self.relu = nn.ReLU()
        if custom_name is not None:
            self.model_name = custom_name
        else:
            self.model_name = f"QNeRF_{n_qubits}_{rep}_{embedding_dim_pos}_{embedding_dim_direction}"

    @staticmethod
    def positional_encoding(x, L, DEBUG=False):
        print(f"[Positional Encoding] Input shape: {x.shape}") if DEBUG else None
        out = [x]
        for j in range(L):
            out.append(torch.sin(2 ** j * x))
            out.append(torch.cos(2 ** j * x))
        result = torch.cat(out, dim=1)
        print(f"[Positional Encoding] Output shape: {result.shape}") if DEBUG else None
        return result



    def forward(self, o, d):

        #print(f"o device: {o.device}")
        #print(f"d device: {d.device}")

        print("=" * 60) if self.DEBUG else None
        print(f"[Forward Pass] Input Shapes: o (positions): {o.shape}, d (directions): {d.shape}") if self.DEBUG else None
        print("=" * 60) if self.DEBUG else None

        # Positional encoding for spatial positions
        emb_x = self.positional_encoding(o, self.embedding_dim_pos)
        print(f"[Encoded Positions] Shape: {emb_x.shape}") if self.DEBUG else None

        # Positional encoding for directions
        emb_d = self.positional_encoding(d, self.embedding_dim_direction)
        print(f"[Encoded Directions] Shape: {emb_d.shape}") if self.DEBUG else None


        final_input = torch.cat((emb_x, emb_d), dim=1)
        print(f"[Final Input] Shape: {final_input.shape}") if self.DEBUG else None
        q_emb = self.embed_func(final_input)
        print(f"[Quantum Embedding] Shape: {q_emb.shape}") if self.DEBUG else None
        h = self.block1(q_emb)


        if self.n_qubits > 4:
            # Initialize a tensor for the 4 output groups
            averages = torch.zeros(h.shape[0], 4, device=h.device)  # Shape: [batch, 4]

            # Iterate through groups and compute averages
            for i in range(4):
                group_indices = torch.arange(self.n_qubits, device=h.device) % 4 == i
                group_elements = h[:, group_indices]  # Select elements in the group
                if group_elements.shape[1] > 0:  # Avoid empty groups
                    averages[:, i] = group_elements.mean(dim=1)  # Compute mean along n_qubits

            print(f"[Averages] Shape: {averages.shape}") if self.DEBUG else None
            h = averages


        print(self.scaler) if self.DEBUG else None

        print(f"[Final Output] Shape: {h.shape}") if self.DEBUG else None

        if self.use_scaler:
            cR = h[:, :1]*self.scaler_R
            cG = h[:, 1:2]*self.scaler_G
            cB = h[:, 2:3]*self.scaler_B
            c = torch.cat((cR, cG, cB), dim=1)
            sigma = h[:, 3]*self.scaler_sigma
        else:
            c = h[:, :3]
            sigma = h[:, 3]
        
        print(f"[Final Outputs] RGB Colors (c): {c.shape}, Density (sigma): {sigma.shape}") if self.DEBUG else None
        print(f"RGB Min: {c.min().item()}, RGB Max: {c.max().item()}, Mean: {c.mean().item()}") if self.DEBUG else None
        print("=" * 60) if self.DEBUG else None

            

        return c, sigma


