import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.base_network import BaseNetwork
from models.networks.normalization import get_nonspade_norm_layer
from models.networks.architecture import ResnetBlock as ResnetBlock
from models.networks.architecture import MATResnetBlock as MATResnetBlock
from models.networks.architecture import GMapping


class S2PGenerator(BaseNetwork):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
        parser.set_defaults(semantic_nc = 512 + 3)
        return parser

    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        if 'cheetah' in self.opt.env_type:
            self.state_num = 17
        elif 'walker' in self.opt.env_type:
            self.state_num = 24
        elif 'ballincup' in self.opt.env_type:
            self.state_num = 8
        elif 'cartpole' in self.opt.env_type:
            self.state_num = 5
        elif 'finger' in self.opt.env_type:
            self.state_num = 9
        elif 'reacher' in self.opt.env_type:
            self.state_num = 6
        else:
            self.state_num = None

        self.fc = nn.Linear(self.state_num * 21, 8 * nf)
        self.fc2d = nn.Conv2d(3, 8 * nf, 3, padding=1)

        self.state_map = GMapping(latent_size=self.state_num * 21)

        self.head_0 = MATResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0 = MATResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = MATResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = MATResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = MATResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = MATResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = MATResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf

        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)

    def compute_latent_vector_size(self, opt):

        num_up_layers = 5

        sw = opt.crop_size // (2**num_up_layers)
        sh = round(sw / opt.aspect_ratio)

        return sw, sh

    def forward(self, input, back_imgs):

        batch = len(input)

        seg = self.state_map(input)
        seg = seg.unsqueeze(2).unsqueeze(3).expand(batch, 512, 128, 128)
        seg = torch.cat((seg, back_imgs), dim=1)

        x1 = self.fc(input)
        x1 = x1.unsqueeze(2).unsqueeze(3).expand(batch, 8 * self.opt.ngf, self.sh, self.sw)

        x2 = F.interpolate(back_imgs, size=(self.sh, self.sw))
        x2 = self.fc2d(x2)

        x = torch.cat((x1, x2), dim=1)


        x = self.head_0(x, seg)
        x = self.up(x)

        x = self.G_middle_0(x, seg)
        x = self.G_middle_1(x, seg)
        x = self.up(x)

        x = self.up_0(x, seg)
        x = self.up(x)

        x = self.up_1(x, seg)
        x = self.up(x)

        x = self.up_2(x, seg)
        x = self.up(x)

        x = self.up_3(x, seg)
        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)

        return x



class S2PLightGenerator(BaseNetwork):
    @staticmethod
    def modify_commandline_options(parser, is_train):
        parser.set_defaults(norm_G='spectralspadesyncbatch3x3')

        return parser

    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        nf = opt.ngf

        self.sw, self.sh = self.compute_latent_vector_size(opt)

        if 'cheetah' in self.opt.env_type:

            self.state_num = 17

        elif 'walker' in self.opt.env_type:

            self.state_num = 24

        elif 'ballincup' in self.opt.env_type:

            self.state_num = 8

        elif 'cartpole' in self.opt.env_type:

            self.state_num = 5

        elif 'finger' in self.opt.env_type:
            self.state_num = 9

        elif 'reacher' in self.opt.env_type:

            self.state_num = 6

        else:
            self.state_num = None


        self.fc = nn.Linear(self.state_num * 21, 8 * nf)
        self.fc2d = nn.Conv2d(3, 8 * nf, 3, padding=1)

        self.head_0 = MATResnetBlock(16 * nf, 16 * nf, opt)

        self.G_middle_0 = MATResnetBlock(16 * nf, 16 * nf, opt)
        self.G_middle_1 = MATResnetBlock(16 * nf, 16 * nf, opt)

        self.up_0 = MATResnetBlock(16 * nf, 8 * nf, opt)
        self.up_1 = MATResnetBlock(8 * nf, 4 * nf, opt)
        self.up_2 = MATResnetBlock(4 * nf, 2 * nf, opt)
        self.up_3 = MATResnetBlock(2 * nf, 1 * nf, opt)

        final_nc = nf


        self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)

        self.up = nn.Upsample(scale_factor=2)

    def compute_latent_vector_size(self, opt):
        num_up_layers=5

        sw = opt.crop_size // (2**num_up_layers)
        sh = round(sw / opt.aspect_ratio)

        return sw, sh

    def forward(self, input, back_imgs):

        batch = len(input)
        seg = input.unsqueeze(2).unsqueeze(3).expand(batch, self.state_num * 21, 128, 128)
        seg = torch.cat((seg, back_imgs), dim=1)

        x1 = self.fc(input)
        x1 = x1.unsqueeze(2).unsqueeze(3).expand(batch, 8 * self.opt.ngf, self.sh, self.sw)

        x2 = F.interpolate(back_imgs, size=(self.sh, self.sw))
        x2 = self.fc2d(x2)

        x = torch.cat((x1, x2), dim=1)

        x = self.head_0(x, seg)
        x = self.up(x)

        x = self.G_middle_0(x, seg)
        x = self.G_middle_1(x, seg)
        x = self.up(x)

        x = self.up_0(x, seg)
        x = self.up(x)

        x = self.up_1(x, seg)
        x = self.up(x)

        x = self.up_2(x, seg)
        x = self.up(x)

        x = self.up_3(x, seg)
        x = self.conv_img(F.leaky_relu(x, 2e-1))
        x = F.tanh(x)

        return x
