import torch
import torch.nn as nn
import einops
from einops.layers.torch import Rearrange
from einops import rearrange
import pdb
from torch.distributions import Bernoulli
import math
from .helpers import (
    SinusoidalPosEmb,
    Downsample1d,
    Upsample1d,
    Conv1dBlock,
)
import transformers
from .GPT2 import GPT2Model
from typing import Any, Tuple

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.InstanceNorm2d(dim, affine = True)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 128):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)


class GlobalMixing(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 128):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)

class QKVAttention(nn.Module):
    def __init__(self, heads = 4) -> None:
        super().__init__()
        
        self.heads = heads
        
    def forward(self, x, encoder_kv):
        q, k, v = rearrange(x, 'b (qkv heads c) h -> qkv (b heads) c h', heads=self.heads, qkv=3)
        encoder_k, encoder_v = rearrange(encoder_kv, 'b (kv heads c) h -> kv (b heads) c h', heads=self.heads, kv=2)
        k = torch.cat([encoder_k, k], dim=-1)
        v = torch.cat([encoder_v, v], dim=-1)
        scale = 1 / math.sqrt(math.sqrt(q.shape[-1]))
        weight = torch.einsum('bct,bcs->bts', q * scale, k * scale)
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        out = torch.einsum("bts,bcs->bct", weight, v)
        out = rearrange(out, '(b heads) c h -> b (heads c) h', heads=self.heads)
        return out

class ResidualAttentionBlock(nn.Module):
    def __init__(self, inp_channels, out_channels, encoder_channels, heads=4, kernel_size=5, mish=True) -> None:
        super().__init__()
        assert out_channels % heads == 0
        
        self.blocks = nn.ModuleList([
            nn.Conv1d(inp_channels, out_channels * 3, 1),
            QKVAttention(heads),
        ])
        
        self.encoder_kv = nn.Conv1d(encoder_channels, out_channels * 2, 1)
        self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
            if inp_channels != out_channels else nn.Identity()
    
    def forward(self, x, cond):
        # print(x.shape, t.shape, cond.shape) #torch.Size([10, 128, 100]) torch.Size([10, 128]) torch.Size([10, 128]) 
        if cond.dim() < 3:
            cond = cond.unsqueeze(-1)
        encoder_kv = self.encoder_kv(cond)
        out = self.blocks[0](x)
        out = self.blocks[1](out, encoder_kv)
        
        return out + self.residual_conv(x)

