from typing import Dict, Any, Optional, List

import torch
import torch.nn as nn
from einops import rearrange
import torchvision.transforms as transforms

from models.labits_spline.temporal_fusion import *
from models.labits_utlis.aplof import UNet
from models.labits_utlis.extractor import *

from models.raft_spline.bezier import BezierCurves
from models.raft_spline.update import BasicUpdateBlock
from models.raft_utils.extractor import BasicEncoder
from models.raft_utils.corr import CorrComputation, CorrBlockParallelMultiTarget
from models.raft_utils.utils import coords_grid
from utils.timers import CudaTimerDummy as CudaTimer

    
class LabitsSpline(nn.Module):
    def __init__(self, model_params: Dict[str, Any]): 
        super().__init__()
        nbins_context = model_params['num_bins']['context']
        nbins_correlation = model_params['num_bins']['correlation']
        self.bezier_degree = model_params['bezier_degree']
        self.detach_bezier = model_params['detach_bezier']
        self.ltm_threshold = model_params['ltm_threshold']
        print(f'ltm_threshold: {self.ltm_threshold}')
        print(f'Detach Bezier curves: {self.detach_bezier}')

        assert nbins_correlation > 0 and nbins_context > 0
        assert self.bezier_degree >= 1
        self.nbins_context = nbins_context
        self.nbins_corr = nbins_correlation
        self.total_bins = self.nbins_context+ self.nbins_corr - 1

        print('Labits-Spline config:')
        print(f'Num bins context: {nbins_context}')
        print(f'Num bins correlation: {nbins_correlation}')

        corr_params = model_params['correlation']
        self.corr_use_cosine_sim = corr_params['use_cosine_sim']

        ev_corr_params = corr_params['ev']
        self.ev_corr_target_indices = ev_corr_params['target_indices']
        self.ev_corr_levels = ev_corr_params['levels']

        self.ev_corr_radius = 4

        self.img_corr_params = None
        if model_params['use_boundary_images']:
            print('Using images')
            self.img_corr_params = corr_params['img']
            assert 'levels' in self.img_corr_params
            assert 'radius' in self.img_corr_params

        self.hidden_dim = model_params['hidden']['dim']
        self.context_dim = model_params['context']['dim']
        feature_dim = model_params['feature']['dim']
        fnorm = model_params['feature']['norm']
        cnorm = model_params['context']['norm']

        self.speed_dim = 2
        base_params = model_params['base']
        
        self.speed_kernel = base_params['speed_kernel']
        self.normalization = base_params['normalization']
        self.inverse = base_params ['inverse']
        self.local_extractor = base_params['local_extractor']
        self.use_grid_attn = base_params['use_grid_attn']
        self.attention_cfg = base_params['attention_cfg']
        self.temporal_fusion = base_params['temporal_fusion']
        self.local_timemask = base_params['local_timemask']

        self.pretrained_unet_path = model_params['pretrained_unet_path']
        self.pretrained_dim = model_params['pretrained_dim']
        self.visualization = model_params['speed_visualization']

        self.speed_time_slice = [28, 32, 36, 40, 44, 48, 52, 56, 60, 64]

        # img normalization
        self.img_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                               std=[0.229, 0.224, 0.225])
           
        

        # feature network, context network, and update block
        context_dim = 0
        self.fnet_img = None
        if self.img_corr_params is not None:
            # 3->256
            self.fnet_img = BasicEncoder(input_dim=3, output_dim=feature_dim, norm_fn=fnorm)   #added attention_cfg
            context_dim += 3   #3

        self.fnet_ev = None
        if model_params['use_events']:
            print('Using events')
            assert 0 not in self.ev_corr_target_indices
            assert len(self.ev_corr_target_indices) > 0
            assert max(self.ev_corr_target_indices) < self.nbins_context
            assert len(self.ev_corr_target_indices)  == len(self.ev_corr_levels)
            self.fnet_ev = BasicEncoder(input_dim=nbins_correlation, output_dim=feature_dim, norm_fn=fnorm)   #added attention_cfg

            self.pretrained_unet = UNet(1, 2, self.pretrained_dim, self.visualization)
            print('loading pretrained flow unet')
            checkpoint = torch.load(self.pretrained_unet_path)

            # Adjust the keys in the loaded state_dict
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint

            # Remove the 'net.' prefix
            adjusted_state_dict = {key.replace('net.', ''): value for key, value in state_dict.items()}

            # Load the adjusted state_dict
            self.pretrained_unet.load_state_dict(adjusted_state_dict, strict=True)
            self.pretrained_unet.eval()
            print('pretrained flow unet loaded')

            for param in self.pretrained_unet.parameters():
                param.requires_grad = False
            print('pretrained flow unet frozen')

            context_dim += nbins_context     #3+41
        assert self.fnet_ev is not None or self.fnet_img is not None

        
        self.pretrained_down = nn.Sequential(DoubleConv(self.pretrained_dim*88, feature_dim*2), DoubleConv(feature_dim*2, feature_dim), DoubleConv(feature_dim, feature_dim//2))
        
        self.cnet = BasicEncoder(input_dim=context_dim, output_dim=feature_dim, norm_fn=cnorm)  #added attention_cfg
        self.merge_cnet_speed = DoubleConv(feature_dim, self.hidden_dim)
        self.update_block = BasicUpdateBlock(model_params, hidden_dim=self.hidden_dim)
         

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def initialize_flow(self, input_):
        N, _, H, W = input_.shape
        # B, 2, H, W
        downsample_factor = 8
        coords0 = coords_grid(N, H//downsample_factor, W//downsample_factor, device=input_.device)
        # Here we are just initializing the bezier curves tensor with zeros which has a shape of [B, 2*(n_ctrl_pts - 1), H/DF, W/DF]
        bezier = BezierCurves.create_from_voxel_grid(input_, downsample_factor=downsample_factor, bezier_degree=self.bezier_degree)
        return coords0, bezier

    def gen_voxel_grids(self, input_: torch.Tensor):
        # input_: N, nbins_context + nbins_corr - 1 , H, W
        assert self.nbins_context + self.nbins_corr - 1 == input_.shape[1]
        corr_grids = list()
        # We need to add the reference index (which is 0).
        indices_with_reference = [0]
        indices_with_reference.extend(self.ev_corr_target_indices)
        for idx in indices_with_reference:

            ##### J
            slice_ = input_[:, idx:idx+self.nbins_corr, ...]
            corr_grids.append(slice_)
        context_grid = input_[:, -self.nbins_context:, ...]
        return corr_grids, context_grid

    def forward(self,
                bilts: Optional[torch.Tensor]=None,
                images: Optional[List[torch.Tensor]]=None,
                iters: int=12,
                flow_init: Optional[BezierCurves]=None,
                test_mode: bool=False):
        assert bilts is not None or images is not None
        assert iters > 0

        hdim = self.hidden_dim
        cdim = self.context_dim
        current_device = bilts.device if bilts is not None else images[0].device

        corr_computation_events = None
        context_input = None
        with CudaTimer(current_device, 'fnet_ev'):
            if self.fnet_ev is not None:
                assert bilts is not None
                bilts = bilts.contiguous() # (B, T, H, W)
                
                if self.visualization:
                    # (B, T(10), H, W) -> (B, T, 2, H, W), (B, 512, H/8, W/8)
                    pred_lr, pred_hr, flow_feat = self.pretrained_unet(bilts[:, [20, 24]+self.speed_time_slice, ...], self.ltm_threshold)
                    pred_lr = pred_lr[:, 2:, ...]
                    pred_hr = pred_hr[:, 2:, ...]
                else:
                    # (B, T, H, W) -> (B, T, 2, H, W), (B, 512, H/8, W/8)
                    flow_feat = self.pretrained_unet(bilts[:, [20, 24]+ self.speed_time_slice, ...], self.ltm_threshold)
                    pred_lr = None
                    pred_hr = None      
                #(BT, 128, h, w) -> (B, T*128, h, w)
                flow_corr = torch.chunk(flow_feat, 6, dim=1)
                flow_corr = [rearrange(x, 'B T C h w -> B (T C) h w') for x in flow_corr] # 5x[(B, 2*C, h, w)]

                flow_feat = flow_feat[:, 1:, ...]
                flow_feat = rearrange(flow_feat, 'B T C h w -> B (T C) h w')
                flow_feat = self.pretrained_down(flow_feat)

                # 6*[(B, 25, H, W)], (B, 41, H, W)
                corr_grids, context_input = self.gen_voxel_grids(bilts)
                
                # 6x[(B, 256, h, w)]
                fmaps_ev = self.fnet_ev(corr_grids)
                
                fmaps = [torch.cat((fmaps_ev[i], flow_corr[i]), dim=1) for i in range(len(fmaps_ev))]
                fmaps = [self.merge_fmaps(x) for x in fmaps]

                fmap1 = fmaps[0]
                fmap2 = torch.stack(fmaps[1:], dim=0)                
                corr_computation_events = CorrComputation(fmap1, fmap2, num_levels_per_target=self.ev_corr_levels)

        corr_computation_frames = None
        with CudaTimer(current_device, 'fnet_img'):
            if self.fnet_img is not None:
                assert self.img_corr_params is not None
                assert len(images) == 2
                images = [self.img_normalize(x.float().contiguous()) for x in images]
                fmaps_img = self.fnet_img(images)
                corr_computation_frames = CorrComputation(fmaps_img[0], fmaps_img[1], num_levels_per_target=self.img_corr_params['levels'])
                if context_input is not None:
                    context_input = torch.cat((context_input, images[0]), dim=-3)
                else:
                    context_input = images[0]
        assert context_input is not None

        with CudaTimer(current_device, 'cnet'):
            ## (B, 41+3, H, W) --> (B, 128+128, H, W)
            cnet = self.cnet(context_input)
            # net, inp = torch.split(cnet, [hdim, cdim], dim=1)
            net, inp = torch.split(cnet, [hdim, cdim], dim=1)
            net = self.merge_cnet_speed(torch.cat((net, flow_feat), dim=1))
            net = torch.tanh(net) # hidden state in GRU, [-1, 1], tanh used since GRU requires the hidden state to be in the range [-1, 1]
            inp = torch.relu(inp)

        # Coordinates grid: (B, 2, H, W), xy, 0~H-1, 0~W-1
        # bezier shape: (B, 2*(control point num-1), H, W), pure zeros
        ## (B, 85, H, W) -->  bezier: (B, 2*(10-1), H, W)
        coords0, bezier = self.initialize_flow(context_input)

        if flow_init is not None:
            # Init the bezier curve control points parameters with the provided bezier curve.
            # Not used in the default RAFT-Spline.
            bezier.delta_update_params(flow_init.get_params())

        bezier_up_predictions = []
        dt = 1/(self.nbins_context - 1)

        with CudaTimer(current_device, 'corr computation'):
            # [num_targets, B, HW, HW]
            corr_block = CorrBlockParallelMultiTarget(
                corr_computation_events=corr_computation_events,
                corr_computation_frames=corr_computation_frames)
        with CudaTimer(current_device, 'all iters'):
            for itr in range(iters):
                # NOTE: original RAFT detaches the flow (bezier) here from the graph.
                # Our experiments with bezier curves indicate that detaching is lowering the validation EPE by up to 5% on DSEC.
                with CudaTimer(current_device, '1 iter'):
                    if self.detach_bezier:
                        bezier.detach_()

                    lookup_timestamps = list()
                    if corr_computation_events is not None:
                        for tindex in self.ev_corr_target_indices:
                            # 0 < time <= 1
                            time = dt*tindex
                            lookup_timestamps.append(time)
                    if corr_computation_frames is not None:
                        lookup_timestamps.append(1)

                    with CudaTimer(current_device, 'get_flow (per iter)'):
                        flows = bezier.get_flow_from_reference(time=lookup_timestamps)
                        coords1 = coords0 + flows

                    with CudaTimer(current_device, 'corr lookup (per iter)'):
                        corr_total = corr_block(coords1)

                    with CudaTimer(current_device, 'update (per iter)'):
                        bezier_params = bezier.get_params()
                        net, up_mask, delta_bezier = self.update_block(net, inp, corr_total, bezier_params)

                    # B(k+1) = B(k) + \Delta(B)
                    bezier.delta_update_params(delta_bezier=delta_bezier)

                    if not test_mode or itr == iters - 1:
                        bezier_up = bezier.create_upsampled(up_mask)
                        bezier_up_predictions.append(bezier_up)

        if test_mode:
            return bezier, bezier_up, pred_lr, pred_hr

        return bezier_up_predictions, pred_lr, pred_hr
