from torch import nn
from torch.nn import functional as F
from Model.ViT import ViT
from Model.Resnet import ResnetEncoder, ResnetDecoder
from Model.VQ import SequenceGaussianVectorQuantizer, SequenceFSQ, SequenceVectorQuantizer
from .types_ import *
from prettytable import PrettyTable


def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params


def create_model(config):
    if config.model_type == 'vit_bump':
        model = ViT_Bump(embedding_dim=config.embedding_dim,
                         num_embeddings=config.num_embeddings,
                         beta=config.beta,
                         img_height=config.img_height,
                         img_width=config.img_width,
                         dim=config.dim,
                         depth=config.depth,
                         p_in_encoder=config.p_in_encoder,
                         p_in_decoder=config.p_in_decoder,
                         compress_rate=config.compress_rate,
                         sigma=getattr(config, 'sigma', 0.01)  # 默认sigma=0.01
                         ).to(config.device)
    elif config.model_type == 'hybird_bump':
        model = Hybird_Bump(in_channels=3,
                            embedding_dim=config.embedding_dim,
                            num_embeddings=config.num_embeddings,
                            beta=config.beta,
                            residual_blocks=config.residual_blocks,
                            hidden_dims=config.hidden_dims,
                            img_height=config.img_height,
                            img_width=config.img_width,
                            sigma=getattr(config, 'sigma', 0.01)  # 默认sigma=0.01
                            ).to(config.device)
    elif config.model_type == 'resnet_bump':
        model = Resnet_Bump(in_channels=3,
                            embedding_dim=config.embedding_dim,
                            num_embeddings=config.num_embeddings,
                            beta=config.beta,
                            residual_blocks=config.residual_blocks,
                            hidden_dims=config.hidden_dims,
                            img_height=config.img_height,
                            img_width=config.img_width,
                            sigma=getattr(config, 'sigma', 0.01)  # 默认sigma=0.01
                            ).to(config.device)
    elif config.model_type == 'vit_fsq':
        model = Vit_FSQ(d_out=config.d_out,
                        beta=config.beta,
                        img_height=config.img_height,
                        img_width=config.img_width,
                        dim=config.dim,
                        depth=config.depth,
                        p_in_encoder=config.p_in_encoder,
                        device=config.device).to(config.device)
    elif config.model_type == 'vit_vq':
        model = ViT_VQ(embedding_dim=config.embedding_dim,
                       num_embeddings=config.num_embeddings,
                       beta=config.beta,
                       img_height=config.img_height,
                       img_width=config.img_width,
                       dim=config.dim,
                       depth=config.depth,
                       p_in_encoder=config.p_in_encoder,
                       p_in_decoder=config.p_in_decoder,
                       compress_rate=config.compress_rate,
                       ).to(config.device)
    count_parameters(model)
    return model


class Resnet_Bump(nn.Module):
    def __init__(self, in_channels: int, embedding_dim: int, num_embeddings: int,
                 hidden_dims: List = None, beta: float = 0.1,
                 img_height: int = 32, img_width: int = 80,
                 residual_blocks: int = 6, sigma: float = 0.01, ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.img_height = img_height
        self.img_width = img_width
        self.beta = beta
        self.sigma = sigma

        self.resnet_encoder = ResnetEncoder(in_channels, hidden_dims, embedding_dim, residual_blocks)
        self.resnet_decoder = ResnetDecoder(embedding_dim, hidden_dims, residual_blocks)

        self.feature_height = img_height // (2 ** len(hidden_dims))
        self.feature_width = img_width // (2 ** len(hidden_dims))
        feature_dim = embedding_dim * self.feature_height * self.feature_width

        self.pre_fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(feature_dim, feature_dim),
            nn.LeakyReLU(),
        )

        self.post_fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(feature_dim, feature_dim),
            nn.LeakyReLU(),
        )

        self.vq_layer = SequenceGaussianVectorQuantizer(num_embeddings, embedding_dim, sigma)

    def encoder(self, x: Tensor) -> Tensor:
        B, S, C, H, W = x.shape
        x = x.view(-1, C, H, W)  # [B*S, C, H, W]
        x = self.resnet_encoder(x)  # [B*S, C2, H2, W2]
        _, C2, H2, W2 = x.shape
        res = self.pre_fc(x)
        res = res.view(-1, C2, H2, W2)
        x = x + res
        x = x.view(B, S, C2, H2, W2)  # [B, S, C2, H2, W2]
        return x

    def decoder(self, x: Tensor) -> Tensor:
        B, S, C2, H2, W2 = x.shape
        x = x.view(-1, C2, H2, W2)  # [B*S, C2, H2, W2]
        res = self.post_fc(x)
        res = res.view(-1, C2, H2, W2)
        x = x + res
        x = self.resnet_decoder(x)  # [B*S, C, H, W]
        _, C, H, W = x.shape
        x = x.view(B, S, C, H, W)  # [B, S, C, H, W]
        return x

    def forward(self, input: Tensor, label: Tensor) -> List[Tensor]:
        x = self.encoder(input)
        x_q, vq_loss = self.vq_layer(x, label)  # [B, S, C2, H2, W2]
        x = self.decoder(x_q)
        return [x, input, vq_loss]

    def loss_function(self, *args) -> dict:
        recons, input, vq_loss = args
        recons_loss = F.mse_loss(recons, input)
        loss = recons_loss + vq_loss * self.beta
        return {'loss': loss, 'Reconstruction_Loss': recons_loss, 'VQ_Loss': vq_loss}

    def generate(self, obs_init: Tensor, pos_init: Tensor, pos_new: Tensor) -> Tensor:
        x = self.encoder(obs_init)
        x_q = self.vq_layer.test_forward(x, pos_init, pos_new)
        x = self.decoder(x_q)
        return x


