
import sys
import copy
import torch
import numpy as np
import torch.nn as nn
from cqr import helper
from sklearn.model_selection import train_test_split
import time 
from helper import independence_penalty

SCHEDULER = True


if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

###############################################################################
# Helper functions
###############################################################################

def epoch_internal_train(model, loss_func, x_train, y_train, batch_size, optimizer, cnt=0, best_cnt=np.Inf):
    """ Sweep over the data and update the model's parameters

    Parameters
    ----------

    model : class of neural net model
    loss_func : class of loss function
    x_train : pytorch tensor n training features, each of dimension p (nXp)
    batch_size : integer, size of the mini-batch
    optimizer : class of SGD solver
    cnt : integer, counting the gradient steps
    best_cnt: integer, stop the training if current cnt > best_cnt

    Returns
    -------

    epoch_loss : mean loss value
    cnt : integer, cumulative number of gradient steps

    """

    model.train()
    shuffle_idx = np.arange(x_train.shape[0])
    np.random.shuffle(shuffle_idx)
    x_train = x_train[shuffle_idx]
    y_train = y_train[shuffle_idx]
    epoch_losses = []
    for idx in range(0, x_train.shape[0], batch_size):
        cnt = cnt + 1
        optimizer.zero_grad()
        batch_x = x_train[idx : min(idx + batch_size, x_train.shape[0]),:]
        batch_y = y_train[idx : min(idx + batch_size, y_train.shape[0])]
        preds = model(batch_x)
        loss = loss_func(preds, batch_y)
        loss.backward()
        optimizer.step()
        
        epoch_losses.append(loss.cpu().detach().numpy())

        if cnt >= best_cnt:
            break

    epoch_loss = np.mean(epoch_losses)

    return epoch_loss, cnt

def rearrange(all_quantiles, quantile_low, quantile_high, test_preds):
    """ Produce monotonic quantiles

    Parameters
    ----------

    all_quantiles : numpy array (q), grid of quantile levels in the range (0,1)
    quantile_low : float, desired low quantile in the range (0,1)
    quantile_high : float, desired high quantile in the range (0,1)
    test_preds : numpy array of predicted quantile (nXq)

    Returns
    -------

    q_fixed : numpy array (nX2), containing the rearranged estimates of the
              desired low and high quantile

    References
    ----------
    .. [1]  Chernozhukov, Victor, Iván Fernández‐Val, and Alfred Galichon.
            "Quantile and probability curves without crossing."
            Econometrica 78.3 (2010): 1093-1125.

    """
    scaling = all_quantiles[-1] - all_quantiles[0]
    low_val = (quantile_low - all_quantiles[0])/scaling
    high_val = (quantile_high - all_quantiles[0])/scaling
    q_fixed = np.quantile(test_preds,(low_val, high_val),interpolation='linear',axis=1)
    return q_fixed.T

###############################################################################
# Deep conditional mean regression
# Minimizing MSE loss
###############################################################################

# Define the network
class mse_model(nn.Module):
    """ Conditional mean estimator, formulated as neural net
    """

    def __init__(self,
                 in_shape=1,
                 hidden_size=64,
                 dropout=0.5):
        """ Initialization

        Parameters
        ----------

        in_shape : integer, input signal dimension (p)
        hidden_size : integer, hidden layer dimension
        dropout : float, dropout rate

        """

        super().__init__()
        self.in_shape = in_shape
        self.out_shape = 1
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.build_model()
        self.init_weights()

    def build_model(self):
        """ Construct the network
        """
        self.base_model = nn.Sequential(
            nn.Linear(self.in_shape, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.hidden_size, 1),
        )

    def init_weights(self):
        """ Initialize the network parameters
        """
        for m in self.base_model:
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """ Run forward pass
        """
        return torch.squeeze(self.base_model(x))

