import torch.nn as nn
import torch.nn.functional as F
import torch
from torch import nn, Tensor
from .basic import MLP
from .data_representation import Batch, BatchIndicator
from rotary_embedding_torch import RotaryEmbedding



class Transformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4, num_heads=8, **kwargs):
        super(Transformer, self).__init__()
       
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, dim_feedforward=hidden_dim,
            batch_first=True)
            for _ in range(num_layers)
        ])

        encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, 
                         nhead=num_heads, dim_feedforward=hidden_dim)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)

    def forward(self, x, return_attn=False):
        data = x.data
        indicator = x.n_nodes
        order = x.order

        if return_attn == True:
            attention_weights = []

            for layer in self.layers:
                data, weights = layer.self_attn(data, data, data, need_weights=True)
                attention_weights.append(weights)

            return Batch.from_batched(data=data, order=order, n_nodes=indicator), attention_weights

        else:
            data = self.encoder(data)
            return Batch.from_batched(data=data, order=order, n_nodes=indicator)

# class RotaryTransformer(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4, num_heads=8, **kwargs):
#         super(Transformer, self).__init__()
       
#         self.layers = nn.ModuleList([
#             nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, dim_feedforward=hidden_dim,
#             batch_first=True)
#             for _ in range(num_layers)
#         ])

#         encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, 
#                          nhead=num_heads, dim_feedforward=hidden_dim)
#         self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)

#     def forward(self, x, return_attn=False):
#         data = x.data
#         indicator = x.n_nodes
#         order = x.order

#         if return_attn == True:
#             attention_weights = []

#             for layer in self.layers:
#                 data, weights = layer.self_attn(data, data, data, need_weights=True)
#                 attention_weights.append(weights)

#             return Batch.from_batched(data=data, order=order, n_nodes=indicator), attention_weights

#         else:
#             #apply rotary encoding here
#             data = self.encoder(data)
#             return Batch.from_batched(data=data, order=order, n_nodes=indicator)


