import math
import matplotlib.pyplot as plt
from functools import partial
import itertools
import numpy as np
from tqdm import tqdm
from typing import *
from pylab import cm

import torch
from torch import Tensor, vmap
from torch.func import grad_and_value, jacrev, vmap
import torch.nn as nn
from torch.nn.functional import leaky_relu, sigmoid, softmax
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
import torch.nn.utils as nn_utils
from torch.distributions import Dirichlet, Categorical, Normal, Uniform
from torchdiffeq import odeint_adjoint, odeint

from zuko.distributions import DiagNormal
from unet import *

# from dataloader_pinwheel import *

torch.set_printoptions(precision=3)
torch.set_default_dtype(torch.float64)

class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: List[int] = [64, 64],
        fct=nn.Tanh(),
        batch_norm=False,
        dropout=False,
        weight_norm=False,
        layer_norm=False,
        p=0.2
        # fct=ScaledSigmoid()
    ):
        layers = []

        for a, b in zip(
            (in_features, *hidden_features),
            (*hidden_features, out_features),
        ):  
            linear_layer = nn.Linear(a, b)
            if weight_norm:
                linear_layer = nn_utils.weight_norm(linear_layer)
            if batch_norm:
                layers.extend([linear_layer, nn.BatchNorm1d(b), fct])
            elif layer_norm:
                layers.extend([linear_layer, nn.LayerNorm(b), fct])
            elif dropout:
                layers.extend([linear_layer, nn.Dropout(p=p), fct])
            else:
                layers.extend([linear_layer, fct])

        if not weight_norm or batch_norm or layer_norm or dropout:
            super().__init__(*layers[:-1])
        else:
            super().__init__(*layers[:-2])