# Define the training procedure
class LearnerOptimized:
    """ Fit a neural network (conditional mean) to training data
    """
    def __init__(self, model, optimizer_class, loss_func, device='cpu', test_ratio=0.2, random_state=0):
        """ Initialization

        Parameters
        ----------

        model : class of neural network model
        optimizer_class : class of SGD optimizer (e.g. Adam)
        loss_func : loss to minimize
        device : string, "cuda:0" or "cpu"
        test_ratio : float, test size used in cross-validation (CV)
        random_state : int, seed to be used in CV when splitting to train-test

        """
        self.model = model.to(device)
        self.optimizer_class = optimizer_class
        self.optimizer = optimizer_class(self.model.parameters())
        self.loss_func = loss_func.to(device)
        self.device = device
        self.test_ratio = test_ratio
        self.random_state = random_state
        self.loss_history = []
        self.test_loss_history = []
        self.full_loss_history = []

    def fit(self, x, y, epochs, batch_size, verbose=False):
        """ Fit the model to data

        Parameters
        ----------

        x : numpy array, containing the training features (nXp)
        y : numpy array, containing the training labels (n)
        epochs : integer, maximal number of epochs
        batch_size : integer, mini-batch size for SGD

        """

        sys.stdout.flush()
        model = copy.deepcopy(self.model)
        model = model.to(device)
        optimizer = self.optimizer_class(model.parameters())
        best_epoch = epochs

        x_train, xx, y_train, yy = train_test_split(x, y, test_size=self.test_ratio,random_state=self.random_state)

        x_train = torch.from_numpy(x_train).float().to(self.device).requires_grad_(False)
        xx = torch.from_numpy(xx).float().to(self.device).requires_grad_(False)
        y_train = torch.from_numpy(y_train).float().to(self.device).requires_grad_(False)
        yy = torch.from_numpy(yy).float().to(self.device).requires_grad_(False)

        best_cnt = 1e10
        best_test_epoch_loss = 1e10

        cnt = 0
        for e in range(epochs):
            epoch_loss, cnt = epoch_internal_train(model, self.loss_func, x_train, y_train, batch_size, optimizer, cnt)
            self.loss_history.append(epoch_loss)

            # test
            model.eval()
            preds = model(xx)
            test_preds = preds.cpu().detach().numpy()
            test_preds = np.squeeze(test_preds)
            test_epoch_loss = self.loss_func(preds, yy).cpu().detach().numpy()

            self.test_loss_history.append(test_epoch_loss)

            if (test_epoch_loss <= best_test_epoch_loss):
                best_test_epoch_loss = test_epoch_loss
                best_epoch = e
                best_cnt = cnt

            if (e+1) % 100 == 0 and verbose:
                print("CV: Epoch {}: Train {}, Test {}, Best epoch {}, Best loss {}".format(e+1, epoch_loss, test_epoch_loss, best_epoch, best_test_epoch_loss))
                sys.stdout.flush()

        # use all the data to train the model, for best_cnt steps
        x = torch.from_numpy(x).float().to(self.device).requires_grad_(False)
        y = torch.from_numpy(y).float().to(self.device).requires_grad_(False)

        cnt = 0
        for e in range(best_epoch+1):
            if cnt > best_cnt:
                break

            epoch_loss, cnt = epoch_internal_train(self.model, self.loss_func, x, y, batch_size, self.optimizer, cnt, best_cnt)
            self.full_loss_history.append(epoch_loss)

            if (e+1) % 100 == 0 and verbose:
                print("Full: Epoch {}: {}, cnt {}".format(e+1, epoch_loss, cnt))
                sys.stdout.flush()

    def predict(self, x):
        """ Estimate the label given the features

        Parameters
        ----------
        x : numpy array of training features (nXp)

        Returns
        -------
        ret_val : numpy array of predicted labels (n)

        """
        self.model.eval()
        ret_val = self.model(torch.from_numpy(x).to(self.device).requires_grad_(False)).cpu().detach().numpy()
        return ret_val


##############################################################################
# Quantile regression
# Implementation inspired by:
# https://github.com/ceshine/quantile-regression-tensorflow
##############################################################################

