import os

from hyper_params import hp
import numpy as np
import matplotlib.pyplot as plt
import PIL
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from encoder import EncoderGCN, biLstm
from decoder import DecoderRNN
from utils.sketch_processing import make_graph
from utils.Gumble import gumbel_softmax

def train_sample_bivariate_normal(mu_x: torch.Tensor, mu_y: torch.Tensor,
                            sigma_x: torch.Tensor, sigma_y: torch.Tensor,
                            rho_xy: torch.Tensor, greedy=False):
    mu_x = mu_x.cuda()
    mu_y = mu_y.cuda()
    sigma_x = sigma_x.cuda()
    sigma_y = sigma_y.cuda()
    rho_xy = rho_xy.cuda()
    # inputs must be floats
    if greedy:
        return mu_x, mu_y
    mean = torch.cat([mu_x, mu_y],dim=-1).cuda()
    sigma_x = sigma_x*torch.sqrt(torch.tensor(hp.temperature).cuda())
    sigma_y = sigma_y*torch.sqrt(torch.tensor(hp.temperature).cuda())

    cov =torch.stack(
        [torch.cat([sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],dim=-1),
           torch.cat([rho_xy * sigma_x * sigma_y, sigma_y * sigma_y],dim=-1)]
    ,dim=-1).cuda()

    epsilon = torch.randn(mean.shape).view(-1,1,2).cuda()
    ans = mean+torch.bmm(epsilon,cov).view(-1,2)

    return ans[:,0], ans[:,1]

################################# load and prepare data
class SketchesDataset:
    def __init__(self, path: str, category: list, mode="train"):
        self.sketches = None
        self.sketches_normed = None
        self.max_sketches_len = 0
        self.path = path
        self.category = category
        self.mode = mode

        tmp_sketches = []
        self.labels = []
        for i,c in enumerate(self.category):
            dataset = np.load(os.path.join(self.path, c), encoding='latin1', allow_pickle=True)
            tmp_sketches.append(dataset[self.mode])
            self.labels.append(torch.ones(len(dataset[self.mode])) * i)
            print(f"dataset: {c} added. labels: {i}")
        self.labels = torch.cat(self.labels)

        data_sketches = np.concatenate(tmp_sketches)
        print(f"length of trainset: {len(data_sketches)}")

        data_sketches = self.purify(data_sketches)  # data clean.  # remove toolong and too stort sketches.
        self.sketches = data_sketches.copy()
        self.sketches_normed = self.normalize(data_sketches)
        self.Nmax = self.max_size(data_sketches)  # max size of a sketch.

    def max_size(self, sketches):
        """返回所有sketch中 转折最多的一个sketch"""
        sizes = [len(sketch) for sketch in sketches]
        return max(sizes)

    def purify(self, sketches):
        data = []
        for sketch in sketches:
            if hp.max_seq_length >= sketch.shape[0] > hp.min_seq_length:  # remove small and too long sketches.
                sketch = np.minimum(sketch, 1000)  # remove large gaps.
                sketch = np.maximum(sketch, -1000)
                sketch = np.array(sketch, dtype=np.float32)  # change it into float32
                data.append(sketch)
        return data

    def calculate_normalizing_scale_factor(self, sketches):
        data = []
        for sketch in sketches:
            for stroke in sketch:
                data.append(stroke)
        return np.std(np.array(data))

    def normalize(self, sketches):
        """Normalize entire dataset (delta_x, delta_y) by the scaling factor."""
        data = []
        scale_factor = self.calculate_normalizing_scale_factor(sketches)
        print("scale_factor:",scale_factor)
        for sketch in sketches:
            sketch[:, 0:2] /= scale_factor
            data.append(sketch)
        return data

    def make_batch(self, batch_size):
        """
        :param batch_size:
        :return:
        """
        batch_idx = np.random.choice(len(self.sketches_normed), batch_size)
        batch_sketches = [self.sketches_normed[idx] for idx in batch_idx]
        batch_sketches_graphs = [self.sketches[idx] for idx in batch_idx]
        batch_labels = [self.labels[idx] for idx in batch_idx]
        sketches = []
        lengths = []
        graphs = []  # (batch_size * graphs_num_constant, x, y)
        adjs = []
        index = 0
        for _sketch in batch_sketches:
            len_seq = len(_sketch[:, 0])  # sketch
            new_sketch = np.zeros((self.Nmax, 5))  # new a _sketch, all length of sketch in size is Nmax.
            new_sketch[:len_seq, :2] = _sketch[:, :2]

            # set p into one-hot.
            new_sketch[:len_seq - 1, 2] = 1 - _sketch[:-1, 2]
            new_sketch[:len_seq, 3] = _sketch[:, 2]

            # len to Nmax set as 0,0,0,0,1
            new_sketch[(len_seq - 1):, 4] = 1
            new_sketch[len_seq - 1, 2:4] = 0  # x, y, 0, 0, 1
            lengths.append(len(_sketch[:, 0]))  # lengths is _sketch length, not new_sketch length.
            sketches.append(new_sketch)
            index += 1

        for _each_sketch in batch_sketches_graphs:
            _graph_tensor, _adj_matrix = make_graph(_each_sketch, graph_num=hp.graph_number,
                                                    graph_picture_size=hp.graph_picture_size, mask_prob=hp.mask_prob)
            graphs.append(_graph_tensor)
            adjs.append(_adj_matrix)

        if hp.use_cuda:
            batch = torch.from_numpy(np.stack(sketches, 1)).cuda().float()  # (Nmax, batch_size, 5)
            graphs = torch.from_numpy(np.stack(graphs, 0)).cuda().float()  # (batch_size, len, 5)
            adjs = torch.from_numpy(np.stack(adjs, 0)).cuda().float()

        else:
            batch = torch.from_numpy(np.stack(sketches, 1)).float()  # (Nmax, batch_size, 5)
            graphs = torch.from_numpy(np.stack(graphs, 0)).float()
            adjs = torch.from_numpy(np.stack(adjs, 0)).float()

        return batch, lengths, graphs, adjs, batch_labels