class cnnLLK(nn.Module):
    def __init__(self, x_features: int, freqs: int = 2, in_ch: int = 1, mod_ch: int = 128, hidden_dim = 512, num_blocks=4, unet=True, droprate=0.2, **kwargs):
        super().__init__()

        self.register_buffer('freqs', torch.arange(1, freqs + 1) * torch.pi)
        self.x_features = x_features
        self.in_ch = in_ch
        self.unet = unet
        self.base_ch = mod_ch
        self.num_blocks = num_blocks

        if unet:            
            # self.input_proj = nn.Conv2d(self.in_ch, self.base_ch, 3, padding=1)
            self.input_proj = nn.Conv2d(self.in_ch, self.base_ch, 3, padding=1)
            # self.fc = Unet(2*self.in_ch+2*freqs, mod_ch, in_ch, cdim=z_features, freqs=freqs)
            self.res_blocks = EmbedSequential_nz(
                *[ResBlock_nz(self.base_ch, self.base_ch, 2*freqs, droprate) for _ in range(self.num_blocks)])

            self.output_proj = nn.Sequential(
                nn.GroupNorm(8, self.base_ch),
                nn.SiLU(),
                nn.Conv2d(self.base_ch, self.in_ch, 3, padding=1)
                )
        else:
            self.cnn = nn.Sequential(
                nn.Conv2d(in_ch, 32, kernel_size=3, stride=1, padding=1),  # Output: 32 x 28 x 28
                nn.Softplus(),
                nn.MaxPool2d(kernel_size=2, stride=2),  # Output: 32 x 14 x 14
                nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # Output: 64 x 14 x 14
                nn.Softplus(),
                nn.MaxPool2d(kernel_size=2, stride=2)  # Output: 64 x 7 x 7
            )
            # Fully connected layer to project CNN features to the hidden dimension
            self.cnn_fc = nn.Linear(64 * (x_features // 4) * (x_features // 4), hidden_dim)

            # Combine the processed inputs and output the vector field
            self.combine_fc = MLP(hidden_dim+2*freqs, x_features**2*in_ch, **kwargs)


    def forward(self, t: Tensor, x: Tensor):
        t = self.freqs * t[..., None]
        temb = torch.cat((t.cos(), t.sin()), dim=-1)
        temb = temb.expand(*x.shape[:1], -1)    

        if self.unet:
            B, C, H, W = x.shape
            xemb = self.input_proj(x)
            xemb = self.res_blocks(xemb, temb)
            xemb = self.output_proj(xemb)
            return xemb
        else:

            # Extract features from the MNIST image using CNN
            batch_size = x.size(0)
            xemb = self.cnn(x)  # Output: [batch_size, 64, 7, 7]
            # Flatten: [batch_size, 64 * 7 * 7]
            xemb = self.cnn_fc(xemb.view(batch_size, -1))  # Map to hidden_dim: [batch_size, hidden_dim]
            # Combine all features
            combined_features = torch.cat([xemb, temb], dim=-1)  # [batch_size, hidden_dim * 3]
            vector_field = self.combine_fc(combined_features)  # [batch_size, image_size * image_size]

            # reshape to same as x
            return vector_field.view(batch_size, self.in_ch, x.size(2), x.size(3))


    def _forward(self, t: Tensor, x: Tensor) -> Tensor:
        out = self.forward(t, x)
        return out, out

    def encode(self, x: Tensor) -> Tensor:
        return odeint(self, x, 0.0, 1.0, phi=self.parameters())

    def decode(self, x: Tensor, t=None) -> Tensor:
        if t is None:
            t = 1.
        xt = odeint_adjoint(
            self,
            x, 
            torch.Tensor([t, 0.]), 
            adjoint_params=itertools.chain(
                # vt.net.parameters(),
                self.parameters()
                ),
            method="dopri5",
            atol=1e-8, rtol=1e-8)[-1]
        return xt

    def decode_with_trajectory(self, x: Tensor, t=None, num_points=100) -> tuple:
    
        if t is None:
            t = 1.
        
        # Create time points
        start_time = 0.
        end_time = t
        time_points = torch.linspace(start_time, end_time, num_points, device=x.device)
        
        # Storage for trajectory
        trajectory = []
        recorded_times = []
        
        # Wrapper for ODE function to collect states
        def ode_func_with_recording(t, state):
            # Record the current state
            trajectory.append(state.detach().clone())
            recorded_times.append(t.item())
            # Call the original function
            return self(t, state)
        
        # Run the ODE solver
        xt = odeint_adjoint(
            ode_func_with_recording, 
            x, 
            time_points, 
            adjoint_params=self.parameters(), 
            atol=1e-8, rtol=1e-8)
        # Convert lists to tensors
        trajectory_tensor = torch.stack(trajectory)
        recorded_times_tensor = torch.tensor(recorded_times)
        # Return final state and the recorded trajectory
        return xt[-1], trajectory_tensor, recorded_times_tensor

    @staticmethod
    def exact_trace(f, y):
        """Exact Jacobian trace"""
        # Check if f.sum() is differentiable with respect to x
        # print("Autograd grad test:", torch.autograd.grad(f.sum(), y, allow_unused=True))  # Should not be None
        dims = y.size()[1:]
        tr_dzdx = 0.0
        dim_ranges = [range(d) for d in dims]
        for idcs in itertools.product(*dim_ranges):
            batch_idcs = (slice(None),) + idcs
            tr_dzdx += torch.autograd.grad(f[batch_idcs].sum(), y, create_graph=True)[0][batch_idcs]
        return tr_dzdx

    @staticmethod
    def hutch_trace(f, y, e):
        """Hutchinson's estimator for the Jacobian trace"""
        # With _eps ~ Rademacher (== Bernoulli on -1 +1 with 50/50 chance).
        e_dzdx = torch.autograd.grad(f, y, e, create_graph=True)[0]
        e_dzdx_e = e_dzdx * e
        approx_tr_dzdx = sum_except_batch(e_dzdx_e)
        return approx_tr_dzdx

    def log_prob(self, x: Tensor, t, source) -> Tensor:
        # I = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device)
        # I = I.expand(*x.shape, x.shape[-1]).movedim(-1, 0)
        x = x.clone().detach().requires_grad_(True)
        e = torch.randint(low=0, high=2, size=x.size()).to(x.device) * 2 - 1
        def augmented(t: Tensor, state) -> Tensor:
            x, adj = state
            with torch.enable_grad():
                x.requires_grad_(True)
                dx = self(t, x)
            # I = torch.ones_like(dx, device=x.device)
            # jacobian = torch.autograd.grad(dx, x, I, create_graph=True, is_grads_batched=True)[0]
            # trace = torch.einsum('i...i', jacobian)
                # trace = self.exact_trace(dx, x)
                trace = self.hutch_trace(dx,x,e)
            return dx, trace * 1e-3

        # ladj = torch.zeros_like(x[..., 0])
        ladj = x.new_zeros(x.shape[0])
        x0, ladj = odeint_adjoint(
            augmented, 
            (x, ladj), 
            torch.Tensor([t, 1.0]), 
            adjoint_params=self.parameters(), 
            atol=1e-7, rtol=1e-7)
        # print("logprob x", source.log_prob(x0[-1]).shape)
        # print("log_prob", ladj[-1].shape)
        # independant dimension over each pixel and channel
        return source.log_prob(x0[-1]).sum(dim=(1, 2, 3)) + ladj[-1] * 1e3



class FlowMatchingLoss_marginal(nn.Module):
    # fix z distribution and vary pi distirbution
    def __init__(self, vt: nn.Module, prior, sig_min=1e-4, beta=0.1, alpha=0.1, eps=1e-8):
        super().__init__()

        self.vt = vt
        # self.xemb = xemb
        self.prior = prior
        self.sig_min = sig_min
        self.eps = eps

   
    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        nbatch, C, H, W = x.shape

        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t

        x1 = self.prior.sample((len(x),)).to(x.device)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
        # xt = inv_transform(yt) # between 0 and 1, more stable
        # xt = torch.sigmoid(yt)
        ut = ((1 - self.sig_min) * x1 - x)
        vt = self.vt(_t, xt)

        # print("reg", reg:q.mean())
        fm_loss = (vt - ut).square().mean(-1)
       
        loss = fm_loss.mean()
        return loss # + 0.01* (vt.square()).mean()# + alpha * 0.001 * reg_loss

class FlowMatchingLoss_marginal_LDS(nn.Module):
    # fix z distribution and vary pi distirbution
    def __init__(self, vt: nn.Module, sig_min=1e-8, beta=0.1, alpha=0.1, eps=1e-8):
        super().__init__()

        self.vt = vt
        # self.xemb = xemb
        self.sig_min = sig_min
        self.eps = eps

   
    def forward(self, x: Tensor, eps=1e-8, hard=False) -> Tensor:
        S, nbatch, C, H, W = x.shape

        _t = torch.rand(1).to(x.device)
        t = torch.ones_like(x[..., 0, None]) * _t

        x1 = torch.randn(S, nbatch, C, H, W, device=x.device)
        xt = (1 - t) * x + (self.sig_min + (1 - self.sig_min) * t) * x1
        ut = ((1 - self.sig_min) * x1 - x)
        vt = self.vt(_t, xt)

        fm_loss = (vt - ut).square().mean(-1)
       
        loss = fm_loss.mean()
        return loss # + 0.01* (vt.square()).mean()# + alpha * 0.001 * reg_loss




class GRUVelocityModel(nn.Module):
    def __init__(self, seq_length, n_channels, img_size,
                 num_time_frequencies=16, # Number of frequencies for sinusoidal embedding
                 frame_feature_dim=64,
                 rnn_hidden_dim=128,
                 num_rnn_layers=1,
                 **kwargs):
        super().__init__()
        self.seq_length = seq_length # S
        self.n_channels = n_channels # C
        self.img_size = img_size     # H, W
        self.frame_input_dim = n_channels * img_size * img_size
        self.num_time_frequencies = num_time_frequencies

        # 1. Sinusoidal Time Embedding Frequencies
        # freqs will have shape (num_time_frequencies)
        self.register_buffer('freqs', torch.arange(1, num_time_frequencies + 1) * torch.pi)
        
        actual_time_embed_dim = 2 * num_time_frequencies

        # 2. Per-frame feature encoder (MLP for flattened frames)
        self.frame_encoder = nn.Sequential(
            nn.Linear(self.frame_input_dim, frame_feature_dim * 2),
            nn.Softplus(),
            nn.Linear(frame_feature_dim * 2, frame_feature_dim)
        )


        # 3. RNN (GRU)
        # batch_first=False is default, expects input (seq_len, batch, feature_dim)
        self.rnn = nn.GRU(input_size=frame_feature_dim + actual_time_embed_dim,
                          hidden_size=rnn_hidden_dim,
                          num_layers=num_rnn_layers,
                          batch_first=False)

        # 4. Output MLP to predict vector field for each frame
        self.output_mlp = MLP(rnn_hidden_dim, self.frame_input_dim, **kwargs)

    def _encode_time_sinusoidal(self, t_flow_b):
        # t_flow_b: (B, 1)
        # self.freqs: (num_time_frequencies)
        # scaled_time: (B, num_time_frequencies) via broadcasting
        scaled_time = t_flow_b * self.freqs
        
        # t_emb_b: (B, 2 * num_time_frequencies)
        t_emb_b = torch.cat((torch.cos(scaled_time), torch.sin(scaled_time)), dim=-1)
        return t_emb_b

    def forward(self, t_flow_b, x_sequence_sb):
        # x_sequence_sb: (S, B, C, H, W) - Sequence length first
        # t_flow_b: (B, 1) - Batch of time scalars
        # print(x_sequence_sb.shape)
        S, B, C, H, W = x_sequence_sb.shape

        x_frames_flattened_sb = x_sequence_sb.reshape(S * B, self.frame_input_dim)
        frame_features_flat_sb = self.frame_encoder(x_frames_flattened_sb) # (S*B, frame_feature_dim)
        frame_features_sb = frame_features_flat_sb.reshape(S, B, -1) # (S, B, frame_feature_dim)

        # a. Sinusoidal time embedding
        t_emb_b = self._encode_time_sinusoidal(t_flow_b) # (B, actual_time_embed_dim)
        t_emb_expanded_sb = t_emb_b.unsqueeze(0).expand(S, B, -1)
        rnn_input_sb = torch.cat((frame_features_sb, t_emb_expanded_sb), dim=2) # (S, B, frame_feature_dim + actual_time_embed_dim)
        
        # d. Pass through GRU (expects S, B, F)
        rnn_outputs_sb, h_n_b = self.rnn(rnn_input_sb)
        
        # e. Predict vector field from GRU outputs for each frame
        rnn_outputs_flat_sb = rnn_outputs_sb.reshape(S * B, self.rnn.hidden_size)
        predicted_flat_vector_fields_sb = self.output_mlp(rnn_outputs_flat_sb)

        # f. Reshape back to (S, B, C, H, W)
        output_sequence_vector_field_sb = predicted_flat_vector_fields_sb.reshape(S, B, C, H, W)

        return output_sequence_vector_field_sb

    @torch.no_grad()
    def decode(self,
                 x0_sb, # This is B (batch size for generation)
                 t=None,
                 solver_method='dopri5',
                 rtol=1e-5,
                 atol=1e-5,
                 num_eval_points=2,
                 device=None):
        """
        Generates video frame sequences from random noise using this model instance
        by solving the learned ODE with torchdiffeq.
        The returned sequence has S as the leading axis.
        """
        self.eval()

        if t is None:
            t = 1

        if device is None:
            device = next(self.parameters()).device
        else:
            self.to(device)

        B = x0_sb.shape[1]

        def ode_dynamics(t_scalar, current_x_bs):
            # current_x_bs: (B, S, C, H, W)
            t_model_input_b = t_scalar.expand(B, 1)
            current_x_sb = current_x_bs.permute(1, 0, 2, 3, 4)
            predicted_velocity_sb = self.forward(t_model_input_b, current_x_sb)
            return predicted_velocity_sb.permute(1, 0, 2, 3, 4)

        y0_bs = x0_sb.permute(1, 0, 2, 3, 4)
        solution_ebs = odeint(
            ode_dynamics, 
            y0_bs, 
            torch.Tensor([t, 0.]), 
            method=solver_method, 
            rtol=rtol, atol=atol
        )

        return solution_ebs[-1].permute(1, 0, 2, 3, 4)

