import math
from pathlib import Path
from random import random
from functools import partial
from multiprocessing import cpu_count
from typing import List, Callable

import torch
from torch import nn, einsum
from torch.special import expm1
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torch.optim import Adam
from torchvision import transforms as T, utils

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange


class FITBlock(nn.Module):
    """
    An FIT block with context conditioning through an additional cross attention before reading patches
    See https://arxiv.org/abs/2305.12689 for more details.
    """
    def __init__(
        self,
        dim: int,
        dim_latent: int,
        dim_latent_conditioning: int,
        dim_context: int,
        groups_count: List[int],
        group_patches: List[int],
        group_patches_total: int,
        groups_count_total: int,
        latents_per_group: int,
        local_layers_count: int,
        global_layers_count: int,
        use_camera_conditioning: bool = False,
    ):
        """
        :param dim: the number of input and output dimensions
        :param dim_latent: the number of dimensions of the latents and of latent conditionings
        :param dim_latent_conditioning: the number of dimensions of the latent conditioning
        :param dim_context: the number of dimensions of the context information, eg. text embeddings
        :param groups_count: the total number of groups in (time, height, width)
        :param group_patches: the size of each group in patches in (time, height, width)
        :param group_patches_total: the number of patches in each group
        :param groups_count_total: the number of groups
        :param latents_per_group: the number of latents per group
        :param local_layers_count: the number of local layers per block
        :param global_layers_count: the number of global layers per block
        """

        super().__init__()
        self.dim = dim
        self.dim_latent = dim_latent
        self.dim_latent_conditioning = dim_latent_conditioning
        self.dim_context = dim_context
        self.groups_count = groups_count
        self.group_patches = group_patches
        self.group_patches_total = group_patches_total
        self.groups_count_total = groups_count_total
        self.latents_per_group = latents_per_group
        self.local_layers_count = local_layers_count
        self.global_layers_count = global_layers_count
        self.use_camera_conditioning = use_camera_conditioning

        # Add controlnet
        self.use_patch_latents_controlnet = block_config.get("use_patch_latents_controlnet", False)
        linear_layer_module_camera_control_kwargs = {}

        # No need to read without global layers
        if self.global_layers_count > 0:
            # Read operation
            self.latents_attend_to_patches = Attention(dim_latent, dim_context = dim, norm=True, norm_context=True)
            self.latents_cross_attn_ff = FeedForward(dim_latent, drop_units=0.0)

            # Add camera fine-tuning
            if self.use_camera_conditioning and self.use_patch_latents_controlnet:
                self.latents_attend_to_patches_camera_control = Attention(dim_latent, dim_context = dim, norm=True, norm_context=True)
                self.latents_cross_attn_ff_camera_control = FeedForward(dim_latent, drop_units=0.0)
                self.latents_cross_attn_out_camera_control = torch.nn.Conv1d(dim_latent, dim_latent, kernel_size=1, bias=False)
                # Zero-initialize the output layer.
                for param in self.latents_cross_attn_out_camera_control.parameters():
                    nn.init.zeros_(param)

            # Context read
            if self.dim_context is not None:
                self.latents_attend_to_context = Attention(dim_latent, dim_context = dim_context, norm=True, norm_context=True)
                self.latents_context_attn_ff = FeedForward(dim_latent, drop_units=0.0)


        # No need for local computation if there are no local layers
        if self.local_layers_count > 0:
            # Local layers
            self.local_layers = nn.ModuleList([])
            for _ in range(self.local_layers_count):
                self.local_layers.append(nn.ModuleList([
                    Attention(dim, norm=True,),
                    FeedForward(dim, drop_units=drop_units)
                ]))

        # No need for global computation if there are no global layers
        if self.global_layers_count > 0:
            # Global layers
            self.global_layers = nn.ModuleList([])
            for _ in range(self.global_layers_count):
                self.global_layers.append(nn.ModuleList([
                    Attention(dim_latent, norm=True),
                    FeedForward(dim_latent, drop_units=drop_units)
                ]))
        
        # We write only if there are global layers that did some work on the latents
        if self.global_layers_count > 0:
            # Write operation
            self.patches_attend_to_latents = Attention(dim, dim_context = dim_latent, norm=True, norm_context=True)
            self.patches_cross_attn_ff = FeedForward(dim, drop_units=0.0)

    def forward(self, patches: torch.Tensor, latents: torch.Tensor, conditioning_latents: torch.Tensor, context: torch.Tensor, camera: torch.Tensor = None):
        """
        :param patches (batch_size, patch_count, patch_channels) tensor with patches where patch_count = (tgid hgid wgid tgsize hgsize wgsize)
        :param latents (batch_size, latent_count, latent_channels) tensor with latents
        :param conditioning_latents (batch_size, conditioning_latent_count, latent_conditioning_channels) tensor with latent conditioning information. E.g. diffusion time
        :param context (batch_size, context_channels) tensor with context information. None if context is not present
        :param camera (batch_size, patch_count, patch_channels) tensor with plucker camera embeddings. None if camera is not present
        """
        # Applies the local network
        group_patches = rearrange(patches, 'b (g p) c -> (b g) p c', g=self.groups_count_total, p=self.group_patches_total)
        # Applies local layers only if they are present
        if self.local_layers_count > 0:
            for attn, ff in self.local_layers:
                group_patches = attn(group_patches) + group_patches
                group_patches = ff(group_patches) + group_patches

        # Applies all layers related to global computation
        if self.global_layers_count > 0:
            if self.dim_context is not None:
                latents = self.latents_attend_to_context(latents, context) + latents 
                latents = self.latents_context_attn_ff(latents) + latents

            group_latents = rearrange(latents, 'b (g l) c -> (b g) l c', g=self.groups_count_total, l=self.latents_per_group)
            # Camera conditioning: compute the camera-augmented latents.
            if self.use_camera_conditioning and self.use_patch_latents_controlnet:
                group_camera_latents = group_latents
                group_camera_patches = rearrange(camera, 'b (g p) c -> (b g) p c', g=self.groups_count_total, p=self.group_patches_total)
                group_camera_patches = group_camera_patches + group_patches # [(b g) p c]
                group_camera_latents = self.latents_attend_to_patches_camera_control(group_camera_latents, group_camera_patches) + group_camera_latents # [(b g) p c]
                group_camera_latents = self.latents_cross_attn_ff_camera_control(group_camera_latents) + group_camera_latents # [(b g) p c]
                group_camera_latents = self.latents_cross_attn_out_camera_control(group_camera_latents) # [(b g) p c]
            
            # Each group of latents reads from the corresponding group of patches
            group_latents = self.latents_attend_to_patches(group_latents, group_patches) + group_latents
            group_latents = self.latents_cross_attn_ff(group_latents) + group_latents
            
            # Camera conditioning: apply the camera-augmented latents to the main FIT latents.
            if self.use_camera_conditioning and self.use_patch_latents_controlnet:
                group_latents = group_latents + group_camera_latents
            
            latents = rearrange(group_latents, '(b g) l c -> b (g l) c', g=self.groups_count_total, l=self.latents_per_group)

            # Applies global self attention to all latents
            for attn, ff in self.global_layers:
                latents = attn(latents) + latents
                latents = ff(latents) + latents

            # Each group of patches attends to the respective group of latents
            group_latents = rearrange(latents, 'b (g l) c -> (b g) l c', g=self.groups_count_total, l=self.latents_per_group)
            group_patches = self.patches_attend_to_latents(group_patches, group_latents) + group_patches
            group_patches = self.patches_cross_attn_ff(group_patches) + group_patches

        patches = rearrange(group_patches, '(b g) p c -> b (g p) c', g=self.groups_count_total, p=self.group_patches_total)

        return patches, latents