sketch_dataset = SketchesDataset(hp.data_location, hp.category, "train")
hp.Nmax = sketch_dataset.Nmax


def sample_bivariate_normal(mu_x: torch.Tensor, mu_y: torch.Tensor,
                            sigma_x: torch.Tensor, sigma_y: torch.Tensor,
                            rho_xy: torch.Tensor, greedy=False):
    mu_x = mu_x.item()
    mu_y = mu_y.item()
    sigma_x = sigma_x.item()
    sigma_y = sigma_y.item()
    rho_xy = rho_xy.item()
    # inputs must be floats
    if greedy:
        return mu_x, mu_y
    mean = [mu_x, mu_y]

    sigma_x *= np.sqrt(hp.temperature)
    sigma_y *= np.sqrt(hp.temperature)

    cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],
           [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
    x = np.random.multivariate_normal(mean, cov, 1)
    return x[0][0], x[0][1]


def make_image(sequence, epoch, name='_output_'):
    strokes = np.split(sequence, np.where(sequence[:, 2] > 0)[0] + 1)
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    for s in strokes:
        plt.plot(s[:, 0], -s[:, 1])
    canvas = plt.get_current_fig_manager().canvas
    canvas.draw()
    pil_image = PIL.Image.frombytes('RGB', canvas.get_width_height(),
                                    canvas.tostring_rgb())
    name = f"./model_save/" + str(epoch) + name + '.jpg'
    pil_image.save(name, "JPEG")
    plt.close("all")


################################# encoder and decoder modules


class Model:
    def __init__(self):
        if hp.use_cuda:
            self.encoder: nn.Module = EncoderGCN(hp.graph_number, hp.graph_picture_size, hp.out_f_num, hp.Nz,
                                                 bias_need=False).cuda()
            self.decoder: nn.Module = DecoderRNN().cuda()
        else:
            self.encoder: nn.Module = EncoderGCN(hp.graph_number, hp.graph_picture_size, hp.out_f_num, hp.Nz,
                                                 bias_need=False)
            self.decoder: nn.Module = DecoderRNN()
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), hp.lr)
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), hp.lr)
        self.eta_step = hp.eta_min
        self.rnn_zoo = self.get_rnn()

    def get_rnn(self):
        #zoo = []
       # for cat in hp.category:
         encoder = biLstm().cuda()
            #c = cat.split('.')[0]
         encoder.load_state_dict(torch.load(f"./encoderRNN.pth"))
         for params in encoder.parameters():
             params.requires_grad = False
             #zoo.append(decoder)
         return encoder

    def lr_decay(self, optimizer: optim):
        """Decay learning rate by a factor of lr_decay"""
        for param_group in optimizer.param_groups:
            if param_group['lr'] > hp.min_lr:
                param_group['lr'] *= hp.lr_decay
        return optimizer

    def make_target(self, batch, lengths):
        """
        batch torch.Size([129, 100, 5])  Nmax batch_size
        """
        if hp.use_cuda:
            eos = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * batch.size()[1]).cuda().unsqueeze(
                0)  # torch.Size([1, 100, 5])
        else:
            eos = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * batch.size()[1]).unsqueeze(0)  # max of len(strokes)

        batch = torch.cat([batch, eos], 0)
        mask = torch.zeros(hp.Nmax + 1, batch.size()[1])
        for indice, length in enumerate(lengths):  # len(lengths) = batchsize
            mask[:length, indice] = 1
        if hp.use_cuda:
            mask = mask.cuda()
        dx = torch.stack([batch.data[:, :, 0]] * hp.M, 2)  # torch.Size([130, 100, 20])
        dy = torch.stack([batch.data[:, :, 1]] * hp.M, 2)  # torch.Size([130, 100, 20])
        p1 = batch.data[:, :, 2]  # torch.Size([130, 100])
        p2 = batch.data[:, :, 3]
        p3 = batch.data[:, :, 4]
        p = torch.stack([p1, p2, p3], 2)  # torch.Size([130, 100, 3])
        return mask, dx, dy, p

    def get_perceptual_loss(self, x, x_hat, labels):

        criterion = nn.MSELoss()
        _,mu_list,_,ph_list = self.rnn_zoo(x,hp.batch_size)
        _,mu_hat_list,_,ph_hat_list = self.rnn_zoo(x_hat,hp.batch_size)
        #ph_list = []
        #ph_hat_list = []
        #mu_list =[]
        #mu_hat_list=[]
        #for i,label in enumerate(labels):
            # 44.1803
           # scalar = torch.ones(x.shape[0],1,x.shape[-1])
           # scalar[:,:,0:2] = scalar[:,:,0:2]*37.8767/hp.scalar[int(label)] 
            #scalar = scalar.cuda()
           # scalar_hat = torch.ones(x_hat.shape[0],1,x_hat.shape[-1])
            #scalar_hat[:,:,0:2] = scalar_hat[:,:,0:2]*37.8767/hp.scalar[int(label)] 
            #scalar_hat = scalar_hat.cuda()

            # _,mu,_,ph = self.rnn_zoo[int(label)](x[:,i,:].view(x.shape[0],1,x.shape[-1])*scalar,1)
           # ph, _, _, ph = self.rnn_zoo[int(label)](x[:,i,:].view(x.shape[0],1,x.shape[-1])*scalar,1)
            #_,mu_hat,_,ph_hat = self.rnn_zoo[int(label)](x_hat[:,i,:].view(x_hat.shape[0],1,x_hat.shape[-1])*scalar_hat, 1)
           # ph_list.append(ph)
           # ph_hat_list.append(ph_hat)
           # mu_list.append(mu)
           # mu_hat_list.append(mu_hat)
       # ph_list = torch.stack(ph_list)
       # ph_hat_list = torch.stack(ph_hat_list)
       # mu_list = torch.stack(mu_list)
       # mu_hat_list = torch.stack(mu_hat_list)
        return criterion(ph_list, ph_hat_list)*1024+criterion(mu_list,mu_hat_list)*128

    def train(self, epoch):
        self.encoder.train()
        self.decoder.train()
        #hp.mask_prob = random.uniform(0.1,0.3)
        batch, lengths, graphs, adjs, labels = sketch_dataset.make_batch(hp.batch_size)
        # print(batch, lengths)

        # encode:
        # z, self.mu, self.sigma = self.encoder(batch, hp.batch_size)  # in here, Z is sampled from N(mu, sigma)
        z, self.mu, self.sigma, _ = self.encoder(graphs, adjs)  # in here, Z is sampled from N(mu, sigma)
        # torch.Size([100, 128]) torch.Size([100, 128]) torch.Size([100, 128])
        # print(z.shape, self.mu.shape, self.sigma.shape)

        # create start of sequence:
        if hp.use_cuda:
            sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * hp.batch_size).cuda().unsqueeze(0)
            # torch.Size([1, 100, 5])
        else:
            sos = torch.stack([torch.Tensor([0, 0, 1, 0, 0])] * hp.batch_size).unsqueeze(0)
        # had sos at the begining of the batch:
        batch_init = torch.cat([sos, batch], 0)  # torch.Size([130, 100, 5])
        # expend z to be ready to concatenate with inputs:
        z_stack = torch.stack([z] * (hp.Nmax + 1))  # torch.Size([130, 100, 128])
        # inputs is concatenation of z and batch_inputs
        inputs = torch.cat([batch_init, z_stack], 2)  # torch.Size([130, 100, 133])

        # decode:
        self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, self.rho_xy, self.q, h, c= self.decoder(inputs, z)

        rec_seq = self.s2s()
        eos = torch.stack([torch.Tensor([0, 0, 0, 0, 1])] * batch.size()[1]).cuda().unsqueeze(0)
        inseq = torch.cat([batch,eos],0)
        perceptual_loss = self.get_perceptual_loss(rec_seq, inseq, labels)

        # prepare targets:
        mask, dx, dy, p = self.make_target(batch, lengths)
        # prepare optimizers:
        self.encoder_optimizer.zero_grad()
        self.decoder_optimizer.zero_grad()
        # update eta for LKL:
        self.eta_step = 1 - (1 - hp.eta_min) * (hp.R ** epoch)  # self.eta_step = 1 - (1 - hp.eta_min) * hp.R
        # compute losses:
        # LKL = self.kullback_leibler_loss()
        LR = self.reconstruction_loss(mask, dx, dy, p, epoch)
        # loss = LR + LKL
        loss = (1-hp.mask_prob)*LR+ hp.mask_prob*perceptual_loss
        # gradient step
        loss.backward()  # all torch.Tensor has backward.
        # gradient cliping
        nn.utils.clip_grad_norm(self.encoder.parameters(), hp.grad_clip)
        nn.utils.clip_grad_norm(self.decoder.parameters(), hp.grad_clip)
        # optim step
        self.encoder_optimizer.step()
        self.decoder_optimizer.step()
        # some print and save:
        if epoch % 1 == 0:
            # print('epoch', epoch, 'loss', loss.item(), 'LR', LR.item(), 'LKL', LKL.item())
            print('gcn, epoch -> ', epoch, 'loss', loss.item(), 'LR', LR.item(),'LP',perceptual_loss.item())
            self.encoder_optimizer = self.lr_decay(self.encoder_optimizer)  # modify optimizer after one step.
            self.decoder_optimizer = self.lr_decay(self.decoder_optimizer)
        if epoch == 0:
            return
        if epoch % 500 == 0:
            self.conditional_generation(epoch)
        if epoch % 2000 == 0:
            self.save(epoch)

    def bivariate_normal_pdf(self, dx, dy):
        z_x = ((dx - self.mu_x) / self.sigma_x) ** 2
        z_y = ((dy - self.mu_y) / self.sigma_y) ** 2
        z_xy = (dx - self.mu_x) * (dy - self.mu_y) / (self.sigma_x * self.sigma_y)
        z = z_x + z_y - 2 * self.rho_xy * z_xy
        exp = torch.exp(-z / (2 * (1 - self.rho_xy ** 2)))
        norm = 2 * np.pi * self.sigma_x * self.sigma_y * torch.sqrt(1 - self.rho_xy ** 2)
        return exp / norm

    def reconstruction_loss(self, mask, dx, dy, p, epoch):
        pdf = self.bivariate_normal_pdf(dx, dy)  # torch.Size([130, 100, 20])
        # stroke
        LS = -torch.sum(mask * torch.log(1e-3 + torch.sum(self.pi * pdf, 2))) / float( hp.batch_size)
        # position
        LP = -torch.sum(p * torch.log(1e-3 + self.q)) / float( hp.batch_size)
        return LS + LP

    def kullback_leibler_loss(self):
        LKL = -0.5 * torch.sum(1 + self.sigma - self.mu ** 2 - torch.exp(self.sigma)) \
              / float(hp.Nz * hp.batch_size)
        if hp.use_cuda:
            KL_min = torch.Tensor([hp.KL_min]).cuda().detach()
        else:
            KL_min = torch.Tensor([hp.KL_min]).detach()
        return hp.wKL * self.eta_step * torch.max(LKL, KL_min)

    def save(self, epoch):
        # sel = np.random.rand()
        torch.save(self.encoder.state_dict(), \
                   f'./{hp.model_save}/encoderRNN_epoch_{epoch}.pth')
        torch.save(self.decoder.state_dict(), \
                   f'./{hp.model_save}/decoderRNN_epoch_{epoch}.pth')

    def load(self, encoder_name, decoder_name):
        saved_encoder = torch.load(encoder_name)
        saved_decoder = torch.load(decoder_name)
        self.encoder.load_state_dict(saved_encoder)
        self.decoder.load_state_dict(saved_decoder)

    def conditional_generation(self, epoch):
        batch, lengths, graphs, adjs,_ = sketch_dataset.make_batch(1)
        # should remove dropouts:
        self.encoder.train(False)
        self.decoder.train(False)
        # encode:
        z, _, _, _ = self.encoder(graphs, adjs)
        if hp.use_cuda:
            sos = torch.Tensor([0, 0, 1, 0, 0]).view(1, 1, -1).cuda()
        else:
            sos = torch.Tensor([0, 0, 1, 0, 0]).view(1, 1, -1)
        s = sos
        seq_x = []
        seq_y = []
        seq_z = []
        hidden_cell = None
        for i in range(hp.Nmax):
           input = torch.cat([s, z.unsqueeze(0)], 2)  # start of stroke concatenate with z
            # decode:
           self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, \
           self.rho_xy, self.q, hidden, cell = \
               self.decoder(input, z, hidden_cell)
           hidden_cell = (hidden, cell)
            # sample from parameters:
           s, dx, dy, pen_down, eos = self.sample_next_state()
            # ------
           seq_x.append(dx)
           seq_y.append(dy)
           seq_z.append(pen_down)
           if eos:
               print(i)
               break
        # visualize result:
        x_sample = np.cumsum(seq_x, 0)
        y_sample = np.cumsum(seq_y, 0)
        z_sample = np.array(seq_z)
        sequence = np.stack([x_sample, y_sample, z_sample]).T
        make_image(sequence, epoch)

    def sample_next_state(self):
        """
        softmax
        """

        def adjust_temp(pi_pdf):
            pi_pdf = np.log(1e-3 + np.abs(pi_pdf)) / hp.temperature
            # pi_pdf -= pi_pdf.max()
            pi_pdf = np.exp(pi_pdf)
            pi_pdf /= (pi_pdf.sum())
            return pi_pdf

        # get mixture indice:
        pi = self.pi.data[0, 0, :].cpu().numpy()
        pi = adjust_temp(pi)
        pi_idx = np.random.choice(hp.M, p=pi)
        # get pen state:
        q = self.q.data[0, 0, :].cpu().numpy()
        q = adjust_temp(q)
        q_idx = np.random.choice(3, p=q)
        # get mixture params:
        mu_x = self.mu_x.data[0, 0, pi_idx]
        mu_y = self.mu_y.data[0, 0, pi_idx]
        sigma_x = self.sigma_x.data[0, 0, pi_idx]
        sigma_y = self.sigma_y.data[0, 0, pi_idx]
        rho_xy = self.rho_xy.data[0, 0, pi_idx]
        x, y = sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False)  # get samples.
        next_state = torch.zeros(5)
        next_state[0] = x
        next_state[1] = y
        next_state[q_idx + 2] = 1
        if hp.use_cuda:
            return next_state.cuda().view(1, 1, -1), x, y, q_idx == 1, q_idx == 2
        else:
            return next_state.view(1, 1, -1), x, y, q_idx == 1, q_idx == 2

    def s2s(self):
        pi_idx = F.one_hot(torch.argmax(self.pi,dim=-1),20) # Nmax,bs,20
        q_idx = F.one_hot(torch.argmax(self.q,dim=-1), 3)
        mu_x = (pi_idx*self.mu_x).sum(dim=-1,keepdims=True)
        mu_y = (pi_idx*self.mu_y).sum(dim=-1,keepdims=True)
        sequence = torch.cat([mu_x,mu_y,q_idx],dim=-1)
        return sequence

    def train_sample_next_state(self):
        """
        softmax
        """

        def adjust_temp(pi_pdf):
            pi_pdf = torch.log(1e-3 + torch.abs(pi_pdf)) / hp.temperature
            # pi_pdf -= pi_pdf.max()
            pi_pdf = torch.exp(pi_pdf)
            pi_pdf = pi_pdf/(pi_pdf.sum(dim=-1, keepdim=True))
            return pi_pdf

        # get mixture indice:
        pi = self.train_pi[:, 0, :].view(-1, hp.M)
        pi = adjust_temp(pi)
        # pi_idx(bs , hp.M)
        pi_idx = gumbel_softmax(pi)

        # get pen state:
        q = self.train_q[:, 0, :].view(-1, 3)
        q = adjust_temp(q)
        q_idx = gumbel_softmax(q)
        #q_idx = torch.multinomial(q, 1)

        # get mixture params:
        mu_x = (self.train_mu_x.view(-1,hp.M) * pi_idx).sum(dim=-1).view(-1,1)
        mu_y = (self.train_mu_y.view(-1, hp.M) * pi_idx).sum(dim=-1).view(-1,1)
        sigma_x = (self.train_sigma_x.view(-1, hp.M) * pi_idx).sum(dim=-1).view(-1,1)
        sigma_y = (self.train_sigma_y.view(-1, hp.M) * pi_idx).sum(dim=-1).view(-1,1)
        rho_xy = (self.train_rho_xy.view(-1, hp.M) * pi_idx).sum(dim=-1).view(-1,1)

        x, y = train_sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False)  # get samples.
        # x = mu_x
        # y = mu_y
        if hp.use_cuda:
            return x, y, q_idx
        else:
            return x, y, q_idx


def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}


if __name__ == "__main__":
    model = Model()
    print(get_parameter_number(model.encoder))
    print(get_parameter_number(model.decoder))
    epoch_load =0
    if epoch_load !=0:
        model.load(f'./{hp.model_save}/encoderRNN_epoch_{epoch_load}.pth',
                f'./{hp.model_save}/decoderRNN_epoch_{epoch_load}.pth')
     
    model.encoder_optimizer = optim.Adam(model.encoder.parameters(), hp.lr)
    model.decoder_optimizer = optim.Adam(model.decoder.parameters(), hp.lr)
    for epoch in range(500001):
        if epoch <= epoch_load:
            model.encoder_optimizer = model.lr_decay(model.encoder_optimizer)  # modify optimizer after one step.
            model.decoder_optimizer = model.lr_decay(model.decoder_optimizer)
            continue
        #if epoch_load:
           # model.load(f'./{hp.model_save}/encoderRNN_epoch_{epoch_load}.pth',
                      # f'./{hp.model_save}/decoderRNN_epoch_{epoch_load}.pth')
        model.train(epoch)

    '''
                                           
    model.conditional_generation(0)
    #'''