class AllQuantileLoss(nn.Module):
    """ Pinball loss function
    """
    def __init__(self, quantiles):
        """ Initialize

        Parameters
        ----------
        quantiles : pytorch vector of quantile levels, each in the range (0,1)


        """
        super().__init__()
        self.quantiles = quantiles

    def forward(self, preds, target):
        """ Compute the pinball loss

        Parameters
        ----------
        preds : pytorch tensor of estimated labels (n)
        target : pytorch tensor of true labels (n)

        Returns
        -------
        loss : cost function value

        """
        assert not target.requires_grad
        assert preds.size(0) == target.size(0)
        losses = []

        for i, q in enumerate(self.quantiles):
            errors = target - preds[:, i]
            losses.append(torch.max((q-1) * errors, q * errors).unsqueeze(1))

        loss = torch.mean(torch.sum(torch.cat(losses, dim=1), dim=1))
        return loss
    
    
# from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric

class CI_loss(nn.Module):
    """
    Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calculated as

    Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
    """

    def __init__(
        self,
        coverage,
        penalty):
        """
        Quantile loss

        Args:
            quantiles: quantiles for metric
        """
        super().__init__()
        self.c = coverage
        self.l = penalty
    def forward(self, preds: torch.Tensor,  target: torch.Tensor) -> torch.Tensor:
        y_pred_q1 = preds[:, 0]
        y_pred_q2 = preds[:, 1]

        errors1 =   target -y_pred_q1 
        errors2 =   target- y_pred_q2
        errors =y_pred_q2-y_pred_q1 
        loss1 = torch.maximum(errors1*errors2*(self.c+2*self.l), errors2*errors1*(self.c+2*self.l-1))
        loss2 = self.l*torch.square(errors)*0.5
        
        loss = loss1 + loss2

        return torch.mean(loss)


class HQ_loss(nn.Module):
    """
    Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calculated as

    Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
    """

    def __init__(
        self,
        coverage,
        penalty,):
        """
        Quantile loss

        Args:
            quantiles: quantiles for metric
        """
        super().__init__()
        self.c = coverage
        self.penalty = penalty
    def forward(self, preds: torch.Tensor,  target: torch.Tensor) -> torch.Tensor:
        y_pred_q1 = preds[:, 0]
        y_pred_q2 = preds[:, 1]
        
        n = y_pred_q1.shape[0]
        soften = torch.full((n,), 160).to(y_pred_q1.device)
        
        
        
        K_HU = torch.maximum(torch.Tensor([0.]).to(y_pred_q1.device), torch.sign(y_pred_q2-target))
        K_HL = torch.maximum(torch.Tensor([0.]).to(y_pred_q1.device), torch.sign(target-y_pred_q1))
        K_H = torch.mul(K_HU, K_HL)
        
        K_SU = torch.sigmoid(torch.mul(soften,y_pred_q2-target))
        K_SL = torch.sigmoid(torch.mul(soften,target-y_pred_q1))
        K_S = torch.mul(K_SU, K_SL)
        
        c = torch.sum(K_H)+ 0.001
        MPIW_capt = (1/c)*torch.sum(torch.mul(K_H,(y_pred_q2-y_pred_q1)))
        

        PICP_H = torch.mean(K_H)
        PICP_S = torch.mean(K_S)
        
        
        loss = MPIW_capt + self.penalty*n*(1/((1-self.c)*self.c))*torch.square(torch.maximum(torch.Tensor([0.]).to(y_pred_q1.device),self.c-PICP_S))

        
        return loss

class Winkler_Loss(nn.Module):
    """ Pinball loss function
    """
    def __init__(self, quantiles):
        """ Initialize

        Parameters
        ----------
        quantiles : pytorch vector of quantile levels, each in the range (0,1)


        """
        super().__init__()
        self.q1 = quantiles[0]
        self.q2 = quantiles[1]

    def forward(self, preds, target):
        """ Compute the pinball loss

        Parameters
        ----------
        preds : pytorch tensor of estimated labels (n)
        target : pytorch tensor of true labels (n)

        Returns
        -------
        loss : cost function value

        """
        assert not target.requires_grad
        assert preds.size(0) == target.size(0)
        c = self.q2-self.q1
        alpha = 1-c
        y_pred_q1 = preds[:, 0]
        y_pred_q2 = preds[:, 1]
        
        below_1 = (y_pred_q1-target).gt(-1)
        below_2 = (target-y_pred_q2).gt(-1)
        
        # print(y_pred_q1.shape, y_pred_q2.shape, target.shape, below_1.shape, below_2.shape)
        loss = (y_pred_q2-y_pred_q1) + (2/alpha)*(y_pred_q1-target)*below_1  + (2/alpha)*(target-y_pred_q2)*below_2

        
        return loss.mean()



