import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils import restore_parameters
from copy import deepcopy


def make_onehot(x):
    '''Make an approximately one-hot vector one-hot.'''
    argmax = torch.argmax(x, dim=1)
    onehot = torch.zeros(x.shape, dtype=x.dtype, device=x.device)
    onehot[torch.arange(len(x)), argmax] = 1
    return onehot


class Pretrainer(nn.Module):
    '''Pretrain model with missing features.'''

    def __init__(self, model, mask_layer):
        super().__init__()
        self.model = model
        self.mask_layer = mask_layer
        
    def fit(self,
            train,
            val,
            mbsize,
            lr,
            nepochs,
            max_features,
            loss_fn,
            val_loss_fn=None,
            verbose=True):
        '''Train model.'''
        # Set up data loaders.
        train_loader = DataLoader(
            train, batch_size=mbsize, shuffle=True, pin_memory=True,
            drop_last=True, num_workers=4)
        val_loader = DataLoader(
            val, batch_size=mbsize, shuffle=False, pin_memory=True,
            drop_last=False, num_workers=4)
        
        # More setup.
        model = self.model
        mask_layer = self.mask_layer
        device = next(model.parameters()).device
        opt = optim.Adam(model.parameters(), lr=lr)
        if val_loss_fn is None:
            val_loss_fn = loss_fn

        # For tracking best model.
        best_model = None
        best_loss = float('inf')
            
        for epoch in range(nepochs):
            for x, y in train_loader:
                # Move to device.
                x = x.to(device)
                y = y.to(device)
                
                # Generate missingness.
                m = (torch.rand(x.shape, device=device)
                     < (max_features / x.shape[1])).float()

                # Evaluate model.
                x_masked = mask_layer(x, m)
                pred = model(x_masked)

                # Calculate loss.
                loss = loss_fn(pred, y)

                # Take gradient step.
                loss.backward()
                opt.step()
                model.zero_grad()
                
            # Calculate validation loss.
            with torch.no_grad():
                # For mean loss.
                val_loss = 0
                n = 0

                for x, y in val_loader:
                    # Move to device.
                    x = x.to(device)
                    y = y.to(device)
                    
                    # Generate missingness.
                    # TODO this is not ideal, we should precompute this
                    m = (torch.rand(x.shape, device=device)
                         < (max_features / x.shape[1])).float()

                    # Evaluate model.
                    x_masked = mask_layer(x, m)
                    pred = model(x_masked)

                    # Calculate loss.
                    loss = val_loss_fn(pred, y).item()
                        
                    # Update mean loss.
                    val_loss = (loss * len(x) + val_loss * n) / (len(x) + n)
                    n += len(x)
            
            # Print progress.
            print(f'{"-"*8}Epoch {epoch+1}{"-"*8}')
            print(f'Val loss = {val_loss:.4f}\n')

            # See if best model.
            if val_loss < best_loss:
                best_loss = val_loss
                best_model = deepcopy(model)

        # Restore best parameters.
        restore_parameters(model, best_model)
            
            