class ResidualTemporalBlock(nn.Module):

    def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5, mish=True):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(inp_channels, out_channels, kernel_size, mish),
            Conv1dBlock(out_channels, out_channels, kernel_size, mish),
        ])

        if mish:
            act_fn = nn.Mish()
        else:
            act_fn = nn.SiLU()

        self.time_mlp = nn.Sequential(
            act_fn,
            nn.Linear(embed_dim, out_channels),
            Rearrange('batch t -> batch t 1'),
        )

        self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
            if inp_channels != out_channels else nn.Identity()

    def forward(self, x, t):
        '''
            x : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]
            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x) + self.time_mlp(t)
        out = self.blocks[1](out)

        return out + self.residual_conv(x)

class TemporalUnet(nn.Module):

    def __init__(
        self,
        horizon, # 100
        transition_dim, # observation_dim
        cond_dim, # observation_dim
        dim=128, # 128
        dim_mults=(1, 4, 8), # (1,4,8)
        returns_condition=True, # True
        condition_dropout=0.25, # 0.25
        calc_energy=False,
        kernel_size=5,
    ):
        super().__init__()

        dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]# 第一层是obs_dim, 第二层是128*1， 第三层是128*4，第四层是128*8
        in_out = list(zip(dims[:-1], dims[1:])) # 用zip形成元组，然后组成list, [(obs_dim,128), (128, 128*4),(128*4, 128*8)]
        print(f'[ models/temporal ] Channel dimensions: {in_out}')

        mish = True
        act_fn = nn.Mish() # f(x) = x * tanh(softplus(x)), softplus(x) = log(1 + exp(x))

        self.time_dim = dim # 128
        self.returns_dim = dim # 128

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim), # 先做正弦位置编码
            nn.Linear(dim, dim * 4),
            act_fn,
            nn.Linear(dim * 4, dim),
        )

        self.returns_condition = returns_condition # True
        self.condition_dropout = condition_dropout # 0.25
        self.calc_energy = calc_energy # False

        if self.returns_condition:
            self.returns_mlp = nn.Sequential(
                        nn.Linear(16, dim),
                        act_fn,
                        nn.Linear(dim, dim * 4),
                        act_fn,
                        nn.Linear(dim * 4, dim),
                    )
            self.mask_dist = Bernoulli(probs=1-self.condition_dropout)
            embed_dim = 2*dim # 2*128
        else:
            embed_dim = dim

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out) # 3

        # print(in_out) # [(obs_dim,128), (128, 128*4),(128*4, 128*8)]
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                ResidualTemporalBlock(dim_in, dim_out, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
                ResidualTemporalBlock(dim_out, dim_out, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

            if not is_last:
                horizon = horizon // 2

        mid_dim = dims[-1]
        self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish)
        self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(nn.ModuleList([
                ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
                ResidualTemporalBlock(dim_in, dim_in, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

            if not is_last:
                horizon = horizon * 2

        self.final_conv = nn.Sequential(
            Conv1dBlock(dim, dim, kernel_size=kernel_size, mish=mish),
            nn.Conv1d(dim, transition_dim, 1),
        )

    @torch.jit.ignore
    def init_mask_dist(self, sample_shape: Tuple[int, int]):
        return Bernoulli(probs=1-self.condition_dropout).sample(sample_shape)

    def forward(self, x, cond, time, returns=None, use_dropout: bool=True, force_dropout: bool=False):
        '''
            x : [ batch x horizon x transition ]
            returns : [batch x horizon]
        '''
        x = x.permute(0, 2, 1)

        t = self.time_mlp(time)

        if self.returns_condition:
            assert returns is not None
            returns_embed = self.returns_mlp(returns)
            if use_dropout:
                mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
                # mask = self.init_mask_dist((returns_embed.size(0), 1))
                returns_embed = mask*returns_embed
            if force_dropout:
                returns_embed = 0*returns_embed
            t = torch.cat([t, returns_embed], dim=-1)

        h = []

        for resnet, resnet2, downsample in self.downs:
        # for nets in self.downs:
        #     resnet = nets[0]
        #     resnet2 = nets[1]
        #     downsample = nets[2]
            x = resnet(x, t)
            x = resnet2(x, t)
            h.append(x)
            x = downsample(x)
        x = self.mid_block1(x, t)
        x = self.mid_block2(x, t)

        for resnet, resnet2, upsample in self.ups:
        # for nets in self.ups:
        #     resnet = nets[0]
        #     resnet2 = nets[1]
        #     upsample = nets[2]
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, t)
            x = resnet2(x, t)
            x = upsample(x)

        x = self.final_conv(x)

        x = x.permute(0, 2, 1)

        return x

class AttTemporalUnet(nn.Module):

    def __init__(
        self,
        horizon, # 100
        transition_dim, # observation_dim
        dim=128, # 128
        dim_mults=(1, 2, 4, 8), # (1,4,8)
        # dim_mults=(1, 2, 4), # (1,4,8)
        returns_condition=True, # True
        condition_dropout=0.1, # 0.25
        kernel_size=5,
    ):
        super().__init__()

        dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]# 第一层是obs_dim, 第二层是128*1， 第三层是128*4，第四层是128*8
        in_out = list(zip(dims[:-1], dims[1:])) # 用zip形成元组，然后组成list, [(obs_dim,128), (128, 128*4),(128*4, 128*8)]
        print(f'[ models/temporal ] Channel dimensions: {in_out}')

        mish = True
        act_fn = nn.Mish() # f(x) = x * tanh(softplus(x)), softplus(x) = log(1 + exp(x))

        self.time_dim = dim # 128
        embed_dim = 2*dim # 128

        self.use_ln = True

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim), # 先做正弦位置编码
            nn.Linear(dim, dim * 4),
            act_fn,
            nn.Linear(dim * 4, dim),
            nn.LayerNorm(dim) if self.use_ln else nn.Identity(),
        )

        self.returns_condition = returns_condition # True
        self.condition_dropout = condition_dropout # 0.25

        self.returns_mlp = nn.Sequential(
                        # nn.LayerNorm(16) if self.use_ln else nn.Identity(),
                        nn.Linear(16, dim),
                        act_fn,
                        nn.Linear(dim, dim * 4),
                        act_fn,
                        nn.Linear(dim * 4, dim),
                        nn.LayerNorm(dim) if self.use_ln else nn.Identity(),
                    )
        self.mask_dist = Bernoulli(probs=1-self.condition_dropout)

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out) # 3

        # print(in_out) # [(obs_dim,128), (128, 128*4),(128*4, 128*8)]
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                ResidualTemporalBlock(dim_in, dim_out, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
                ResidualAttentionBlock(dim_out, dim_out, encoder_channels=embed_dim),
                ResidualAttentionBlock(dim_out, dim_out, encoder_channels=embed_dim),
                ResidualTemporalBlock(dim_out, dim_out, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

            if not is_last:
                horizon = horizon // 2

        mid_dim = dims[-1]
        self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish)
        self.mid_block2 = ResidualAttentionBlock(mid_dim, mid_dim, embed_dim)
        self.mid_block3 = ResidualAttentionBlock(mid_dim, mid_dim, embed_dim)
        self.mid_block4 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(nn.ModuleList([
                ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
                ResidualAttentionBlock(dim_in, dim_in, encoder_channels=embed_dim),
                ResidualAttentionBlock(dim_in, dim_in, encoder_channels=embed_dim),
                ResidualTemporalBlock(dim_in, dim_in, embed_dim=embed_dim, horizon=horizon, kernel_size=kernel_size, mish=mish),
                Upsample1d(dim_in) if not is_last else nn.Identity()
            ]))

            if not is_last:
                horizon = horizon * 2

        self.final_conv = nn.Sequential(
            Conv1dBlock(dim, dim, kernel_size=kernel_size, mish=mish),
            nn.Conv1d(dim, transition_dim, 1),
        )

        self.ln = nn.LayerNorm(transition_dim)

    def forward(self, x, cond, time, returns, use_dropout=True, force_dropout=False):
        '''
            x : [ batch x horizon x transition ]
            returns : [batch x horizon]
        '''
        if self.use_ln:
            x = self.ln(x)
        x = einops.rearrange(x, 'b h t -> b t h')
        # print(cond.shape, time.shape, returns.shape)

        t = self.time_mlp(time)

        returns_embed = self.returns_mlp(returns)
        if use_dropout:
            mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
            returns_embed = mask*returns_embed
        if force_dropout:
            returns_embed = 0*returns_embed
        # print(returns_embed.shape, returns_embed.max(), returns_embed.min())

        t = torch.cat([t, returns_embed], dim=-1)
        # t = torch.mul(t, returns_embed)
        h = []

        # for resnet, attention, resnet2, downsample in self.downs:
        for resnet, attention, attention2, resnet2, downsample in self.downs:
        # for resnet, attention, downsample in self.downs:
        # for resnet, resnet2, downsample in self.downs:
            x = resnet(x, t)
            x = attention(x, t)
            x = attention2(x, t)
            x = resnet2(x, t)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_block2(x, t)
        x = self.mid_block3(x, t)
        x = self.mid_block4(x, t)


        # for resnet, attention,  resnet2, upsample in self.ups:
        for resnet, attention, attention2, resnet2, upsample in self.ups:
        # for resnet, attention, upsample in self.ups:
        # for resnet, resnet2, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, t)
            x = attention(x, t)
            x = attention2(x, t)
            x = resnet2(x, t)
            x = upsample(x)

        x = self.final_conv(x)
        x = einops.rearrange(x, 'b t h -> b h t')

        return x

class TransformerNoise(nn.Module):

    def __init__(
        self,
        horizon, # 100
        obs_dim, # observation_dim
        z_dim,
        dim=128, # 128
        hidden_size=128,
    ):
        super().__init__()
        self.dim = dim
        self.horizon = horizon
        self.hidden_size = hidden_size
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size,
            n_layer=4,
            n_head=2,
            n_inner=4 * 256,
            activation_function='mish',
            n_positions=1024,
            n_ctx=1023,
            resid_pdrop=0.1,
            attn_pdrop=0.1,
        )
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, dim * 4),
            nn.Mish(),
            nn.Linear(dim * 4, dim),
            # nn.LayerNorm(dim),
        )
        self.return_mlp = nn.Sequential(
            nn.Linear(z_dim, dim),
            nn.Mish(),
            nn.Linear(dim, 4*dim),
            nn.Mish(),
            nn.Linear(dim*4, dim),
            # nn.LayerNorm(dim),
        )

        self.mask_dist = Bernoulli(probs=0.8)
        self.transformer = GPT2Model(config)
        self.embed_ln = nn.LayerNorm(horizon+2)
        self.embed_x = ResidualTemporalBlock(obs_dim, dim, embed_dim=dim, horizon=horizon, kernel_size=5, mish=True)
        self.final1 = torch.nn.Linear(dim, obs_dim)
        self.final2 = torch.nn.Linear(horizon+2, horizon)
        # self.x_ln = nn.LayerNorm(obs_dim)
        
        self.position_emb = nn.Parameter(torch.zeros(1, 24+horizon, hidden_size))

    def forward(self, x, cond, time, repre, use_dropout=True, force_dropout=False):
        t = self.time_mlp(time) # [64, 128]
        batch_size, seq_length = x.shape[0], x.shape[1]
        repre = self.return_mlp(repre) # [64, 128]
        if use_dropout:
            mask = self.mask_dist.sample(sample_shape=(repre.shape[0], 1)).to(x.device)
        else:
            mask = 1
        if force_dropout:
            mask = 0
        repre = ((repre * mask) + 1e-8).unsqueeze(-1)
        # x = self.x_ln(x)
        x = einops.rearrange(x, 'b h t -> b t h')
        x_embed = self.embed_x(x, t) # [64, 128, 100]
        addition_attention_mask = torch.ones((batch_size, 2), dtype=torch.long, device=x.device)
        attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long, device=x.device)
        stacked_inputs = torch.cat((t.unsqueeze(-1), repre, x_embed), dim=-1) # [64, 128, 102]
        stacked_attention_mask = torch.cat((addition_attention_mask, attention_mask), dim=-1) # [64, 102]
        stacked_inputs = self.embed_ln(stacked_inputs) # [64, 128, 102]
        stacked_inputs = einops.rearrange(stacked_inputs, 'b h t -> b t h') # [64, 102, 128]
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs, attention_mask=stacked_attention_mask)
        x = transformer_outputs['last_hidden_state'] # [64, 102, 128]
        x = self.final1(x) # [64, 102, 11]
        x = einops.rearrange(x, 'b h t -> b t h') # [64, 11, 102]
        noise = self.final2(x)  # [64, 11, 100]
        noise = einops.rearrange(noise, 'b h t -> b t h') # [64, 128, 11]
        return noise