
# Cell
from torch.distributions.geometric import Geometric
from torch.distributions.binomial import Binomial
import torch
from torch import nn

from typing import Optional
from numpy import random

# import sys
# sys.path.insert(1, '/dccstor/nnguyen1/projects/TSFoundation/')
from .core import Callback
from src.learner import get_model

# Cell
def create_subsequence_mask(o, r=.15, lm=3, stateful=True, sync=False):
    if r <= 0: return torch.zeros_like(o).bool()
    device = o.device
    if o.ndim == 2: o = o[None]
    n_masks, mask_dims, mask_len = o.shape
    if sync == 'random': sync = random.random() > .5
    dims = 1 if sync else mask_dims
    if stateful:
        numels = n_masks * dims * mask_len
        pm = torch.tensor([1 / lm], device=device)
        pu = torch.clip(pm * (r / max(1e-6, 1 - r)), 1e-3, 1)
        zot, proba_a, proba_b = (torch.as_tensor([False, True], device=device), pu, pm) if random.random() > pm else \
        (torch.as_tensor([True, False], device=device), pm, pu)
        max_len = max(1, 2 * torch.div(numels, (1/pm + 1/pu), rounding_mode='floor').long().item())
        for i in range(10):
            _dist_a = (Geometric(probs=proba_a).sample([max_len])+1).long()
            _dist_b = (Geometric(probs=proba_b).sample([max_len])+1).long()
            dist_a = _dist_a if i == 0 else torch.cat((dist_a, _dist_a), dim=0)
            dist_b = _dist_b if i == 0 else torch.cat((dist_b, _dist_b), dim=0)
            add = torch.add(dist_a, dist_b)
            if torch.gt(torch.sum(add), numels): break
        dist_len = torch.argmax((torch.cumsum(add, 0) >= numels).float()) + 1
        if dist_len%2: dist_len += 1
        repeats = torch.cat((dist_a[:dist_len], dist_b[:dist_len]), -1).flatten()
        zot = zot.repeat(dist_len)
        mask = torch.repeat_interleave(zot, repeats)[:numels].reshape(n_masks, dims, mask_len)
    else:
        probs = torch.tensor(r, device=device)
        mask = Binomial(1, probs).sample((n_masks, dims, mask_len)).bool()
    if sync: mask = mask.repeat(1, mask_dims, 1)
    return mask

def create_variable_mask(o, r=.15):
    if r <= 0: return torch.zeros_like(o).bool()
    device = o.device
    n_masks, mask_dims, mask_len = o.shape
    _mask = torch.zeros((n_masks * mask_dims, mask_len), device=device)
    if int(mask_dims * r) > 0:
        n_masked_vars = int(n_masks * mask_dims * r)
        p = torch.tensor([1./(n_masks * mask_dims)], device=device).repeat([n_masks * mask_dims])
        sel_dims = p.multinomial(num_samples=n_masked_vars, replacement=False)
        _mask[sel_dims] = 1
    mask = _mask.reshape(*o.shape).bool()
    return mask

def create_future_mask(o, r=.15, sync=False):
    if r <= 0: return torch.zeros_like(o).bool()
    if o.ndim == 2: o = o[None]
    n_masks, mask_dims, mask_len = o.shape
    if sync == 'random': sync = random.random() > .5
    dims = 1 if sync else mask_dims
    probs = torch.tensor(r, device=o.device)
    mask = Binomial(1, probs).sample((n_masks, dims, mask_len))
    if sync: mask = mask.repeat(1, mask_dims, 1)
    mask = torch.sort(mask,dim=-1, descending=False)[0].bool()
    return mask
    
def self_mask(o):
    mask1 = torch.isnan(o)
    mask2 = rotate_axis0(mask1)
    return torch.logical_and(mask2, ~mask1)

# Cell
def create_mask(o,  r=.15, lm=3, stateful=True, sync=False, subsequence_mask=True, variable_mask=False, future_mask=False):
    if r <= 0 or r >=1: return torch.zeros_like(o).bool()
    if int(r * o.shape[1]) == 0:
        variable_mask = False
    if subsequence_mask and variable_mask:
        random_thr = 1/3 if sync == 'random' else 1/2
        if random.random() > random_thr:
            variable_mask = False
        else:
            subsequence_mask = False
    elif future_mask:
        return create_future_mask(o, r=r)
    elif subsequence_mask:
        return create_subsequence_mask(o, r=r, lm=lm, stateful=stateful, sync=sync)
    elif variable_mask:
        return create_variable_mask(o, r=r)
    else:
        raise ValueError('You need to set subsequence_mask, variable_mask or future_mask to True or pass a custom mask.')




