import torch
import torch.nn as nn
import torch.nn.functional as F
import mynn
from torchvision.utils import make_grid, save_image
from utils import norm, CUDA
import os
from torch.nn.utils import spectral_norm as SN


class Gen(mynn.Sequential):

    def __init__(self, *args, z_sampler=None, n_anchor=50, **kwargs):
        super().__init__(*args, **kwargs)
        self.anchored_layers = [l for l in self._modules.values()
                                    if isinstance(l, mynn.Anchored)]
        self.z_sampler = z_sampler
        self.n_anchor = n_anchor
        self.anchor_idx = None;  self.data_per_anchor = None;  self.batch_size = None


    def reset_anchor(self, anchor_idx, data_per_anchor):
        if isinstance(anchor_idx, int):
            anchor_idx = torch.randperm(self.n_anchor)[:anchor_idx]
        self.anchor_idx = anchor_idx
        self.data_per_anchor = data_per_anchor
        self.batch_size = self.data_per_anchor * self.data_per_anchor

        for l in self.anchored_layers:
            l.reset_anchor(self.anchor_idx, self.data_per_anchor)


    def reg_bias_align(self):
        regs = [torch.log(l.reg_bias_align()) for l in self.anchored_layers]
        return sum(regs)


    def set_bias(self, **pinv_kwargs):
        for l in self.anchored_layers:
            l.set_bias(**pinv_kwargs)


    def encode(self, x, n_steps=1000, lr=0.01, verbose=False, return_all=False):
        self.reset_anchor(anchor_idx=torch.arange(self.n_anchor), data_per_anchor=1)
        ais = []
        z_mins = []
        norms = []
        zs = []
        for xi in range(x.shape[0]):
            z = (0.01 * torch.randn(self.n_anchor, *self.in_shape, device=CUDA())).requires_grad_(True)
            prj_opt = torch.optim.Adam([z], lr=lr)
            for i in range(n_steps):
                diff = self(z) - x[xi].detach()
                norm = torch.norm(diff.view(self.n_anchor, -1), dim=1)
                dist = torch.mean( norm**2 )
                prj_opt.zero_grad()
                dist.backward()
                prj_opt.step()
            ai = torch.argmin(norm)
            z_min = z[ai].clone().detach()
            ais.append(ai)
            z_mins.append(z_min)
            norms.append(norm)
            zs.append(z.clone().detach())
            if verbose:
                print(dist)
        if return_all:
            return torch.stack(z_mins, dim=0), dict(ais=ais, norms=norms, zs=zs)
        else:
            return torch.stack(z_mins, dim=0)



class Dis(mynn.Sequential):
    def __init__(self, *args, form='wgan_sn', **kwargs):
        super().__init__(*args, **kwargs)
        self.form = form
        if self.form == 'wgan_sn':
            for li in range(len(self)):
                if isinstance(self[li], (nn.Linear, nn.Conv2d)):
                    self[li] = SN(self[li])


    def loss_toward(self, output, toward):
        assert toward in ['real', 'fake']

        if self.form == 'orig':
            target = torch.ones_like(output, device=CUDA()) if toward=='real' \
                else torch.zeros_like(output, device=CUDA())
            return F.binary_cross_entropy_with_logits(output, target)
        elif self.form == 'wgan_sn':
            return -torch.mean(output) if toward=='real' else torch.mean(output)
        else:
            raise NotImplementedError