class Hybird_Bump(nn.Module):
    def __init__(self, in_channels: int, embedding_dim: int, num_embeddings: int,
                 hidden_dims: List = None, beta: float = 0.1,
                 img_height: int = 32, img_width: int = 80,
                 residual_blocks: int = 6, sigma: float = 0.01, ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.img_height = img_height
        self.img_width = img_width
        self.beta = beta
        self.sigma = sigma

        if hidden_dims is None:
            hidden_dims = [256, 256]

        self.feature_height = img_height // (2 ** len(hidden_dims))
        self.feature_width = img_width // (2 ** len(hidden_dims))

        self.resnet_encoder = ResnetEncoder(in_channels, hidden_dims, embedding_dim, residual_blocks)
        self.resnet_decoder = ResnetDecoder(embedding_dim, hidden_dims, residual_blocks)

        dim = 1024
        depth = 2
        p_in = 5

        mlp_dim = dim * 4
        dim_head = 64
        heads = dim // dim_head
        self.vit_encoder = ViT(image_size=(self.feature_height, self.feature_width), p_in=p_in, p_out=p_in,
                               dim=dim,
                               depth=depth, heads=heads, mlp_dim=mlp_dim,
                               dim_in=embedding_dim, dim_out=embedding_dim, dim_head=dim_head)

        self.vit_decoder = ViT(image_size=(self.feature_height, self.feature_width), p_in=p_in, p_out=p_in, dim=dim,
                               depth=depth, heads=heads, mlp_dim=mlp_dim,
                               dim_in=embedding_dim, dim_out=embedding_dim, dim_head=dim_head)

        self.vq_layer = SequenceGaussianVectorQuantizer(num_embeddings, embedding_dim, sigma)

    def encoder(self, x: Tensor) -> Tensor:
        B, S, C, H, W = x.shape
        x = x.view(-1, C, H, W)  # [B*S, C, H, W]
        x = self.resnet_encoder(x)  # [B*S, C2, H2, W2]
        x = self.vit_encoder(x)  # [B*S, C2, H2, W2]
        _, C2, H2, W2 = x.shape
        x = x.view(B, S, C2, H2, W2)  # [B, S, C2, H2, W2]
        return x

    def decoder(self, x: Tensor) -> Tensor:
        B, S, C2, H2, W2 = x.shape
        x = x.view(-1, C2, H2, W2)  # [B*S, C2, H2, W2]
        x = self.vit_decoder(x)  # [B*S, C2, H2, W2]
        x = self.resnet_decoder(x)  # [B*S, C, H, W]
        _, C, H, W = x.shape
        x = x.view(B, S, C, H, W)  # [B, S, C, H, W]
        return x

    def forward(self, input: Tensor, label: Tensor) -> List[Tensor]:
        x = self.encoder(input)
        x_q, vq_loss = self.vq_layer(x, label)  # [B, S, C2, H2, W2]
        x = self.decoder(x_q)
        return [x, input, vq_loss]

    def loss_function(self, *args) -> dict:
        recons, input, vq_loss = args
        recons_loss = F.mse_loss(recons, input)
        loss = recons_loss + vq_loss * self.beta
        return {'loss': loss, 'Reconstruction_Loss': recons_loss, 'VQ_Loss': vq_loss}

    def generate(self, obs_init: Tensor, pos_init: Tensor, pos_new: Tensor) -> Tensor:
        x = self.encoder(obs_init)
        x_q = self.vq_layer.test_forward(x, pos_init, pos_new)
        x = self.decoder(x_q)
        return x


class ViT_Bump(nn.Module):
    def __init__(self, embedding_dim: int, num_embeddings: int,
                 beta: float = 0.1,
                 img_height: int = 32, img_width: int = 80,
                 sigma: float = 0.01, dim=768,
                 depth=3,
                 p_in_encoder=8,
                 p_in_decoder=2, compress_rate=4):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.img_height = img_height
        self.img_width = img_width
        self.beta = beta
        self.sigma = sigma
        self.compress_rate = compress_rate

        mlp_dim = dim * 4
        dim_head = 64
        heads = dim // dim_head

        self.vit_encoder = ViT(image_size=(self.img_height, self.img_width),
                               p_in=p_in_encoder,
                               p_out=p_in_encoder // self.compress_rate,
                               dim=dim,
                               depth=depth, heads=heads, mlp_dim=mlp_dim,
                               dim_in=3, dim_out=embedding_dim, dim_head=dim_head)

        self.vit_decoder = nn.Sequential(
            ViT(image_size=(self.img_height // self.compress_rate, self.img_width // self.compress_rate),
                p_in=p_in_decoder,
                p_out=p_in_decoder * self.compress_rate,
                dim=dim,
                depth=depth, heads=heads, mlp_dim=mlp_dim,
                dim_in=embedding_dim, dim_out=3, dim_head=dim_head)
            , nn.Tanh())

        self.vq_layer = SequenceGaussianVectorQuantizer(num_embeddings, embedding_dim, sigma)

    def encoder(self, x: Tensor) -> Tensor:
        B, S, C, H, W = x.shape
        x = x.view(-1, C, H, W)  # [B*S, C, H, W]
        x = self.vit_encoder(x)  # [B*S, C2, H2, W2]
        _, C2, H2, W2 = x.shape
        x = x.view(B, S, C2, H2, W2)  # [B, S, C2, H2, W2]
        return x

    def decoder(self, x: Tensor) -> Tensor:
        B, S, C2, H2, W2 = x.shape
        x = x.view(-1, C2, H2, W2)  # [B*S, C2, H2, W2]
        x = self.vit_decoder(x)  # [B*S, C2, H2, W2]
        _, C, H, W = x.shape
        x = x.view(B, S, C, H, W)  # [B, S, C, H, W]
        return x

    def forward(self, input: Tensor, label: Tensor) -> List[Tensor]:
        x = self.encoder(input)
        x_q, vq_loss = self.vq_layer(x, label)  # [B, S, C2, H2, W2]
        x = self.decoder(x_q)
        return [x, input, vq_loss]

    def loss_function(self, *args) -> dict:
        recons, input, vq_loss = args
        recons_loss = F.mse_loss(recons, input)
        loss = recons_loss + vq_loss * self.beta
        return {'loss': loss, 'Reconstruction_Loss': recons_loss, 'VQ_Loss': vq_loss}

    def generate(self, obs_init: Tensor, pos_init: Tensor, pos_new: Tensor) -> Tensor:
        x = self.encoder(obs_init)
        x_q = self.vq_layer.test_forward(x, pos_init, pos_new)
        x = self.decoder(x_q)
        return x


class ViT_VQ(nn.Module):
    def __init__(self, embedding_dim: int, num_embeddings: int,
                 beta: float = 0.1,
                 img_height: int = 32, img_width: int = 80,
                 dim=768,
                 depth=3,
                 p_in_encoder=8,
                 p_in_decoder=2, compress_rate=4):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.img_height = img_height
        self.img_width = img_width
        self.beta = beta
        self.compress_rate = compress_rate

        mlp_dim = dim * 4
        dim_head = 64
        heads = dim // dim_head

        self.vit_encoder = ViT(image_size=(self.img_height, self.img_width),
                               p_in=p_in_encoder,
                               p_out=p_in_encoder // self.compress_rate,
                               dim=dim,
                               depth=depth, heads=heads, mlp_dim=mlp_dim,
                               dim_in=3, dim_out=embedding_dim, dim_head=dim_head)

        self.vit_decoder = ViT(image_size=(self.img_height // self.compress_rate, self.img_width // self.compress_rate),
                               p_in=p_in_decoder,
                               p_out=p_in_decoder * self.compress_rate,
                               dim=dim,
                               depth=depth, heads=heads, mlp_dim=mlp_dim,
                               dim_in=embedding_dim, dim_out=3, dim_head=dim_head)

        self.vq_layer = SequenceVectorQuantizer(num_embeddings, embedding_dim)

    def encoder(self, x: Tensor) -> Tensor:
        B, S, C, H, W = x.shape
        x = x.view(-1, C, H, W)  # [B*S, C, H, W]
        x = self.vit_encoder(x)  # [B*S, C2, H2, W2]
        _, C2, H2, W2 = x.shape
        x = x.view(B, S, C2, H2, W2)  # [B, S, C2, H2, W2]
        return x

    def decoder(self, x: Tensor) -> Tensor:
        B, S, C2, H2, W2 = x.shape
        x = x.view(-1, C2, H2, W2)  # [B*S, C2, H2, W2]
        x = self.vit_decoder(x)  # [B*S, C2, H2, W2]
        _, C, H, W = x.shape
        x = x.view(B, S, C, H, W)  # [B, S, C, H, W]
        return x

    def forward(self, input: Tensor, label: Tensor) -> List[Tensor]:
        x = self.encoder(input)
        x_q, vq_loss = self.vq_layer(x, label)  # [B, S, C2, H2, W2]
        x = self.decoder(x_q)
        return [x, input, vq_loss]

    def loss_function(self, *args) -> dict:
        recons, input, vq_loss = args
        recons_loss = F.mse_loss(recons, input)
        loss = recons_loss + vq_loss * self.beta
        return {'loss': loss, 'Reconstruction_Loss': recons_loss, 'VQ_Loss': vq_loss}

    def generate(self, obs_init: Tensor, pos_init: Tensor, pos_new: Tensor) -> Tensor:
        x = self.encoder(obs_init)
        x_q = self.vq_layer.test_forward(x, pos_init, pos_new)
        x = self.decoder(x_q)
        return x


class Vit_FSQ(nn.Module):
    def __init__(self, d_out: int,
                 beta: float = 0.1,
                 img_height: int = 32, img_width: int = 80,
                 dim=768,
                 depth=3,
                 p_in_encoder=8,
                 device='cpu'):
        super().__init__()

        self.d_out = d_out
        self.img_height = img_height
        self.img_width = img_width
        self.beta = beta

        mlp_dim = dim * 4
        dim_head = 64
        heads = dim // dim_head

        self.vit_encoder = ViT(image_size=(self.img_height, self.img_width), p_in=p_in_encoder, p_out=1,
                               dim=dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dim_in=3, dim_out=d_out,
                               dim_head=dim_head, patch_in_out=(True, False))

        self.vit_decoder = ViT(image_size=(self.img_height, self.img_width), p_in=p_in_encoder, p_out=p_in_encoder,
                               dim=dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dim_in=d_out,
                               dim_out=3, dim_head=dim_head, patch_in_out=(False, True))
        num = img_height * img_width // (p_in_encoder) ** 2
        self.vq_layer = SequenceFSQ(num=num, d=d_out, a_dim=[num * d_out // 8] * 4, device=device)

    def encoder(self, x: Tensor) -> Tensor:
        B, S, C, H, W = x.shape
        x = x.view(-1, C, H, W)  # [B*S, C, H, W]
        x = self.vit_encoder(x)  # [B*S, num, d_out]
        _, num, d_out = x.shape
        x = x.view(B, S, num, d_out)  # [B, S, C2, H2, W2]
        return x

    def decoder(self, x: Tensor) -> Tensor:
        B, S, n_rot, _ = x.shape
        x = x.view(B * S, -1, self.d_out)
        x = self.vit_decoder(x)  # [B*S, C, H, W]
        _, C, H, W = x.shape
        x = x.view(B, S, C, H, W)  # [B, S, C, H, W]
        return x

    def forward(self, input: Tensor, label: Tensor) -> List[Tensor]:
        x = self.encoder(input)
        x_q, x, vq_loss = self.vq_layer(x, label)  # [B, S, C2, H2, W2]
        x_recon = self.decoder(x_q)
        return [x_recon, input, vq_loss]

    def loss_function(self, *args) -> dict:
        recons, input, vq_loss = args
        recons_loss = F.mse_loss(recons, input)
        loss = recons_loss + vq_loss * self.beta
        return {'loss': loss, 'Reconstruction_Loss': recons_loss, 'VQ_Loss': vq_loss}

    def generate(self, obs_init: Tensor, pos_init: Tensor, pos_new: Tensor) -> Tensor:
        x = self.encoder(obs_init)
        x_q = self.vq_layer.test_forward(x, pos_init, pos_new)
        x = self.decoder(x_q)
        return x