class MVPSimpleCB(Callback):
    order = 60

    def __init__(self, r: float = .15, subsequence_mask: bool = True, lm: float = 3., stateful: bool = True, 
                 sync: bool = False, variable_mask: bool = False,
                 future_mask: bool = False, custom_mask = None, sel_vars: Optional[list] = None, nan_to_num: int = 0,
                 window_size: Optional[tuple] = None, dropout: float = 0, crit: callable = None, 
                ddp:bool = True):
        r"""
        Callback used to perform the pretext task of reconstruct the original data after a binary mask has been applied.

        Args:
            r:                proba of masking.
            subsequence_mask: apply a mask to random subsequences.
            lm:               average mask len when using stateful (geometric) masking.
            stateful:         geometric distribution is applied so that average mask length is lm.
            sync:             all variables have the same masking.
            variable_mask:    apply a mask to random variables. Only applicable to multivariate time series.
            future_mask:      used to train a forecasting model.
            custom_mask:      allows to pass any type of mask with input tensor and output tensor. Values to mask should be set to True.
            sel_vars:         allows to pass a list of variables to mask. If None, all variables will be masked.
            nan_to_num:       integer used to fill masked values
            window_size:      allows you to pass a fixed window size or tuple of window sizes to train MVP with on sequences of different length.
                              You may pass int(s) or float(s).
            dropout:          dropout applied to the head of the model during pretraining.
            crit:             loss function that will be used. If None MSELossFlat().            
    """
        assert subsequence_mask or variable_mask or future_mask or custom_mask, \
            'you must set (subsequence_mask and/or variable_mask) or future_mask to True or use a custom_mask'
        if custom_mask is not None and (future_mask or subsequence_mask or variable_mask):
            warnings.warn("Only custom_mask will be used")
        elif future_mask and (subsequence_mask or variable_mask):
            warnings.warn("Only future_mask will be used")
        
        self.subsequence_mask,self.variable_mask,self.future_mask,self.custom_mask = subsequence_mask,variable_mask,future_mask,custom_mask
        self.r,self.lm,self.dropout = r,lm,dropout         
        self.stateful, self.sync = stateful,sync
        self.crit = crit        
        self.nan_to_num = nan_to_num
        self.ddp = ddp

        self.window_size = window_size
        self.sel_vars = sel_vars        
    

    def before_fit(self):
        # prepare to save best model
        # self.best = float('inf')

        # modify loss for denoising task
        # self.old_loss_func = self.learner.loss_func
        if self.crit is None: self.crit = nn.MSELoss()
        self.learner.loss_func = self._loss        
        # self.learner.MVP = self                
        device = self.learner.device        
        
        assert hasattr(get_model(self.learner.model), "head"), "model must have a head attribute to be trained with MVP"
        
        # check if the pretrain_head is already in the model
        if hasattr(get_model(self.learner.model), "pretrain_head"):
            if get_model(self.learner.model).pretrain_head: 
                print('The model already has a pretrain head') 
                return        
      
        # change head with conv layer (equivalent to linear layer applied to dim=1)
        if self.ddp:            
            self.learner.model.module.head = nn.Sequential(nn.Dropout(self.dropout),
                                              nn.Conv1d(self.learner.model.module.head_nf, self.learner.dls.vars, 1)
                                             ).to(device=device)
        else:            
            print('before add head', self.learner.model.head)
            self.learner.model.head = nn.Sequential(nn.Dropout(self.dropout),
                                              nn.Conv1d(self.learner.model.head_nf, self.learner.dls.vars, 1)
                                             ).to(device=device)
            # self.learner.model.head = self.model.create_pretrain_head(self.model.head_nf, 
            #                                             self.dls.vars, 
            #                                             self.dropout).to(device=device)
            print('after add head', self.learner.model.head)
            print('add model pretrain head')

        # testing to see if the output has the same shape as input
        with torch.no_grad():
            xb = torch.randn(2, self.learner.dls.vars, self.learner.dls.len).to(device=device)
            assert xb.shape == self.learner.model(xb).shape, 'the model cannot reproduce the input shape'


    def before_batch_train(self): self.masking()
    def before_batch_valid(self): self.masking()
        
#     def before_predict(self): self.all_mask = []
#     def before_batch_predict(self): self.masking()
#     def after_batch_predict(self):
#         self.all_mask.append(self.mask)
#     def after_predict(self):
#         self.all_mask = torch.concat(self.all_mask).detach().cpu().numpy()
#         self.learner.all_mask = self.all_mask
        
    
    def masking(self):
        xb, yb = self.batch                
        
        self.mask = create_mask(xb, r=self.r, lm=self.lm, stateful=self.stateful, 
                                    sync=self.sync, subsequence_mask=self.subsequence_mask,
                                   variable_mask=self.variable_mask, 
                                   future_mask=self.future_mask).bool()
        
        xb_mask = xb.masked_fill(self.mask, self.nan_to_num)        
        self.learner.batch = xb_mask, xb        
        


    def _loss(self, preds, target):        
        return self.crit(preds[self.mask], target[self.mask])
    
    