import os
import random
import pickle

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from utils import nuclear_norm


class Network(nn.Module):
    def __init__(self, params, pretrained_model=None):
        super(Network, self).__init__()

        self.params = params
        n_hidden = self.params['n_hidden']
        dim_output = self.params['dim_output']
        dim_input = self.params['dim_input']

        # Network
        if self.params['rnn_type'] == 'vrnn':
            self.recurrent = nn.RNN(dim_input, n_hidden, batch_first=True)
        elif self.params['rnn_type'] == 'gru':
            self.recurrent = nn.GRU(dim_input, n_hidden, batch_first=True)
        else:
            raise ValueError('rnn_type must be one of [vrnn, gru] but was %s' % self.params['rnn_type'])

        self.fc_out = nn.Linear(n_hidden, dim_output)

        # Initialize recurrent layer based on `small_weight_init`
        if self.params.get('small_weight_init', False):
            for name, param in self.recurrent.named_parameters():
                if 'weight' in name:
                    nn.init.uniform_(param, a=-self.params['init_sigma'], b=self.params['init_sigma'])
                elif 'bias' in name:
                    nn.init.constant_(param, 0)
        else:
            nn.init.kaiming_uniform_(self.recurrent.weight_ih_l0, a=0, nonlinearity='relu')
            nn.init.orthogonal_(self.recurrent.weight_hh_l0)
            self.recurrent.bias_ih_l0.data.fill_(0)
            self.recurrent.bias_hh_l0.data.fill_(0)

        nn.init.xavier_uniform_(self.fc_out.weight)
        nn.init.zeros_(self.fc_out.bias)

        if pretrained_model is not None:
            with torch.no_grad():
                self.recurrent._parameters['weight_hh_l0'] = pretrained_model.recurrent._parameters['weight_hh_l0'].clone()
                self.recurrent._parameters['bias_hh_l0'] = pretrained_model.recurrent._parameters['bias_hh_l0'].clone()

        self.to(self.params['device'])
        self.recurrent.to(self.params['device'])
        
        if self.params['trainable_ratio'] != 1.0:
            self.trainable_mask = self._set_trainable_weights(self.params['trainable_ratio'])

    def _set_trainable_weights(self, trainable_ratio):
        trainable_mask = {}
        for name, param in self.named_parameters():
            if 'weight_hh' in name:
                total_params = param.numel()
                num_trainable = int(trainable_ratio * total_params)

                indices = random.sample(range(total_params), num_trainable)

                mask = torch.zeros(total_params, dtype=torch.bool)
                mask[indices] = True
                trainable_mask[name] = mask.view(param.shape)

        return trainable_mask

    def apply_trainable_mask(self):
        for name, param in self.named_parameters():
            if name in self.trainable_mask:
                mask = self.trainable_mask[name]
                if param.grad is not None:
                    param.grad.data *= mask.to(param.device)

    def forward(self, inp, h):
        out, h = self.recurrent(inp, h)
        out = self.fc_out(out) / self.params['gamma']
        return out, h

    def train(self, loader, epochs, steps_per_epoch, optimizer=None, early_stopping_threshold=None, patience=None):
        if optimizer is None:
            optimizer = torch.optim.Adam(
                self.parameters(), 
                lr=self.params['lr'], 
                weight_decay=self.params.get('weight_decay', 0.0)
            )
        
        print('Using weight decay of', optimizer.param_groups[0]['weight_decay'])
        
        loss_function = nn.MSELoss()
        
        if self.params['lr_scheduler'] == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1000, T_mult=1, eta_min=0)
        elif self.params['lr_scheduler'] == 'reduce':
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=20, threshold=1e-2, verbose=True)
        elif self.params['lr_scheduler'] == 'step':
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.8)

        self.train_history = []
        self.delta_W = {name: [] for name, param in self.named_parameters() if param.requires_grad}

        no_improvement_count = 0
        loss = 0
        for epoch in range(epochs):
            for step in range(steps_per_epoch):
                x_batch, y_batch = next(iter(loader))
                x_batch = x_batch.to(self.params['device'])
                y_batch = y_batch.to(self.params['device'])

                h = torch.zeros([1, x_batch.size(0), self.params['n_hidden']], dtype=torch.float32).to(self.params['device'])

                with torch.no_grad():
                    old_weights = {name: param.clone() for name, param in self.named_parameters() if param.requires_grad}

                optimizer.zero_grad()
                output, _ = self(x_batch, h)
            
                loss = loss_function(output, y_batch)

                loss += self.params['W_rank_reg'] * nuclear_norm(self.recurrent.weight_hh_l0)

                l1_norm = sum(p.abs().sum() for p in self.parameters())
                loss += self.params['W_l1_reg'] * l1_norm

                loss.backward()

                # for name, param in self.named_parameters():
                #     print(param.grad)
                
                torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)

                if self.params['trainable_ratio'] != 1.0:
                    self.apply_trainable_mask()

                optimizer.step()

                self.track_delta_W(old_weights)

                if epoch == 0 and step == 0:
                    self.train_history.append(loss.cpu().data.numpy())

            if self.params['lr_scheduler'] is not None:
                scheduler.step(epoch)
            loss = loss.cpu().data.numpy()
            self.train_history.append(loss)
            print(f'train loss epoch {epoch+1}: {loss}')

            if early_stopping_threshold is not None and patience is not None:
                if loss < early_stopping_threshold:
                    no_improvement_count += 1
                    if no_improvement_count >= patience:
                        print(f'Early stopping at epoch {epoch+1} as the loss has been below {early_stopping_threshold} for {patience} epochs.')
                        break
                else:
                    no_improvement_count = 0

        return loss

    def get_test_loss(self, X, Y):
        X = X.to(self.params['device'])
        Y = Y.to(self.params['device'])
        loss_function = nn.MSELoss()
        h = torch.zeros([1, X.size(0), self.params['n_hidden']], dtype=torch.float32).to(self.params['device'])
        output, _ = self(X, h)
        loss = loss_function(output, Y)
        return loss.cpu().data.numpy()

    def predict(self, inputs):
        n_batch = inputs.shape[0]
        h = torch.zeros([1, n_batch, self.params['n_hidden']], dtype=torch.float32)

        outputs, h_final = self(inputs, h)
        outputs = outputs.detach().numpy()
        h_final = h_final.detach().numpy()
        return outputs, h_final

    def get_activations(self, inputs):
        n_batch = inputs.shape[0]
        n_rows = inputs.shape[1]
        h = torch.zeros([1, n_batch, self.params['n_hidden']], dtype=torch.float32).to(self.params['device'])

        outputs = []
        activations = []

        for idx in range(n_rows):
            x = inputs[:, idx, :].unsqueeze(1)
            y, h = self(x, h)
            outputs.append(y.detach().cpu().numpy())
            activations.append(h.detach().cpu().numpy())

        activations = np.array(activations)[:, 0, :, :]
        activations = np.swapaxes(activations, 0, 1)

        return outputs, activations

    def save_measures(self, activations, eval_name=None, initial_save=False):
        filename = f"seed_{self.params['seed']}"

        if type(self.params['task_name']) is not list:
            path = os.path.join(self.params['save_path'], self.params['task_name'])
        else:
            path = os.path.join(self.params['save_path'], 'dd_curriculum')

        if eval_name is None:
            netpath = os.path.join(path, 'weights')
            hxspath = os.path.join(path, 'hxs')
            losspath = os.path.join(path, 'losses')
            dwpath = os.path.join(path, 'delta_w')
        else:
            netpath = os.path.join(path, eval_name, 'weights')
            hxspath = os.path.join(path, eval_name, 'hxs')
            losspath = os.path.join(path, eval_name, 'losses')
            dwpath = os.path.join(path, eval_name, 'delta_w')

        os.makedirs(netpath, exist_ok=True)
        os.makedirs(hxspath, exist_ok=True)
        os.makedirs(losspath, exist_ok=True)
        os.makedirs(dwpath, exist_ok=True)

        if initial_save:
            np.save(os.path.join(hxspath, filename + '_initial.npy'), activations)
        else:
            torch.save(self.state_dict(), os.path.join(netpath, filename + '.pt'))
            np.save(os.path.join(hxspath, filename + '.npy'), activations)
            np.save(os.path.join(losspath, filename + '.npy'), self.train_history)
            with open(os.path.join(dwpath, filename + '.pickle'), 'wb') as f:
                pickle.dump(self.delta_W, f)

    def track_delta_W(self, old_weights):
        with torch.no_grad():
            for name, param in self.named_parameters():
                if param.requires_grad:
                    delta_W = torch.norm(param - old_weights[name]).item()
                    self.delta_W[name].append(delta_W)

                    
                    