class GreedyAdaptiveFS(nn.Module):
    '''Greedy adaptive feature selection.'''

    def __init__(self, selector, model, mask_layer, selector_layer):
        super().__init__()
        self.selector = selector
        self.model = model
        self.mask_layer = mask_layer
        self.selector_layer = selector_layer
    
    def fit(self,
            train,
            val,
            mbsize,
            lr,
            nepochs,
            max_features,
            loss_fn,
            val_loss_fn=None,
            train_model=True,
            train_selector=True,
            start_temp=10.0,
            end_temp=0.01,
            argmax=True,
            no_repeats=True,
            validation_mode='final',
            verbose=True):
        '''Train models.'''
        # Set up data loaders.
        train_loader = DataLoader(
            train, batch_size=mbsize, shuffle=True, pin_memory=True,
            drop_last=True, num_workers=4)
        val_loader = DataLoader(
            val, batch_size=mbsize, shuffle=False, pin_memory=True,
            drop_last=False, num_workers=4)
        
        # More setup.
        selector = self.selector
        model = self.model
        mask_layer = self.mask_layer
        selector_layer = self.selector_layer
        device = next(model.parameters()).device
        assert validation_mode in ('final', 'mean')
        assert train_model or train_selector
        if train_model:
            model_opt = optim.Adam(model.parameters(), lr=lr)
        if train_selector:
            selector_opt = optim.Adam(selector.parameters(), lr=lr)
        if val_loss_fn is None:
            val_loss_fn = loss_fn
        
        # Temperature setup.
        r = (end_temp / start_temp) ** (1 / (nepochs * len(train_loader)))
        temp = start_temp

        # For tracking best model.
        best_model = None
        best_selector = None
        best_loss = float('inf')
        
        for epoch in range(nepochs):
            for x, y in train_loader:
                # Move to device.
                x = x.to(device)
                y = y.to(device)
                
                # Setup.
                m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)
                total_loss = 0
                
                for i in range(max_features):
                    # Evaluate selector model.
                    x_masked = mask_layer(x, m_hard)
                    logits = selector(x_masked)
                    
                    # Get selections.
                    soft = selector_layer(logits, temp)
                    
                    # Evaluate model.
                    m_soft = torch.max(m_hard, soft)
                    x_masked = mask_layer(x, m_soft)
                    pred = model(x_masked)
                    
                    # Calculate loss.
                    loss = loss_fn(pred, y)
                    total_loss = total_loss + loss
                    
                    # Update hard selections.
                    hard = make_onehot(soft)
                    m_hard = torch.max(m_hard, hard)
                    
                # Take gradient step.
                total_loss = total_loss / max_features
                total_loss.backward()
                if train_model:
                    model_opt.step()
                if train_selector:
                    selector_opt.step()
                model.zero_grad()
                selector.zero_grad()
                temp *= r
                
            # Calculate validation loss.
            with torch.no_grad():
                # For mean loss.
                val_loss = 0
                n = 0

                for x, y in val_loader:
                    # Move to device.
                    x = x.to(device)
                    y = y.to(device)

                    # Setup.
                    m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)
                    total_loss = 0

                    for i in range(max_features):
                        # Evaluate selector model.
                        x_masked = mask_layer(x, m_hard)
                        logits = selector(x_masked)

                        # Get selections.
                        if no_repeats:
                            logits = logits - 1e6 * m_hard
                        if argmax:
                            hard = make_onehot(logits)
                        else:
                            hard = make_onehot(selector_layer(logits, 1e-6))

                        # Evaluate model.
                        m_hard = torch.max(m_hard, hard)
                        x_masked = mask_layer(x, m_hard)
                        pred = model(x_masked)

                        # Calculate loss.
                        loss = val_loss_fn(pred, y).item()
                        total_loss = total_loss + loss
                        
                    # Update mean loss.
                    if validation_mode == 'final':
                        batch_loss = loss
                    elif validation_mode == 'mean':
                        batch_loss = total_loss / max_features
                    val_loss = (
                        (batch_loss * len(x) + val_loss * n) / (len(x) + n))
                    n += len(x)
            
            # Print progress.
            print(f'{"-"*8}Epoch {epoch+1}{"-"*8}')
            print(f'Val loss = {val_loss:.4f}\n')

            # See if best model.
            if val_loss < best_loss:
                best_loss = val_loss
                best_model = deepcopy(model)
                best_selector = deepcopy(selector)

        # Restore best parameters.
        if best_model:
            restore_parameters(model, best_model)
        if best_selector:
            restore_parameters(selector, best_selector)

    def forward(self, x, max_features, argmax=True, no_repeats=True):
        '''
        Make predictions using selected features.

        Args:
          x: input data (torch.Tensor).
          max_features: max features to observe.
          argmax: whether to select the next feature using the max probability.
          no_repeats: whether to ensure that no repeated selections occur.
        '''
        # Setup.
        selector = self.selector
        model = self.model
        mask_layer = self.mask_layer
        selector_layer = self.selector_layer
        device = next(model.parameters()).device
        m_hard = torch.zeros(x.shape, dtype=x.dtype, device=device)

        for i in range(max_features):
            # Evaluate selector model.
            x_masked = mask_layer(x, m_hard)
            logits = selector(x_masked)

            # Update selections.
            if no_repeats:
                logits = logits - 1e6 * m_hard
            if argmax:
                hard = make_onehot(logits)
            else:
                hard = make_onehot(selector_layer(logits, 1e-6))
            m_hard = torch.max(m_hard, hard)

        # Make predictions.
        x_masked = mask_layer(x, m_hard)
        pred = model(x_masked)
        return pred

    def evaluate(self, dataset, max_features, loss_fn, batch_size, argmax=True,
                 no_repeats=True):
        '''Evaluate mean performance across a dataset.'''
        # Setup.
        device = next(self.model.parameters()).device
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
            drop_last=False, num_workers=4)

        # For calculating mean loss.
        mean_loss = 0
        n = 0

        with torch.no_grad():
            for x, y in loader:
                # Move to GPU.
                x = x.to(device)
                y = y.to(device)

                # Calculate loss.
                pred = self.forward(x, max_features, argmax, no_repeats)
                loss = loss_fn(pred, y).item()

                # Update average.
                mean_loss = (mean_loss * n + loss * len(x)) / (n + len(x))
                n += len(x)

        return mean_loss
            
            
