import numpy as np
import torch
import os
from collections import OrderedDict
from torch.autograd import Variable
import util.util3d as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks3d
from . import networks_2g_st
from unet import IPN_V2, Unet
from .WUNet import DynamicWaveletUNet3D as WUNet3D

from .hfc import generator
import torch.nn.functional as F

def total_variation_loss_3d(x):
    batch_size, channels = x.shape[0], x.shape[1]

    diff_depth = torch.mean(torch.abs(x[:, :, 1:, :, :] - x[:, :, :-1, :, :]))
    diff_length = torch.mean(torch.abs(x[:, :, :, 1:, :] - x[:, :, :, :-1, :]))
    diff_height = torch.mean(torch.abs(x[:, :, :, :, 1:] - x[:, :, :, :, :-1]))

    tv_loss = diff_depth + diff_length + diff_height

    tv_loss /= (batch_size * channels)
    
    return tv_loss


class TransProModel(BaseModel):
    def name(self):
        return 'TransProModel'

    def __init__(self, opt):
        BaseModel.__init__(self, opt)
        self.isTrain = opt.isTrain

        self.netG = WUNet3D(opt.ngf, wavelet='bior4.4', device=self.device)
        self.netG.weight_init(mean=0.0, std=0.02)
        self.netG.init_wavelet_params() 
        self.netG.to(device=self.device)
        
        bypass_params = []
        other_params = []

        for name, param in self.netG.named_parameters():
            if 'wavelet_bypass' in name:
                bypass_params.append(param)
            else:
                other_params.append(param)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks3d.define_D(opt.input_nc + opt.output_nc, opt.ndf,
                                          opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, use_sigmoid)

            # load 2D Generator for HFC module
            self.netG_t = generator(opt.ngf).to(self.device)
            self.netG_t.load_state_dict(torch.load("./pretrain_weights/hfc.pth", map_location=self.device)) 
            for p in self.netG_t.parameters():
                p.requires_grad=False

            self.netG_ilm = Unet().to(self.device)
            self.netG_ilm.load_state_dict(torch.load("./pretrain_weights/ilm_opl.pth", map_location=self.device))
            for p in self.netG_ilm.parameters():
                p.requires_grad=False

            self.netG_bm = Unet().to(self.device)
            self.netG_bm.load_state_dict(torch.load("./pretrain_weights/opl_bm.pth", map_location=self.device))
            for p in self.netG_bm.parameters():
                p.requires_grad=False    

            block_size = [256, 100, 100]

            self.net = IPN_V2(1, 32, 32, 2, block_size).to(self.device)
            self.net.load_state_dict(torch.load("./pretrain_weights/vpg.pth", map_location=self.device)) 
            self.net.to(device=self.device)
            for p in self.net.parameters():
                p.requires_grad=False


            
        if not self.isTrain or opt.continue_train:
            self.load_network(self.netG, 'G', opt.which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', opt.which_epoch)

        
        if self.isTrain:
            self.fake_AB_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks_2g_st.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionL1 = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_D = torch.optim.Adam([
                {'params': self.netD.parameters(), 'lr': opt.lr, 'betas': (opt.beta1, 0.999)},
            ])


            self.optimizer_G = torch.optim.Adam([
                {'params': other_params, 'lr': opt.lr},
                {'params': bypass_params, 'lr': opt.lr * 0.1}
            ], betas=(opt.beta1, 0.999))

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
            
        print('---------- Networks initialized -------------')
        networks3d.print_network(self.netG)
        if self.isTrain:
            networks3d.print_network(self.netD)
        print('-----------------------------------------------')
        
        

    def set_input(self, input):
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].permute(1,0,2,3,4).to(self.device,dtype=torch.float) #torch.Size([1, 1, 256, 256, 256])
        self.real_B = input['B' if AtoB else 'A'].permute(1,0,2,3,4).to(self.device,dtype=torch.float)
        self.real_A_proj = torch.mean(self.real_A,3) #torch.Size([1, 1, 256, 256])
        self.real_A_proj = self.Norm(self.real_A_proj)
        self.real_B_proj = torch.mean(self.real_B,3)
        self.real_B_proj = self.Norm(self.real_B_proj)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG.forward(self.real_A) # torch.Size([1, 1, 256, 256, 256])
        self.fake_B_proj_t = self.netG_t.forward(self.real_A_proj)
        self.fake_B_proj_s = torch.mean(self.fake_B,3)
        self.fake_B_proj_s = self.Norm(self.fake_B_proj_s)


    def test(self):
        with torch.no_grad():
            self.fake_B = self.netG.forward(self.real_A)
        

    # get image paths
    def get_image_paths(self):
        #return "blksdf"
        return self.image_paths

    def backward_D(self, dataset, iter_count):
        apply_augmentation = (iter_count % 20 == 0)
        
        if apply_augmentation:
            real_B_aug, fake_B_aug = dataset.dataset.get_augmented_batch(self.real_B, self.fake_B)
            real_B_aug = real_B_aug.to(self.device,dtype=torch.float)
            fake_B_aug = fake_B_aug.to(self.device,dtype=torch.float)
            real_A = self.real_A
            fake_AB = torch.cat((real_A, fake_B_aug), 1)
            real_AB = torch.cat((real_A, real_B_aug), 1)
        else:
            fake_AB = torch.cat((self.real_A, self.fake_B), 1)
            real_AB = torch.cat((self.real_A, self.real_B), 1)


        # Fake
        fake_AB = self.fake_AB_pool.query(fake_AB)
        self.pred_fake = self.netD.forward(fake_AB.detach())
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)

        # Real
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)
           
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        current_d_loss = self.loss_D.item()

        dataset.dataset.update_discriminator_loss(current_d_loss)

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = torch.cat((self.real_A, self.fake_B), 1)
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        # Segmentation loss
        with torch.no_grad():
            fake_B_permuted = self.Norm(self.fake_B.permute(1, 0, 3, 4, 2))
            real_B_permuted = self.Norm(self.real_B.permute(1, 0, 3, 4, 2))

            fake_logits, _ = self.net(fake_B_permuted)
            real_logits, _ = self.net(real_B_permuted)
        fake_probs = F.softmax(fake_logits, dim=1)
        real_probs = F.softmax(real_logits, dim=1)
        self.fake_B_seg = fake_probs[0]
        self.real_B_seg = real_probs[0]

        self.loss_G_L1_pm_st = self.criterionL1(self.fake_B_proj_s, self.fake_B_proj_t) * self.opt.lambda_C
        self.loss_G_L1_seg = self.criterionL1(self.fake_B_seg, self.real_B_seg) * self.opt.lambda_C
        
        self.loss_tv = total_variation_loss_3d(self.fake_B) * 0.25

        self.loss_G_ilm = self.criterionL1(self.netG_ilm(self.fake_B_proj_s), self.netG_ilm(self.real_B_proj)) * 0.25
        self.loss_G_bm = self.criterionL1(self.netG_bm(self.fake_B_proj_s), self.netG_bm(self.real_B_proj)) * 0.25
        self.loss_proj = self.criterionL1(self.fake_B_proj_s, self.real_B_proj) * 0.25

        self.loss_G = self.loss_G_GAN + self.loss_G_L1 + self.loss_G_L1_seg


        self.loss_G.backward()

    def optimize_parameters(self, dataset, iter_count):
        self.forward()
        self.set_requires_grad(self.netD, True)
        self.optimizer_D.zero_grad()
        self.backward_D(dataset, iter_count)
        self.optimizer_D.step()

        self.set_requires_grad(self.netD, False) 
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        
    def get_current_errors(self):
        return OrderedDict([
            ('G_GAN', self.loss_G_GAN.item()),
            ('G_L1', self.loss_G_L1.item()),
            ('G_loss_tv', self.loss_tv.item()),
            ('G_L1_seg', self.loss_G_L1_seg.item()), 
            ('G_ilm', self.loss_G_ilm.item()),       
            ('G_bm', self.loss_G_bm.item()),         
            ('proj', self.loss_proj.item()),         
            ('D_real', self.loss_D_real.item()),
            ('D_fake', self.loss_D_fake.item())
        ])

    def get_current_visuals(self):
        real_A = util.tensor2im3d(self.real_A.data)
        fake_B = util.tensor2im3d(self.fake_B.data)
        real_B = util.tensor2im3d(self.real_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])

    def save(self, label):
        self.save_network(self.netG, 'G', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)

    def Norm(self, a):
        max_ = torch.max(a)
        min_ = torch.min(a)
        a_0_1 = (a-min_)/(max_-min_)
        return (a_0_1-0.5)*2