import numpy as np
from tqdm import tqdm
import torch
import torch
import torch.nn as nn


global_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")   

def get_activation(F):
    f = F()
    def activation(x):
        return f(x)
    return activation

def null_activation(x):
    return x

class RNN(nn.Module):
    def __init__(self, tau, weight_matrix, init_state, output_weight_matrix, 
                 input_weight_matrix, state_bias, output_bias,
                 output_nonlinearity=null_activation,
                 activation_func=nn.ReLU(), train=['init_state'], device=global_device):
        '''
        Initializes an instance of the RNN class. 
        '''
        super().__init__()
        self.tau = tau
        # Basic tests to ensure correct input shapes.
        assert len(weight_matrix.shape) == 2
        assert weight_matrix.shape[0] == weight_matrix.shape[1]
        N = weight_matrix.shape[0]
        assert N == init_state.shape[0]

        assert len(output_weight_matrix.shape) == 2
        assert output_weight_matrix.shape[1] == N

        assert(len(input_weight_matrix.shape) == 2)
        assert input_weight_matrix.shape[0] == N

        assert(state_bias.ndim == 1)
        assert(state_bias.shape[0] == N)
        assert(output_bias.ndim == 1)
        assert(output_bias.shape[0] == output_weight_matrix.shape[0])

        assert(np.all([t in ['weights', 'outputs', 'inputs', 'init_state', 'state_bias', 'output_bias']
                       for t in train]))

        if 'weights' in train:
            self.weight_matrix = nn.Parameter(torch.tensor(
                weight_matrix, dtype=torch.float32, requires_grad=True))
        else:
            self.weight_matrix = torch.tensor(
                weight_matrix, dtype=torch.float32, requires_grad=False)
        
        if 'outputs' in train:
            self.output_weight_matrix = nn.Parameter(torch.tensor(
                output_weight_matrix, dtype=torch.float32, requires_grad=True))
        else:
            self.output_weight_matrix = torch.tensor(
                output_weight_matrix, dtype=torch.float32, requires_grad=False)

        # Fix the input matrix to the identity
        if 'inputs' in train:
            self.input_weight_matrix = nn.Parameter(torch.tensor(
                input_weight_matrix, dtype=torch.float32, requires_grad=True))                
        else:
            self.input_weight_matrix = torch.tensor(input_weight_matrix, dtype=torch.float32, requires_grad=False)

        if 'init_state' in train:
            self.init_state = nn.Parameter(torch.tensor(init_state, dtype=torch.float32,
                                                requires_grad=True))
        else:
            self.init_state = torch.tensor(init_state, dtype=torch.float32,
                                                requires_grad=False)

        if 'state_bias' in train:
            self.state_bias = nn.Parameter(torch.tensor(
                state_bias, dtype=torch.float32, requires_grad=True))
        else:
            self.state_bias = torch.tensor(state_bias, dtype=torch.float32, requires_grad=False)

        if 'output_bias' in train:
            self.output_bias = nn.Parameter(torch.tensor(
                output_bias, dtype=torch.float32, requires_grad=True))
        else:
            self.output_bias = torch.tensor(output_bias, dtype=torch.float32, requires_grad=False)

 
        self.activation_func = activation_func
        self.output_nonlinearity = output_nonlinearity
        self.num_nodes = self.weight_matrix.shape[0]
        self.num_outputs = self.output_weight_matrix.shape[0]

    def to_device(self, device):
        self.to(device)
        for attr in dir(self):
            if not attr.startswith('__'):
                if isinstance(getattr(self, attr), torch.Tensor):
                    setattr(self, attr, getattr(self, attr).to(device))

    def detach(self):
        for attr in dir(self):
            if not attr.startswith('__'):
                if isinstance(getattr(self, attr), torch.Tensor):
                    val = getattr(self, attr)
                    if isinstance(val, nn.Parameter):
                        setattr(self, attr, nn.Parameter(val.detach()))
                    else:
                        setattr(self, attr, val.detach())                    

    # Over-ridden by Dale RNN
    def weight_matrix_(self):
        return self.weight_matrix

    def forward(self, inputs, dt, disable_progress_bar=False, return_state=False):
        batch_size = inputs.shape[0]
        sequence_length = inputs.shape[1]
        
        compiled_outputs = []
        compiled_state = []
        xtm1 = self.init_state.unsqueeze(0).repeat(batch_size, 1)  # [batch_size, state_size]
        self.activation_func(xtm1 @ self.weight_matrix_().unsqueeze(0))

        for t in tqdm(range(sequence_length), position=0, leave=True, disable=disable_progress_bar):

            # Euler step (using broadcasting)
            x = xtm1 + dt/self.tau * (
                -xtm1 + 
                torch.matmul(self.activation_func(xtm1).unsqueeze(1), 
                             self.weight_matrix_().unsqueeze(0)).squeeze() +
                torch.matmul(inputs[:, t, :], self.input_weight_matrix.T)
                + self.state_bias
            )

            outputs = self.output_nonlinearity(
                torch.matmul(self.output_weight_matrix.unsqueeze(0), 
                             self.activation_func(x).unsqueeze(-1)).squeeze() + 
                self.output_bias
            )

            compiled_outputs.append(outputs)
            compiled_state.append(x)
            xtm1 = x
        
        compiled_outputs = torch.stack(compiled_outputs, dim=1)  # [batch_size, sequence_length, output_size]
        compiled_state = torch.stack(compiled_state, dim=1)
        if return_state:
            return compiled_outputs, compiled_state
        else:
            return compiled_outputs