class all_q_model(nn.Module):
    """ Conditional quantile estimator, formulated as neural net
    """
    def __init__(self,
                 quantiles,
                 in_shape=1,
                 hidden_size=64,
                 dropout=0.5,
                 use_HQ_loss = False,):
 
        """ Initialization

        Parameters
        ----------
        quantiles : numpy array of quantile levels (q), each in the range (0,1)
        in_shape : integer, input signal dimension (p)
        hidden_size : integer, hidden layer dimension
        dropout : float, dropout rate

        """
        super().__init__()
        self.quantiles = quantiles
        self.use_HQ_loss = use_HQ_loss
        self.num_quantiles = len(quantiles)
        self.hidden_size = hidden_size
        self.in_shape = in_shape
        self.out_shape = len(quantiles)
        self.dropout = dropout
        self.build_model()
        self.init_weights()

    def build_model(self):
        """ Construct the network
        """
        self.base_model = nn.Sequential(
            nn.Linear(self.in_shape, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.hidden_size, self.num_quantiles),
        )

    def init_weights(self):
        """ Initialize the network parameters
        """
        for m in self.base_model:
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight)
                if self.use_HQ_loss and m.bias.shape == torch.Size([2]):
                    m.bias = nn.Parameter(torch.Tensor([0, 1]))
                else:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        """ Run forward pass
        """
        return self.base_model(x)

