import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F



class Implicit_Temporal_Func(nn.Module):

    def __init__(self, dim, hidden_dim, down_in, down_out, args, device= "cuda"):
        super(Implicit_Temporal_Func, self).__init__()   
        unfold_dim = args.unfold_dim
        unfold_style = args.unfold_style
        assert unfold_dim in ["self", "all"]
        assert unfold_style in ["one", "dilation"]
        self.unfold_dim = unfold_dim                
        self.unfold_style = unfold_style            
        self.device = device
        self.tsnet = nn.Sequential(
            nn.Linear(dim + 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim),
        ).to(self.device)
        self.down = nn.Conv1d(down_in, down_out, 1).to(self.device)
    

    def unfolding(self, feat):      # feat in [B, T, D]  
        if self.unfold_dim == "self" and self.unfold_style == "one" :
            feat = feat.unsqueeze(1).permute(0, 1, 3, 2)        # [B, 1, D, T]
            unfold_feat = F.unfold(
                feat, 
                kernel_size= (1, 3),
                padding= (0, 1)
            )
            unfold_feat = unfold_feat.view(feat.shape[0], feat.shape[1] * 3, feat.shape[2], feat.shape[3])
            unfold_feat = unfold_feat.permute(0, 1, 3, 2)   # [B, 3, T, D]
        elif self.unfold_dim == "all" and self.unfold_style == "one" :
            print(feat.size())
            padded = nn.ConstantPad3d((0, 0, 1, 1, 0, 0), 0)(feat)
            print(padded.size())
            pass     
        return unfold_feat
    

    def coord_gen(self, shape):
        coord_seqs = []
        for i, n in enumerate(shape):
            v0, v1 = -1, 1
            r = (v1 - v0) / (2 * n)
            seq = v0 + r + (2 * r) * torch.arange(n).float()
            coord_seqs.append(seq)
        ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
        return ret


    def ensemble(self, feat, HR_T):
        unfold_feat = self.unfolding(feat)     # [B, C==3, T, D]
        LR_B, LR_C, LR_T, LR_D = unfold_feat.shape
        r = 2 / LR_T / 2      # radius along T axis
        vx_lst = [-1, 1]        
        eps_shift = 1e-6

        coord = self.coord_gen([HR_T, 1]).to(self.device)
        coord = coord.unsqueeze(0).expand(LR_B, HR_T, LR_D, 2)     # [B, HR_T, D, (x,y)]

        # size is [B, 2 (x_i, 0.0), T, D]
        feat_coord = self.coord_gen([LR_T, 1]).to(self.device)
        feat_coord = feat_coord.unsqueeze(0).expand(LR_B, LR_T, LR_D, 2)
        feat_coord = feat_coord.permute(0, 3, 1, 2)
        
        preds, dists = [], []
        ensemble_res = 0
        for vx in vx_lst :
            coord_ = coord.clone()
            coord_[:, :, :, 0] += vx * r + eps_shift
            coord_.clamp_(-1 + eps_shift, 1 - eps_shift)

            q_feat = F.grid_sample(
                input= torch.concat([unfold_feat[ :, :, :, i] for i in range(unfold_feat.shape[-1])]).unsqueeze(-1), 
                grid= coord_[ :, :, 0, :].repeat(LR_D, 1, 1).unsqueeze(2).flip(-1),
                mode= 'nearest', 
                align_corners= False,    
            ).permute(0, 3, 1, 2)
            q_feat = q_feat.view(LR_D, LR_B, q_feat.shape[2], q_feat.shape[3]).permute(1, 0, 2, 3) 
            q_feat = q_feat.permute(0, 2, 3, 1)         # [N, C, T, D]

            q_coord = F.grid_sample(                
                input= torch.concat([feat_coord[ :, :, :, i] for i in range(feat_coord.shape[-1])]).unsqueeze(-1), 
                grid= coord_[ :, :, 0, :].repeat(LR_D, 1, 1).unsqueeze(2).flip(-1),
                mode='nearest', 
                align_corners=False
            ).permute(0, 3, 1, 2)
            q_coord = q_coord.view(LR_D, LR_B, q_coord.shape[2], q_coord.shape[3]).permute(1, 0, 2, 3)
            q_coord = q_coord.permute(0, 2, 3, 1)     # [N, (x,y), T, D]

            # channel last
            q_feat = q_feat.permute(0, 2, 3, 1)
            q_coord = q_coord.permute(0, 2, 3, 1)

            rel_coord = coord - q_coord
            rel_coord[ :, :, :, 0] *= LR_T
            tsnet_input = torch.cat([q_feat, rel_coord], dim=-1)

            b, t, d = tsnet_input.shape[ : 3]
            pred = self.tsnet(tsnet_input.view(b * t * d, -1))      # sudo batch_size is b * t * d, MLP at C
            pred = pred.view(b, t, d, -1)                           # [B, T, D, C]
            preds.append(pred)

            dist = torch.abs(rel_coord[ :, :, :, 0])
            dists.append(dist + 1e-9)

        tot_dists = torch.stack(dists).sum(dim=0)
        temp = dists[0]
        dists[0] = dists[1]
        dists[1] = temp
        for pred, dist in zip(preds, dists):
            ensemble_res = ensemble_res + pred * (dist / tot_dists).unsqueeze(-1)
        return ensemble_res
    

    def forward(self, feat, HR_T_list, verbose= False):
        assert len(HR_T_list) >= 1      
        for i in range(len(HR_T_list)):
            HR_T = HR_T_list[i]
            ensemble_input = feat if i == 0 else ensembled
            ensembled = self.ensemble(ensemble_input, HR_T)
            ensembled_shape = ensembled.shape       # [B, T', D, tsnet_out_dim]
            ensembled = ensembled.view(ensembled_shape[0] * ensembled_shape[2], ensembled_shape[1], ensembled_shape[-1]).permute(0, 2, 1)
            ensembled = self.down(ensembled)
            ensembled = ensembled.permute(0, 2, 1).view(ensembled_shape[ : 3], 1)
            if verbose : print(f"ITF schema {i} with size {ensembled.size()}")
        return ensembled