class DaleRNN(RNN):

    def __init__(self, tau, weight_matrix, init_state, output_weight_matrix, 
                 input_weight_matrix, state_bias, output_bias,
                 output_nonlinearity=null_activation,
                 activation_func=nn.ReLU(), train=['init_state']):

        super().__init__(tau, weight_matrix, init_state, output_weight_matrix, 
                         input_weight_matrix, state_bias, output_bias,
                         output_nonlinearity,
                         activation_func, train)

        if 'weights' in train:
            # Remove weight matrix as an attribute (set during super init)
            delattr(self, 'weight_matrix')
            # Freeze the signs
            weight_matrix = torch.tensor(weight_matrix, dtype=torch.float32)
            absW = torch.abs(weight_matrix)
            sgnW = torch.sign(weight_matrix)
            logW = torch.log(absW)
            self.weight_matrix_base = nn.Parameter(logW)
            self.weight_matrix_sgn = sgnW
        else:
            weight_matrix = torch.tensor(weight_matrix, dtype=torch.float32)
            absW = torch.abs(weight_matrix)
            sgnW = torch.sign(weight_matrix)
            logW = torch.log(absW)
            self.weight_matrix_base = logW
            self.weight_matrix_sgn = sgnW

    def weight_matrix_(self):
        return torch.multiply(self.weight_matrix_sgn, torch.exp(self.weight_matrix_base))

def loss_fn(y, target_seq, start_index=0):
    dy = torch.norm(y[:, start_index:] - target_seq[:, start_index:], dim=-1)**2
    loss = torch.mean(dy)
    return loss

# Assumes a batched input and target (as does the RNN forward function)
def train_model(model, target_seq, inputs, pretrain_inputs=None, 
                num_epochs=1000, lr=0.01, dt=1e-2, start_index=0, 
                lr_scheduler=None, device=None):

    if device is None:
        device = global_device

    target_seq = torch.tensor(target_seq, dtype=torch.float32, requires_grad=False)
    inputs = torch.tensor(inputs, dtype=torch.float32, requires_grad=False)
    # Transfer to device
    # model.to(device)
    model.to_device(device) 
    target_seq = target_seq.to(device)
    inputs = inputs.to(device)
    if pretrain_inputs is not None:
        pretrain_inputs = torch.tensor(pretrain_inputs, dtype=torch.float32, requires_grad=False)    
        pretrain_inputs = pretrain_inputs.to(device)

        # Run the model forward in response to the pretrain_inputs prior to training
        _, x = model.forward(pretrain_inputs, dt=dt, return_state=True)
        # Replace the initial state
        model.init_state = x
        model.detach()

    # Create an optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    if lr_scheduler is not None:    
        scheduler = lr_scheduler(optimizer)
    else:
        scheduler = None


    # Training loop
    for epoch in tqdm(range(num_epochs)):
        optimizer.zero_grad()
        # Forward pass
        y_seq = model.forward(inputs, dt=dt, disable_progress_bar=True)
        # Compute loss
        loss = loss_fn(y_seq, target_seq, start_index)
        # Backward pass and optimization
        loss.backward(retain_graph=True)
        optimizer.step()
        optimizer.zero_grad()        
        # Optional: Print loss every 100 epochs
        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
    
        if scheduler is not None:
            scheduler.step()

    return model

