import torch
import torch.nn as nn

from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, TensorDataset

import matplotlib.pyplot as plt
import numpy as np

from .nets import FNN, TCN

class Hedger(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = None

    def hedge(self, x):
        """ Compute hedging strategy.

        Arguments:
        ----------
        x : torch.tensor, shape (batch_size, n_features, T)
            Input data.

        Returns:
        --------
        hedge : torch.tensor, shape (batch_size, T-1)
            Hedging strategies.
        """

        raise NotImplementedError

    def loss(self, x, prices_index, premium, payoff):
        """ Compute hedging loss.

        Arguments:
        ----------
        x : torch.tensor, shape (batch_size, n_features, T)
            Input data.
        prices_index : int
            Index of price feature along dimension 1 (features' dimension).
        premium : torch.tensor, shape ([])
            Option premium (i.e. price charged at t=1).
        payoff : callable
            Option payoff function.
        
        Returns:
        --------
        loss : torch.tensor, shape ([])
            Average loss over batch.
        """
        
        hedge = self.hedge(x)
        prices = x[:, prices_index, :]
        hedge_gain = torch.sum(torch.diff(prices, dim=1)*hedge, dim=1)
        loss = torch.mean((payoff(prices[:, -1]) - hedge_gain - premium)**2)
        
        return loss

    def train_model(self, dataset, prices_index, premium, payoff, epochs=1000, batch_size=1000, lr=1e-3, milestones=None, gamma=0.1):       
        
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=gamma)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        losses = []
        loop = tqdm(range(epochs))
        for epoch in loop:
            for batch, in loader:
    
                optimizer.zero_grad()

                # Compute loss
                loss = self.loss(batch, prices_index, premium, payoff)
                
                # Backpropagate
                loss.backward()
                optimizer.step()
                scheduler.step()
            
                losses.append(loss.item())
                loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
            
        # Plot loss
        plt.plot(np.log(losses), label='training loss')
        plt.legend()
        plt.show()


class GCausalHedger(Hedger):

    def __init__(self, mb_indices=None, layers=[3, 32, 32, 1], activation='relu'):
        super().__init__()

        self.model = FNN(layers=layers, activation=activation)
        self.mb_indices = mb_indices

    def hedge(self, x):

        # Keep only input features in the adapted Markov blanket
        x = x[:, self.mb_indices, :]
        # Keep only input features up to time T-1
        x = x[:, :, :-1]
        # Turn sequence length dimension into additional batch dimension for parallel processing
        x = x.permute(0, 2, 1)

        hedge = torch.squeeze(self.model(x))

        return hedge


class TCNHedger(Hedger):
    
    def __init__(self, input_size, output_size, num_channels, kernel_size, dropout):
        super().__init__()
        self.input_size = input_size
        self.model = TCN(input_size, output_size, num_channels, kernel_size, dropout)

    def hedge(self, x):

        # Keep only input features up to time T-1
        x = x[:, :, :-1]
        hedge = torch.squeeze(self.model(x))
        return hedge