import torch
import torch.nn as nn
import torch.nn.functional as F
from .performer_pytorch.performer_pytorch import Performer
from .nanogpt import GPT, GPTConfig
from utils.utils import gumbel_softmax

class ActivationFunctions:
    @staticmethod
    def get(name):
        activations = {
            'relu': nn.ReLU(),
            'gelu': nn.GELU(),
            'tanh': nn.Tanh(),
            'sigmoid': nn.Sigmoid(),
            'silu': nn.SiLU()
        }
        return activations.get(name, nn.ReLU())

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers, act_fn, layernorm=True):
        super().__init__()
        layers = []
        current_size = input_size

        for _ in range(num_layers):
            layers.extend([
                nn.Linear(current_size, hidden_size),
                ActivationFunctions.get(act_fn)
            ])
            if layernorm:
                layers.append(nn.LayerNorm(hidden_size))
            current_size = hidden_size

        layers.append(nn.Linear(current_size, output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

class DatastructureModel(nn.Module):
    def __init__(self, n_inputs, n_outputs, input_dim, permute, extra_space_dim, 
                 output_dim=1, n_embd=128, n_layer=12, n_head=4, arch='transformer'):
        super().__init__()

        config = GPTConfig(
            block_size=max(n_embd*4, n_inputs),
            n_layer=n_layer,
            n_head=n_head,
            n_embd=n_embd,
            causal_self_attn=False,
        )

        self.n_inputs = n_inputs
        self.n_outputs = n_outputs
        self.permute = permute
        self.arch = arch
        
        self._read_in = nn.Linear(input_dim, n_embd)

        self._backbone = self._select_backbone(arch, config, n_embd, n_layer, n_head, input_dim)

        self._configure_output_layers(permute, n_embd, output_dim, extra_space_dim)

        self.n_extra_tokens = max(n_outputs - n_inputs, 0)
        if self.n_extra_tokens > 0:
            self.embedding = nn.Embedding(self.n_extra_tokens, n_embd)

    def _select_backbone(self, arch, config, n_embd, n_layer, n_head, input_dim):
        backbones = {
            'transformer': GPT(config),
            'lintrans': Performer(dim=n_embd, depth=n_layer, heads=n_head, causal=False, dim_head=n_embd//n_head),
            'mlp': MLP(input_size=input_dim, hidden_size=1024, output_size=1, num_layers=3, act_fn='relu')
        }
        return backbones.get(arch, backbones['transformer'])

    def _configure_output_layers(self, permute, n_embd, output_dim, extra_space_dim):
        if permute:
            self._read_out = nn.Linear(n_embd, 1)
            self._extra_read_out = nn.Linear(n_embd, extra_space_dim) if self.n_outputs > self.n_inputs else None
        else:
            self._read_out = nn.Linear(n_embd, output_dim)

    def get_extra_tokens(self, batch_size, device='cuda:0'):
        return self.embedding(torch.arange(self.n_extra_tokens)[None].repeat(batch_size, 1).to(device))

    def forward(self, input):
        B, N, d = input.shape

        if self.arch == 'mlp':
            return self._backbone(input.view(B*N, d)).view(B, N, 1), None

        embeds = self._read_in(input)
        
        if self.n_extra_tokens > 0:
            embeds = torch.cat([embeds, self.get_extra_tokens(B, device=input.device)], dim=1)
        
        output = self._backbone(embeds)[:, :self.n_outputs]

        if self.permute:
            permute_outs = self._read_out(output[:, :self.n_inputs])
            extra_outs = None
            if self.n_extra_tokens > 0:
                extra_outs = self._extra_read_out(output[:, self.n_inputs:]) if self._extra_read_out else None
            return permute_outs, extra_outs
        else:
            return self._read_out(output)

class MLPQueryModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        
        n_inputs = args.n_inputs + args.n_extra
        n_inputs = n_inputs if n_inputs != -1 else args.n_inputs

        self.adaptive = args.adaptive
        self.query_model = self._create_query_models(args, n_inputs)

    def _create_query_models(self, args, n_inputs):
        if args.adaptive:
            return nn.ModuleList([
                MLP(
                    i*args.dim + args.dim + n_inputs*i, 
                    args.query_mlp_hidden_dim, 
                    args.n_inputs, 
                    num_layers=args.query_mlp_num_layers, 
                    act_fn=args.query_mlp_act_fn
                ) for i in range(args.max_queries)
            ])
        else:
            return nn.ModuleList([
                MLP(
                    args.dim, 
                    args.query_mlp_hidden_dim, 
                    args.n_inputs, 
                    num_layers=args.query_mlp_num_layers, 
                    act_fn=args.query_mlp_act_fn
                ) for i in range(args.max_queries)
            ])

    def forward(self, base_query, queries, masks, i):
        if i == 0 or not self.adaptive:
            return self.query_model[i](base_query)
        
        query_in = [q.flatten(start_dim=1) for q in queries]
        ins = torch.cat([base_query.flatten(start_dim=1)] + query_in + masks, dim=1)
        
        return self.query_model[i](ins)
    

class FeatureModel(nn.Module):
    def __init__(self):
        super(FeatureModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(in_features=64*7*21, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x[:, None])))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64*7*21)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    
class CSMLPQueryModel(nn.Module):
    def __init__(self, args):
        super(CSMLPQueryModel, self).__init__()
        arch = args.query_model_arch
        max_queries = args.max_queries
        query_hidden_size = args.query_mlp_hidden_dim
        query_act_fn = args.query_mlp_act_fn
        query_num_layers = args.query_mlp_num_layers
        
        if args.pred_network != 'default':
            self.pred_model = MLP(args.max_queries, query_hidden_size, 1 if args.pred_network == 'scalar' else args.max_queries, query_num_layers, query_act_fn, layernorm=True)

        if args.adaptive == False:
            self.query_model = nn.ModuleList([MLP(args.n_embd, query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(max_queries)])
        elif arch == 'mlp_all_mask_all_query':
            self.query_model = nn.ModuleList([MLP(i + args.n_embd + args.n_state*i, query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(max_queries)])
        elif arch == 'mlp_last_mask_all_query':
            self.query_model = nn.ModuleList([MLP(i + args.n_embd + args.n_state*(i != 0), query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(max_queries)])
        elif arch == 'mlp_last_mask_last_query':
            self.query_model = nn.ModuleList([MLP((i != 0)*args.dim + args.n_embd  + args.n_state*(i != 0), query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(max_queries)])
        elif arch == 'mlp_last_mask_last_query_shared':
            self.query_model = nn.ModuleList([MLP((i != 0) + 1 + args.n_state*(i != 0), query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(2)])
        elif arch == 'mlp_all_mask_last_query':
            self.query_model = nn.ModuleList([MLP((i != 0) + 1 + args.n_state*i, query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(max_queries)])
        elif arch == 'mlp_no_mask_all_query':
            self.query_model = nn.ModuleList([MLP(i + 1, query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(max_queries)])
        elif arch == 'mlp_no_mask_last_query':
            self.query_model = nn.ModuleList([MLP((i != 0) + 1, query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(max_queries)])
        elif arch == 'mlp_only_query':
            self.query_model = nn.ModuleList([MLP(args.n_embd, query_hidden_size, args.n_state, num_layers=query_num_layers, act_fn=query_act_fn) for i in range(max_queries)])
        else:
            raise Exception(f'{arch} arch not supported!')
        self.arch = arch
        # self.query_model[0] = ZeroQueryModel(args)
    def forward(self, base_query, queries, masks, i):
        B = base_query.size(0)
        if i == 0 or self.adaptive == False or self.arch == 'mlp_only_query':
            return self.query_model[i](base_query)
        else:
            if self.arch == 'mlp_all_mask_all_query':
                query_in = queries
                mask_in = masks
            elif self.arch == 'mlp_last_mask_all_query':
                query_in = queries
                mask_in = [masks[-1]]
            elif self.arch == 'mlp_last_mask_last_query':
                query_in = [queries[-1]]
                mask_in = [masks[-1]]
            elif self.arch == 'mlp_last_mask_last_query_shared':
                query_in = [queries[-1]]
                mask_in = [masks[-1]]
                i = 1
            elif self.arch == 'mlp_all_mask_last_query':
                query_in = [queries[-1]]
                mask_in = masks
            elif self.arch == 'mlp_no_mask_all_query':
                query_in = queries
                mask_in = []
            elif self.arch == 'mlp_no_mask_last_query':
                query_in = [queries[-1]]
                mask_in = []
            query_in = [q.flatten(start_dim=1) for q in query_in]
            ins = torch.cat([base_query.flatten(start_dim=1)] + query_in + mask_in, dim=1)
            return self.query_model[i](ins)



class CSDatastructureModel(nn.Module):
    def __init__(self, d_in, n_state, n_layer=4, n_queries=1, fix_value=False):
        super(CSDatastructureModel, self).__init__()

        self.fix_value = fix_value
        if fix_value:
            self.update_value = nn.Parameter(torch.tensor(1.), requires_grad=True)
        self.models = nn.ModuleList([MLP(d_in, 1024, n_state + 1, n_layer, 'relu') for i in range(n_queries)])

    def forward(self, hard, temp, embeds=None, stream=None, noise_scale=1.):
        B, N = stream.shape
        all_states = []
        all_masks = []
        all_updates = []
        all_outs = []
        for m in self.models:
            all_embeds = m(embeds)
            outs = all_embeds[stream.flatten()].view(B, N, -1)
            all_outs.append(all_embeds)
            if hard == True:
                masks = outs[:, :, :-1]
                masks = torch.nn.functional.one_hot(masks.argmax(dim=-1), masks.size(-1)).float()
            else:
                masks = gumbel_softmax(outs[:, :, :-1], temperature = temp, hard=False, noise_scale=noise_scale)


            if self.fix_value:
                update_val = self.update_value.view(1, 1, 1).repeat(B, N, 1)
            else:
                update_val = outs[:, :, [-1]]

            updates = (masks*update_val)
            states = updates.cumsum(dim=1)
            all_states.append(states)
            all_masks.append(masks)
            all_updates.append(update_val)
        all_states = torch.stack(all_states, dim=0)
        all_states = all_states.sum(0)
        all_masks = torch.stack(all_masks, dim=1)
        all_updates = torch.stack(all_updates, dim=1)
        return all_states, all_masks, all_updates, all_outs


class Embedder(nn.Module):
    def __init__(self, n_vals, n_embd):
        super(Embedder, self).__init__()
        self.n_embd = n_embd
        if self.n_embd > 1:
            self.embeddings = nn.Embedding(n_vals, n_embd)

    def forward(self, input):
        if self.n_embd > 1:
            return self.embeddings(input)
        else:
            return input[..., None].float()