class LearnerOptimizedCrossing:
    """ Fit a neural network (conditional quantile) to training data
    """
    def __init__(self, model, optimizer_class, loss_func, device='cpu', test_ratio=0.2, random_state=0,
                 qlow=0.05, qhigh=0.95, use_rearrangement=False, use_CI_loss = False, penalty = 1, scheduler1 = 0.999, scheduler2 = 0.995):
        """ Initialization

        Parameters
        ----------

        model : class of neural network model
        optimizer_class : class of SGD optimizer (e.g. pytorch's Adam)
        loss_func : loss to minimize
        device : string, "cuda:0" or "cpu"
        test_ratio : float, test size used in cross-validation (CV)
        random_state : integer, seed used in CV when splitting to train-test
        qlow : float, low quantile level in the range (0,1)
        qhigh : float, high quantile level in the range (0,1)
        use_rearrangement : boolean, use the rearrangement  algorithm (True)
                            of not (False)

        """
        self.model = model.to(device)
        self.use_rearrangement = use_rearrangement
        self.compute_coverage = True
        self.quantile_low = qlow
        self.quantile_high = qhigh
        self.target_coverage = 100.0*(self.quantile_high - self.quantile_low)
        self.use_CI_loss = use_CI_loss
        # if not use_CI_loss:
            # self.all_quantiles = loss_func.quantiles
        self.optimizer_class = optimizer_class
        self.optimizer = optimizer_class(self.model.parameters())
        self.loss_func = loss_func.to(device)
        self.device = device
        self.test_ratio = test_ratio
        self.random_state = random_state
        self.loss_history = []
        self.test_loss_history = []
        self.full_loss_history = []
        self.scheduler1 = scheduler1
        self.scheduler2 = scheduler2
        
    def fit(self, x, y, epochs, batch_size, verbose=False):
        """ Fit the model to data

        Parameters
        ----------

        x : numpy array of training features (nXp)
        y : numpy array of training labels (n)
        epochs : integer, maximal number of epochs
        batch_size : integer, mini-batch size used in SGD solver

        """
        sys.stdout.flush()
        model = copy.deepcopy(self.model)
        best_model = copy.deepcopy(self.model)
        model = model.to(device)
        optimizer = self.optimizer_class(model.parameters())
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.scheduler1) #0.999
        
        best_epoch = epochs

        x_train, xx, y_train, yy = train_test_split(x,
                                                    y,
                                                    test_size=self.test_ratio,
                                                    random_state=self.random_state)

        x_train = torch.from_numpy(x_train).float().to(self.device).requires_grad_(False)
        xx = torch.from_numpy(xx).float().to(self.device).requires_grad_(False)
        y_train = torch.from_numpy(y_train).float().to(self.device).requires_grad_(False)
        yy_cpu = yy
        yy = torch.from_numpy(yy).float().to(self.device).requires_grad_(False)

        best_avg_length = 1e10
        best_coverage = 0
        best_cnt = 1e10

        cnt = 0
        
        
        for e in range(epochs):
            model.train()
            epoch_loss, cnt = epoch_internal_train(model, self.loss_func, x_train, y_train, batch_size, optimizer, cnt)
            if SCHEDULER : 
                lr_scheduler.step()

            self.loss_history.append(epoch_loss)
            model.eval()
            preds = model(xx)
            test_epoch_loss = self.loss_func(preds, yy).cpu().detach().numpy()
            self.test_loss_history.append(test_epoch_loss)
            test_preds = preds.cpu().detach().numpy()
            test_preds = np.squeeze(test_preds)

            # if self.use_rearrangement:
            #     test_preds = rearrange(self.all_quantiles, self.quantile_low, self.quantile_high, test_preds)

            y_lower = test_preds[:,0]
            y_upper = test_preds[:,1]
            for i in range(len(test_preds)):
                y_lower[i],y_upper[i] =   min(y_lower[i], y_upper[i]), max(y_lower[i], y_upper[i])
            
            
            coverage, avg_length = helper.compute_coverage_len(yy_cpu, y_lower, y_upper)
            if (coverage >= self.target_coverage) and (avg_length < best_avg_length):
  
                best_avg_length = avg_length
                best_coverage = coverage
                best_epoch = e
                best_cnt = cnt
                best_model = model
            if (e+1) % 100 == 0 and verbose:
                print("CV: Epoch {}: Train {}, Test {}, Best epoch {}, Best Coverage {} Best Length {} Cur Coverage {}".format(e+1, epoch_loss, test_epoch_loss, best_epoch, best_coverage, best_avg_length, coverage))
                sys.stdout.flush()
        
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.scheduler2)

        x = torch.from_numpy(x).float().to(self.device).requires_grad_(False)
        y = torch.from_numpy(y).float().to(self.device).requires_grad_(False)
        print("Best cnt: ", best_cnt)
        cnt = 0
        for e in range(best_epoch):
            if cnt > best_cnt:
                break
            epoch_loss, cnt = epoch_internal_train(self.model, self.loss_func, x, y, batch_size, self.optimizer, cnt, best_cnt)
            
            if SCHEDULER : 
                lr_scheduler.step()

            self.full_loss_history.append(epoch_loss)

            if (e+1) % 100 == 0 and verbose:
                print("Full: Epoch {}: {}, cnt {}".format(e+1, epoch_loss, cnt))
                sys.stdout.flush()
        
    
        
    def predict(self, x):
        """ Estimate the conditional low and high quantile given the features

        Parameters
        ----------
        x : numpy array of training features (nXp)

        Returns
        -------
        test_preds : numpy array of predicted low and high quantiles (nX2)

        """
        self.model.eval()
        test_preds = self.model(torch.from_numpy(x).to(self.device).requires_grad_(False)).cpu().detach().numpy()
        if self.use_rearrangement:
            test_preds = rearrange(self.all_quantiles, self.quantile_low, self.quantile_high, test_preds)
        else:
            for i in range(len(test_preds)):
                test_preds[i][0], test_preds[i][1] = min(test_preds[i][0], test_preds[i][1]), max(test_preds[i][0], test_preds[i][1])
            
            # test_preds[:,0] = np.min(test_preds,axis=1)
            # test_preds[:,1] = np.max(test_preds,axis=1)
        return test_preds