class ConvexHullNNTransformer(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, transformer_output_dim, 
                 output_dim, depth, num_heads, return_attn=False, *args):
        super().__init__()
        self.return_attn = return_attn
        self.initial = nn.Linear(in_features=input_dim, out_features=embedding_dim)
        self.transformer = Transformer(input_dim=embedding_dim, hidden_dim=hidden_dim, 
                                       output_dim=transformer_output_dim, num_layers=depth, num_heads=num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        out = self.initial(x)
        if self.return_attn:
            out, attention_maps = self.transformer(out, self.return_attn)  # Retrieve both output and attention maps
            out = self.mlp(out)
            return out, attention_maps  # Return attention maps along with the model output
        else:
            out = self.transformer(out, self.return_attn)
            out = self.mlp(out)
            return out
    def get_approx_chull(self, probabilities, x: Tensor|Batch):

        n = x.n_nodes[0].item() #todo: assuming constant ptset size throughout batch
        
        probabilities = probabilities.view(-1, n, probabilities.data.size(-1))
        probabilities = F.softmax(probabilities, dim=1)
        probabilities = probabilities.view(-1, probabilities.data.size(-1))
        
        hulls = []
        start = 0
        for num in x.n_nodes:
            end = start + num
            ptset = x.data[start:end]
            ptset_probs = probabilities.data[start:end]
            # print(f'probs (from softmax): {ptset_probs}')
            hull_approx = torch.mm(ptset_probs.T, ptset)
            hulls.append(hull_approx)
            start = end
          
   
        # # out =  Batch.from_batched(hulls, n_nodes = x.n_nodes, order = 1)
        # # print(out.data.shape)
        # print(len(hulls))
        # print(len(hulls[0]))
        return hulls

class ConvexHullNNTransformer_L1(ConvexHullNNTransformer):
    def get_approx_chull(self, probabilities, x: Tensor|Batch):

        n = x.n_nodes[0].item() #todo: assuming constant ptset size throughout batch
        
        probabilities = probabilities.view(-1, n, probabilities.data.size(-1))
        probabilities = F.leaky_relu(probabilities) ##only for L1 norm
        probabilities = F.normalize(probabilities, p = 1.0, dim = 1)
        probabilities = probabilities.view(-1, probabilities.data.size(-1))
        
        hulls = []
        start = 0
        for num in x.n_nodes:
            end = start + num
            ptset = x.data[start:end]
            ptset_probs = probabilities.data[start:end]
            # print(f'probs (from softmax): {ptset_probs}')
            hull_approx = torch.mm(ptset_probs.T, ptset)
            hulls.append(hull_approx)
            start = end
          
        return hulls
       



class RotaryMultiheadAttention(nn.Module):
    def __init__(self, dim, heads, rotary_emb):
        super().__init__()
        self.heads = heads
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5

        self.rotary_emb = rotary_emb
        self.to_qkv = nn.Linear(dim, dim * 3)
        self.out_proj = nn.Linear(dim, dim)


                # out = out.view(-1, n, out.data.size(-1))

    def forward(self, x):
        ### reshaping
        n = x.n_nodes[0].item() #todo: assuming pointsets in batch have same size (holds for our data)
        dim = x.data.shape[-1]
        batch_size = x.data.shape[0] // n

        x = x.view(batch_size, n, dim)


        B, N, C = x.data.shape

        ## todo: permute x here


        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = [t.reshape(B, N, self.heads, self.head_dim).transpose(1, 2) for t in qkv]

        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn_weights = attn_scores.softmax(dim=-1)

        out = torch.matmul(attn_weights, v)
        out = out.transpose(1, 2).reshape(B, N, C)
        return self.out_proj(out)


class RotaryTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.rotary_emb = RotaryEmbedding(dim = d_model // nhead)
        self.self_attn = RotaryMultiheadAttention(d_model, nhead,
                                                     rotary_emb = self.rotary_emb)


        # Feedforward layers
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        # Normalization layers
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src):
        # Self-attention block
        src2 = self.self_attn(src)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # Feedforward block
        src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        return src

class RotaryTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4, num_heads=8, **kwargs):
        super(RotaryTransformer, self).__init__()

        self.layers = nn.ModuleList([
            RotaryTransformerEncoderLayer(
                d_model=input_dim,
                nhead=num_heads,
                dim_feedforward=hidden_dim
            )
            for _ in range(num_layers)
        ])
    
    def forward(self, x, return_attn=False):
        data = x#.data
        indicator = x.n_nodes
        order = x.order

        n = indicator[0]

        if return_attn:
            attention_weights = []

            for layer in self.layers:
                data, weights = layer(data, return_attn=True)
                attention_weights.append(weights)

            return Batch.from_batched(data=data, order=order, n_nodes=indicator), attention_weights

        else:
            for layer in self.layers:
                data = layer(data)

            return Batch.from_batched(data=data, order=order, n_nodes=indicator)