class Network_muP(Network):
    def __init__(self, params):
        super(Network_muP, self).__init__(params)

        # Freeze weights initialized in the parent class
        for name, param in self.named_parameters():
            if "recurrent" in name or "fc_out" in name:
                param.requires_grad = False
                
        self.N = params['n_hidden']  # Number of hidden units
        self.tau = params['tau']  # Time constant
        self.g = params['gain']  # Gain scaling factor
        self.gamma = params['gamma']  # Controls the amount of feature learning

        # Initialize trainable matrices
        self.J = nn.Parameter(torch.randn(self.N, self.N) * (self.g * (self.N ** 0.5)))                 # recurrent weights
        self.U = nn.Parameter(torch.randn(self.N, params['dim_input']) * (self.g * (self.N ** 0.5)))    # input weights
        self.w_readout = nn.Parameter(torch.randn(params['dim_output'], self.N) * (self.g))             # readout weights
        
        # Activation function
        self.phi = torch.tanh        

        self.to(self.params['device'])

    def forward(self, x, h):
        """
        Args:
            x: Tensor of shape [batch_size, seq_len, input_size].
            h: Initial hidden state of shape [1, batch_size, hidden_size].

        Returns:
            outputs: Tensor of shape [batch_size, seq_len, output_size].
            h: Final hidden state of shape [1, batch_size, hidden_size].
        """
        batch_size, seq_len, _ = x.shape

        h = h.squeeze(0)
        outputs = []

        for t in range(seq_len):
            x_t = x[:, t, :]

            h = h + self.tau * (
                -h
                + (1 / self.N) * self.phi(h) @ self.J.T
                + x_t @ self.U.T
            )

            f_t = (1 / (self.N * self.gamma)) * (self.w_readout @ self.phi(h).T).T

            outputs.append(f_t.unsqueeze(1))  # becomes [batch_size, 1, output_size]

        outputs = torch.cat(outputs, dim=1)
        # print(outputs)
        h = h.unsqueeze(0)

        return outputs, h