class GlobalSelector(nn.Module):
    '''Global feature selection.'''

    def __init__(self, model, selector_layer):
        super().__init__()
        self.model = model
        self.selector_layer = selector_layer
    
    def fit(self,
            train,
            val,
            mbsize,
            lr,
            nepochs,
            loss_fn,
            val_loss_fn=None,
            start_temp=10.0,
            end_temp=0.01,
            verbose=True):
        '''
        Train model.
        '''
        # Set up data loaders.
        train_loader = DataLoader(
            train, batch_size=mbsize, shuffle=True, pin_memory=True,
            drop_last=True, num_workers=4)
        val_loader = DataLoader(
            val, batch_size=mbsize, shuffle=False, pin_memory=True,
            drop_last=False, num_workers=4)
        
        # More setup.
        model = self.model
        selector_layer = self.selector_layer
        device = next(model.parameters()).device
        opt = optim.Adam(
            list(model.parameters()) + list(selector_layer.parameters()), lr=lr)
        if val_loss_fn is None:
            val_loss_fn = loss_fn

        # Temperature setup.
        r = (end_temp / start_temp) ** (1 / (nepochs * len(train_loader)))
        temp = start_temp

        # For tracking best model.
        best_model = None
        best_selector = None
        best_loss = float('inf')
            
        for epoch in range(nepochs):
            for x, y in train_loader:
                # Move to device.
                x = x.to(device)
                y = y.to(device)
                
                # Select features and make prediction.
                x_masked = selector_layer(x, temp)
                pred = model(x_masked)

                # Calculate loss.
                loss = loss_fn(pred, y)

                # Take gradient step.
                loss.backward()
                opt.step()
                model.zero_grad()
                selector_layer.zero_grad()
                temp *= r
                
            # Calculate validation loss.
            with torch.no_grad():
                # For mean loss.
                val_loss = 0
                n = 0

                for x, y in val_loader:
                    # Move to device.
                    x = x.to(device)
                    y = y.to(device)

                    # Evaluate model.
                    x_masked = selector_layer(x, 1e-6)
                    pred = model(x_masked)

                    # Calculate loss.
                    loss = val_loss_fn(pred, y).item()

                    # Update mean loss.
                    val_loss = (loss * len(x) + val_loss * n) / (len(x) + n)
                    n += len(x)
            
            # Print progress.
            print(f'{"-"*8}Epoch {epoch+1}{"-"*8}')
            print(f'Val loss = {val_loss:.4f}\n')

            # See if best model.
            if val_loss < best_loss:
                best_loss = val_loss
                best_model = deepcopy(model)
                best_selector = deepcopy(selector_layer)

        # Restore best parameters.
        restore_parameters(model, best_model)
        restore_parameters(selector_layer, best_selector)

    def forward(self, x):
        '''
        Make predictions with selected features.

        Args:
          x: input data (torch.Tensor).
        '''
        x_masked = self.selector_layer(x, 1e-6)
        pred = self.model(x_masked)
        return pred

    def evaluate(self, dataset, loss_fn, batch_size):
        '''Evaluate mean performance across a dataset.'''
        # Setup.
        device = next(self.model.parameters()).device
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=False, pin_memory=True,
            drop_last=False, num_workers=4)

        # For calculating mean loss.
        mean_loss = 0
        n = 0

        with torch.no_grad():
            for x, y in loader:
                # Move to GPU.
                x = x.to(device)
                y = y.to(device)

                # Calculate loss.
                pred = self.forward(x)
                loss = loss_fn(pred, y).item()

                # Update average.
                mean_loss = (mean_loss * n + loss * len(x)) / (n + len(x))
                n += len(x)

        return mean_loss