class ConvexHullNNRotaryTransformer(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, transformer_output_dim, 
                 output_dim, depth, num_heads, return_attn=False, *args):
        super().__init__()
        self.return_attn = return_attn
        self.initial = nn.Linear(in_features=input_dim, out_features=embedding_dim)
        self.transformer = RotaryTransformer(input_dim=embedding_dim, hidden_dim=hidden_dim, 
                                       output_dim=transformer_output_dim, num_layers=depth, num_heads=num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        out = self.initial(x)
        print(f'Initial: {type(out)}')
        if self.return_attn:
            out, attention_maps = self.transformer(out, self.return_attn)  # Retrieve both output and attention maps
            out = self.mlp(out)
            return out, attention_maps  # Return attention maps along with the model output
        else:
            out = self.transformer(out, self.return_attn)
            out = self.mlp(out)
            return out
    def get_approx_chull(self, probabilities, x: Tensor|Batch):

        n = x.n_nodes[0].item() #todo: assuming constant ptset size throughout batch
        
        probabilities = probabilities.view(-1, n, probabilities.data.size(-1))
        probabilities = F.softmax(probabilities, dim=1)
        probabilities = probabilities.view(-1, probabilities.data.size(-1))
        
        hulls = []
        start = 0
        for num in x.n_nodes:
            end = start + num
            ptset = x.data[start:end]
            ptset_probs = probabilities.data[start:end]
            # print(f'probs (from softmax): {ptset_probs}')
            hull_approx = torch.mm(ptset_probs.T, ptset)
            hulls.append(hull_approx)
            start = end
          
   
        # # out =  Batch.from_batched(hulls, n_nodes = x.n_nodes, order = 1)
        # # print(out.data.shape)
        # print(len(hulls))
        # print(len(hulls[0]))
        return hulls


    


class ConvexHullEncoderTransformer(nn.Module):
    def __init__(self, input_dim, encoder_depth, encoder_width, encoder_output_dim,
                transformer_depth, num_heads, transformer_od, processor_embedding_dim, 
                processor_hidden_dim, processor_output_dim, return_attn=False, **config):
        super(ConvexHullEncoderTransformer, self).__init__()

        self.return_attn = return_attn

        self.encoder = MLP(input_dim, *[encoder_width]*encoder_depth, encoder_output_dim, 
                            batchnorm=False, activation=nn.LeakyReLU)
        self.processor = ConvexHullNNTransformer(input_dim = encoder_output_dim, embedding_dim=processor_embedding_dim, hidden_dim=processor_hidden_dim, 
        transformer_output_dim=transformer_od, output_dim=processor_output_dim, depth=transformer_depth, num_heads=num_heads,
        return_attn = self.return_attn)


    def forward(self, x):

        # for name, param in self.encoder.named_parameters():
        #     print(f"Encoder {name}: {param.requires_grad}")
        # for name, param in self.processor.named_parameters():
        #     print(f"Processor {name}: {param.requires_grad}")


        out = self.encoder(x)
       
        if self.return_attn == True:
            out, attn_maps = self.processor(out)
            return out, attn_maps
        else:
            out =  self.processor(out)
            return out


# class Transformer(nn.Module):
#     def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4, num_heads=8, **kwargs):
#         super(Transformer, self).__init__()

#         encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, 
#                         nhead=num_heads, dim_feedforward=hidden_dim)
#         self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
    
#     def forward(self, x: Tensor|Batch):
#         data = x.data
#         indicator = x.n_nodes
#         order = x.order

#         data, weights = self.encoder(data)

#         return Batch.from_batched(data=data, order=order, n_nodes=indicator), weights



# class ConvexHullNNTransformer(nn.Module):
#     def __init__(self, input_dim, embedding_dim, hidden_dim, transformer_output_dim, 
#                 output_dim, depth, num_heads, *args):
#         super().__init__()
#         self.initial = nn.Linear(in_features=input_dim, out_features=embedding_dim)
#         self.transformer = Transformer(input_dim = embedding_dim, hidden_dim=hidden_dim, 
#                             output_dim = transformer_output_dim, num_layers=depth, num_heads=num_heads)
       

#         self.mlp = nn.Sequential(
#             nn.Linear(embedding_dim, hidden_dim),
#             nn.LeakyReLU(),
#             nn.Linear(hidden_dim, hidden_dim),
#             nn.LeakyReLU(),
#             nn.Linear(hidden_dim, output_dim)
#         )
        
       
#     def forward(self, x: Tensor|Batch):
#         out = self.initial(x)
#         out = self.transformer(out)
#         # print(f'size after transformer {x.data.shape}')
#         out = self.mlp(out)

#         # print(f'out data shape (forward): {out.data.shape}')
        
#         return out


