import time
from timm.layers import Mlp
import math
import torch
import torch.nn as nn
from einops import rearrange
import numpy as np
from functools import partial

from .MlpResNet import MlpResNet, MPMlpResNet, MpScaleMlpResNet, MpScaleMlpResNet_WithDis, DDPMScaleMlpResNet

class LearnedPosEmb(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.output_size = output_size
        self.kernel = nn.Parameter(torch.randn(output_size // 2, input_size) * 0.2)
            
    def forward(self, x):
        f = 2 * torch.pi * x @ self.kernel.T
        f = torch.cat([f.cos(), f.sin()], axis=-1)
        return f


def sum_flat(x):
    return torch.sum(x, dim=list(range(1, len(x.size()))))


def stopgrad(x):
    return x.detach()


def adaptive_l2_loss(error, gamma=0.5, c=1e-3):

    delta_sq = torch.mean(error ** 2, dim=tuple(range(1, error.ndim)))
    p = 1.0 - gamma
    w = 1.0 / (delta_sq + c).pow(p)
    loss = delta_sq
    return (stopgrad(w) * loss).mean()

def cosine_beta_schedule(timesteps, s=0.008):

    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)


def extract(a, x_shape):
    b, *_ = a.shape
    return a.reshape(b, *((1,) * (len(x_shape) - 1)))

class MPScalseFlowhead(nn.Module):
    def __init__(
            self,
            action_size: int = 16,
            num_blocks: int = 3,
            input_dim: int = 256,
            hidden_dim: int = 256,
            time_dim: int = 32,
            time_hidden_dim: int = 256
    ):
        super().__init__()
        self.time_hidden_dim = time_hidden_dim
        self.action_size = action_size

        self.model = MpScaleMlpResNet(num_blocks=num_blocks,
                               input_dim=input_dim + action_size + time_hidden_dim,
                               hidden_dim=hidden_dim, output_size=action_size,
                               time_dim=time_dim,time_hidden_dim=time_hidden_dim)

        self.device='cuda'

        self.flow_ratio = 0.50
        self.time_dist = ['lognorm', -0.4, 1.0]
        self.cfg_ratio = 0.10
        cfg_scale = 2.0

        self.cfg_uncond = 'u'
        self.w = cfg_scale

    def sample_t_r(self, batch_size, device):
        if self.time_dist[0] == 'uniform':
            samples = np.random.rand(batch_size, 2).astype(np.float32)
        elif self.time_dist[0] == 'lognorm':
            mu, sigma = self.time_dist[-2], self.time_dist[-1]
            normal_samples = np.random.randn(batch_size, 2).astype(np.float32) * sigma + mu
            samples = 1 / (1 + np.exp(-normal_samples))


        t_np = np.maximum(samples[:, 0], samples[:, 1])
        r_np = np.minimum(samples[:, 0], samples[:, 1])

        num_selected = int(self.flow_ratio * batch_size)
        indices = np.random.permutation(batch_size)[:num_selected]
        r_np[indices] = t_np[indices]

        t = torch.tensor(t_np, device=device)
        r = torch.tensor(r_np, device=device)
        return t, r

    def forward(self, cur_action, condition):
        batch_size, _, _ = cur_action.shape

        x = cur_action
        t, r = self.sample_t_r(batch_size, self.device)
        t_ = rearrange(t, "b -> b 1 1")
        r_ = rearrange(r, "b -> b 1 1")
        e = torch.randn_like(x)

        z = (1 - t_) * x + t_ * e
        v = e - x

        if self.w is not None:
            with torch.no_grad():
                u_t = self.model(
                            sample=z,
                            timestep=t,
                            global_cond=condition,
                            r=t)
            v_hat = self.w * v + (1 - self.w) * u_t
        else:
            v_hat = v

        model_partial = partial(self.model, global_cond=condition)
        u, dudt = torch.autograd.functional.jvp(
            lambda z, t, r: model_partial(sample=z, timestep=t, r=r),
            # model,
            (z, t, r),
            (v_hat, torch.ones_like(t), torch.zeros_like(r)),
            create_graph=True
        )

        u_tgt = v_hat - (t_ - r_) * dudt

        error = u - stopgrad(u_tgt)
        loss = adaptive_l2_loss(error)

        mse_val = (stopgrad(error) ** 2).mean()

        loss_dict = {
            'bc_loss': loss.item(),
            'mse_val': mse_val.item()
        }

        return loss, loss_dict

    def generate(self, cond_data, sample):

        t = torch.ones((cond_data.shape[0],), device=cond_data.device)
        r = torch.zeros((cond_data.shape[0],), device=cond_data.device)

        sample = sample - self.model(sample=sample,
                                     timestep=t,
                                     global_cond=cond_data,
                                     r=r,
                                     )

        return sample