class GAN(object):
    def __init__(self, gen, dis, vis, nick):
        self.gen = gen
        self.dis = dis
        self.vis = vis
        self.nick = nick


    def gen_loss(self, z):
        x_fake = self.gen(z)
        d_fake = self.dis(x_fake)
        loss = self.dis.loss_toward(d_fake, toward='real')
        return loss


    def dis_loss(self, x_real, z):
        x_fake = self.gen(z)
        d_fake = self.dis(x_fake)
        d_real = self.dis(x_real)
        loss = self.dis.loss_toward(d_fake, toward='fake') \
               + self.dis.loss_toward(d_real, toward='real')
        return loss


    def train(self, real_loader, anchor_per_batch, data_per_anchor, lamb_bias_align,
              n_epoch=30, n_dis=1, n_gen=1, anchor_reset_period=1,
              img_size=None, dump_period=1, vis_period=20):

        batch_size = anchor_per_batch * data_per_anchor

        opt_dis = torch.optim.Adam(self.dis.parameters(), lr=0.0002)
        opt_gen = torch.optim.Adam(self.gen.parameters(), lr=0.0002)
        zero_grad = lambda : (opt_dis.zero_grad(), opt_gen.zero_grad())

        from itertools import count
        import time
        tic = time.time()
        dis_loss_buff = []
        reg_bias_align_buff = []

        more = 2
        z_sample = self.gen.z_sampler(self.gen.n_anchor, more*data_per_anchor)
        for ep in range(1, n_epoch+1):
            print("epoch {} - {:.6} s".format(ep, time.time()-tic))
            tic = time.time()
            enum_real_loader = enumerate(real_loader)
            for i in count():
                try:
                    if i % anchor_reset_period == 0:
                        self.gen.reset_anchor(anchor_per_batch, data_per_anchor)

                    ''' Train Dis '''
                    for i_dis in range(n_dis):
                        bi, batch = next(enum_real_loader)
                        x_real, _ = batch
                        x_real = x_real.view(-1, *self.dis.in_shape)
                        x_real = x_real.cuda() if CUDA() == 'cuda' else x_real

                        dis_loss = self.dis_loss(x_real, self.gen.z_sampler(batch_size))
                        zero_grad()
                        dis_loss.backward()
                        opt_dis.step()
                    dis_loss_buff.append(dis_loss)


                    ''' Train Gen '''
                    for i_gen in range(n_gen):

                        gen_loss = self.gen_loss(self.gen.z_sampler(batch_size))
                        zero_grad()
                        gen_loss.backward()
                        opt_gen.step()

                        zero_grad()
                        reg_bias_align = lamb_bias_align * self.gen.reg_bias_align()
                        reg_bias_align.backward()
                        opt_gen.step()

                    reg_bias_align_buff.append(reg_bias_align)

                    if i % vis_period == 0:
                        self.vis.line( torch.stack(dis_loss_buff), name='dis_loss', update='append')
                        dis_loss_buff = []
                        self.vis.line( torch.stack(reg_bias_align_buff), name='reg_bias_align', update='append')
                        reg_bias_align_buff = []

                        if img_size is not None:
                            self.vis.images(x_real/2+0.5, id='x_real', nrow=data_per_anchor)
                            x_sample = self.gen(self.gen.z_sampler(anchor_per_batch, data_per_anchor)).view(batch_size, -1, img_size, img_size)
                            self.vis.images(x_sample/2+0.5, id='x_sample', nrow=data_per_anchor)


                except StopIteration:
                    break

            if ep % dump_period == 0:
                if img_size is not None:
                    self.gen.reset_anchor(anchor_idx=torch.arange(self.gen.n_anchor), data_per_anchor=more*data_per_anchor)
                    x_sample = self.gen(z_sample).view(self.gen.n_anchor*more*data_per_anchor, -1, img_size, img_size)
                    save_image(x_sample, os.path.join('dumps', self.nick, 'sample_ep{}.png'.format(ep)), nrow=more*data_per_anchor, normalize=True)
                    self.vis.images(x_sample/2+0.5, id='x_sample_all', nrow=more*data_per_anchor)
                torch.save(self.gen.state_dict(), os.path.join('dumps', self.nick, 'gen_ep{}.dump'.format(ep)))
                torch.save(self.dis.state_dict(), os.path.join('dumps', self.nick, 'dis_ep{}.dump'.format(ep)))
