
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import sys
import yaml


sys.path.append("./")  # change this path to your current work directory
from LFAE.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d

import torch





class AutoEncoder(nn.Module):
    def __init__(self, data_params, model_params,train_params,is_train=True):
        super(AutoEncoder, self).__init__()
        self.data_params = data_params
        self.model_params = model_params
        self.train_params=train_params
        self.encoder = Encoder(num_channels=data_params['num_channels'], block_expansion=model_params['block_expansion'],
                                            num_blocks=model_params['num_down_blocks'], max_features=model_params['max_features'])
        self.decoder = Decoder(num_channels=data_params['num_channels'], num_bottleneck_blocks=model_params['num_bottleneck_blocks'],
                                            block_expansion=model_params['block_expansion'],  num_blocks=model_params['num_up_blocks'], 
                                            max_features=model_params['max_features'],skips=model_params['skips'])
        self.flow_predictor = FlowPredictor(**model_params['flow_predictor_params'])

        self.loss_weights = model_params['loss_weights']
        
        # training
        self.is_train = is_train
        if self.is_train:
            self.encoder.train()
            self.decoder.train()
            self.flow_predictor.train()
            # self.pyramid.train()
            

    def forward(self,  x):
        if self.model_params['flow_predictor_params']['flow_type'] == 'z':
            # x : [batch_size, num_channels, height, width]
            z, _ = self.encoder(x)
            return self.decoder(z)
        x_source = x[:, :, 0, ...]
        x_driving = x[:, :, 1, ...]
        z0, _ = self.encoder(x_source)
        flows= self.flow_predictor(x_source, x_driving)
        z_dri = self.apply_optical(z0, flows)
        pred = self.decoder(z_dri)
        


        return pred
    
    def encode(self, x):
        # x : [batch_size, num_channels, num_frames, height, width]
        b, c, f, h, w = x.shape
        x = x.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)
        out, _ = self.encoder(x)
        _, c, h, w = out.shape
        return out.reshape(b, f, c, h, w).permute(0, 2, 1, 3, 4)
    
    def predict_flow(self, x_source, x_driving):
        # x_source : [batch_size, num_channels, 1, height, width]
        # x_driving : [batch_size, num_channels, num_frames, height, width]
        b, c, f, h, w = x_driving.shape
        x_source = x_source.repeat(1, 1, f, 1, 1)
        x_source = x_source.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)
        x_driving = x_driving.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)
        out = self.flow_predictor(x_source, x_driving)
        _, c, h, w = out.shape
        return out.reshape(b, f, c, h, w).permute(0, 2, 1, 3, 4)
    
    def decode_from_flow(self, z0, flows):
        # z0: [batch_size, num_channels, f, height, width]
        # flows: [batch_size, 2. f, height, width]
        b, c, f, h, w = z0.shape
        z0 = z0.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)
        b, c, f, h, w = flows.shape
        flows = flows.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)
        z = self.apply_optical(z0, flows)
        out = self.decoder(z)
        _, c, h, w = out.shape
        return out.reshape(b, f, c, h, w).permute(0, 2, 1, 3, 4)
    
    def apply_optical(self, z0, flows):
        # z0: [batch_size, num_channels, height, width]
        # flows: [batch_size, 2, height, width ]
        if self.model_params['flow_predictor_params']['flow_type'] == 'dz':
            return z0 + flows
        else: 
            _, _,h_old, w_old= flows.shape
            _, _, h, w = z0.shape
            # print("shapes:",flows.shape,z0.shape)
            if h_old != h or w_old != w:
                # print("not eaqul")
                # flows = flows.permute(0, 3, 1, 2)
                flows = F.interpolate(flows, size=(h, w), mode='bilinear')
                # print(flows.shape)
            flows = flows.permute(0, 2, 3, 1)
            return F.grid_sample(z0.to(torch.float32), flows.to(torch.float32))
        

    
    def mse_loss(self,pyramide_real, pyramide_generated):
        loss = 0
        for scale in self.model_params["scales"]:
            real=pyramide_real['prediction_' + str(scale)]
            generated=pyramide_generated['prediction_' + str(scale)]
            loss += torch.mean((real - generated) ** 2)

        return loss / len(self.scales)


class FlowPredictor(nn.Module):
    def __init__(self, num_channels, block_expansion, num_blocks=3, max_features=256, flow_type='dz'):
        super(FlowPredictor, self).__init__()
        self.flow_type = flow_type
        self.encoder = Encoder(num_channels*2, block_expansion, num_blocks, max_features)
        enc_out_features = min(max_features, block_expansion * (2 ** num_blocks))
        if flow_type == 'flow':
            self.fc = nn.Linear(enc_out_features, 2)


    def forward(self, source_image, driven_image):
        
        # source_image : [batch_size, num_channels, height, width]
        # driven_image : [batch_size, num_channels, height, width]
        out,_ = self.encoder(torch.concat((source_image, driven_image), dim=1))
        if self.flow_type == 'dz':
            return out
        else:
            out = torch.permute(out, (0, 2, 3, 1))
            out = self.fc(out)
            return torch.permute(out, (0, 3, 1, 2))


class Encoder(nn.Module):
    """
    Hourglass Encoder
    """

    def __init__(self, num_channels,block_expansion, num_blocks=3, max_features=256):
        super(Encoder, self).__init__()
        self.first = SameBlock2d(in_features=num_channels, out_features= min(max_features, block_expansion ), kernel_size=(7, 7), padding=(3, 3))

        down_blocks = []
        for i in range(num_blocks):
            in_features = min(max_features, block_expansion * (2 ** i))
            out_features = min(max_features, block_expansion * (2 ** (i + 1)))
            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
        self.down_blocks = nn.ModuleList(down_blocks)


    def forward(self, source_image):
        out = self.first(source_image)
        skips = [out]
        
        for i in range(len(self.down_blocks)):
            out = self.down_blocks[i](out)
            skips.append(out)
        return out,skips


class Decoder(nn.Module):
    """
    Hourglass Decoder
    """

    def __init__(self,num_channels,num_bottleneck_blocks, block_expansion, num_blocks=3, max_features=256,skips=False):
        super(Decoder, self).__init__()
        
        self.bottleneck = torch.nn.Sequential()
        in_features = min(max_features, block_expansion * (2 ** num_blocks))
        for i in range(num_bottleneck_blocks):
            self.bottleneck.add_module('r' + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
        
        up_blocks = []
        for i in range(num_blocks):
            in_features = min(max_features, block_expansion * (2 ** (num_blocks - i)))
            out_features = min(max_features, block_expansion * (2 ** (num_blocks - i - 1)))
            up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
        self.up_blocks = nn.ModuleList(up_blocks)
    
        self.up_blocks = nn.ModuleList(up_blocks)
        self.out_filters = block_expansion + in_features
        self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))

        self.skips=skips

    def forward(self,z):
        out = self.bottleneck(z)
        for i in range(len(self.up_blocks)):
            out = self.up_blocks[i](out)
        out = self.final(out)

        return out
    

