#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@File    :   WTM.py
@Time    :   2020/10/06 17:13:43
@Author  :   Leilan Zhang
@Version :   1.0
@Contact :   zhangleilan@gmail.com
@Desc    :   None
'''

import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from octis.models.SlicedWTM.models.wae import WAE
from octis.models.SlicedWTM.utils import evaluate_topic_quality, smooth_curve
import random

class SWTM:
    def __init__(self, bow_dim=10000, n_topic=20, device=None,
                 taskname=None, dropout=0.0, batch_size=256, learning_rate=1e-3,
                 num_epochs=100, log_every=5, beta=1.0, dist='gmm_std', loss_type='sbtsw',
                 num_projections=100, n_trees=100, delta=1, p=2, kappa=50):
        self.bow_dim = bow_dim
        self.n_topic = n_topic
        self.wae = WAE(encode_dims=[bow_dim, 1024, 512, n_topic], decode_dims=[n_topic, 512, bow_dim], dropout=dropout, nonlin='relu')
        self.device = device
        self.id2token = None
        self.dist = dist
        self.dropout = dropout
        self.taskname = taskname
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.log_every = log_every
        self.beta = beta
        self.p = p
        self.kappa = kappa

        self.num_projections = num_projections
        self.n_trees = n_trees
        self.delta = delta
        self.loss_type = loss_type

        if device != None:
            self.wae = self.wae.to(device)

    def train(self, train_data, test_data=None,
              ckpt=None, verbose=False, topK=10):
        self.wae.train()
        self.id2token = {v: k for k,v in train_data.dictionary.token2id.items()}
        data_loader = DataLoader(train_data, batch_size=self.batch_size,shuffle=True, num_workers=4, collate_fn=train_data.collate_fn)

        optimizer = torch.optim.Adam(self.wae.parameters(), lr=self.learning_rate)
        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)

        if ckpt:
            self.load_model(ckpt["net"])
            optimizer.load_state_dict(ckpt["optimizer"])
            start_epoch = ckpt["epoch"] + 1
        else:
            start_epoch = 0

        trainloss_lst, valloss_lst = [], []
        c_v_lst, c_w2v_lst, c_uci_lst, c_npmi_lst, mimno_tc_lst, td_lst = [], [], [], [], [], []
        for epoch in range(start_epoch, self.num_epochs):
            epochloss_lst = []
            for iter, data in enumerate(data_loader):
                optimizer.zero_grad()

                txts, bows = data
                bows = bows.to(self.device)

                bows_recon, z = self.wae(bows)
                if self.dist == "dirichlet":
                    z = F.softmax(z, dim=1)
                
                prior = self.wae.sample(dist=self.dist, batch_size=len(bows), ori_data=bows).to(self.device)

                logsoftmax = torch.log_softmax(bows_recon, dim=1)
                rec_loss = -1.0 * torch.sum(bows*logsoftmax)

                if self.n_trees is None:
                    nlines = self.num_projections
                else:
                    nlines = self.num_projections // self.n_trees

                if self.loss_type == "sbtsw":
                    ot_loss = self.wae.sbtsw_loss(z, prior, p=self.p, delta=self.delta, ntrees=self.n_trees, nlines=nlines)

                s = torch.sum(bows)/len(bows)
                lamb = (5.0*s*torch.log(torch.tensor(1.0 *bows.shape[-1]))/torch.log(torch.tensor(2.0)))
                ot_loss = ot_loss * lamb

                loss = rec_loss + ot_loss * self.beta

                loss.backward()
                optimizer.step()

                trainloss_lst.append(loss.item()/len(bows))
                epochloss_lst.append(loss.item()/len(bows))
                # if (iter+1) % 10 == 0:
                #     print(f'Epoch {(epoch+1):>3d}\tIter {(iter+1):>4d}\tLoss:{loss.item()/len(bows):<.7f}\tRec Loss:{rec_loss.item()/len(bows):<.7f}\t OT loss:{ot_loss.item()/len(bows):<.7f}')
            #scheduler.step()
            if (epoch+1) % self.log_every == 0:
                save_name = f'./ckpt/WTM_{self.taskname}_tp{self.n_topic}_{self.dist}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}_{epoch+1}.ckpt'
                checkpoint = {
                    "net": self.wae.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "param": {
                        "bow_dim": self.bow_dim,
                        "n_topic": self.n_topic,
                        "taskname": self.taskname,
                        "dist": self.dist,
                        "dropout": self.dropout
                    }
                }
                torch.save(checkpoint,save_name)
                print(f'Epoch {(epoch+1):>3d}\tLoss:{sum(epochloss_lst)/len(epochloss_lst):<.7f}')
                print('\n'.join([str(lst) for lst in self.show_topic_words()]))
                print('='*30)
                # smth_pts = smooth_curve(trainloss_lst)
                # plt.plot(np.array(range(len(smth_pts)))*log_every, smth_pts)
                # plt.xlabel('epochs')
                # plt.title('Train Loss')
                # plt.savefig('wlda_trainloss.png')
                if test_data!=None:
                    c_v,c_w2v,c_uci,c_npmi,mimno_tc, td = self.evaluate(test_data,calc4each=False)
                    c_v_lst.append(c_v), c_w2v_lst.append(c_w2v), c_uci_lst.append(c_uci),c_npmi_lst.append(c_npmi), mimno_tc_lst.append(mimno_tc), td_lst.append(td)
        # scrs = {'c_v':c_v_lst,'c_w2v':c_w2v_lst,'c_uci':c_uci_lst,'c_npmi':c_npmi_lst,'mimno_tc':mimno_tc_lst,'td':td_lst}
        '''
        for scr_name,scr_lst in scrs.items():
            plt.cla()
            plt.plot(np.array(range(len(scr_lst)))*log_every,scr_lst)
            plt.savefig(f'wlda_{scr_name}.png')
        '''
        # plt.cla()
        # for scr_name,scr_lst in scrs.items():
        #     if scr_name in ['c_v','c_w2v','td']:
        #         plt.plot(np.array(range(len(scr_lst)))*log_every,scr_lst,label=scr_name)
        # plt.title('Topic Coherence')
        # plt.xlabel('epochs')
        # plt.legend()
        # plt.savefig(f'wlda_tc_scores.png')


    def evaluate(self, test_data, calc4each=False):
        topic_words = self.show_topic_words()
        return evaluate_topic_quality(topic_words, test_data, taskname=self.taskname, calc4each=calc4each)


    def inference_by_bow(self, doc_bow):
        # doc_bow: torch.tensor [vocab_size]; optional: np.array [vocab_size]
        if isinstance(doc_bow,np.ndarray):
            doc_bow = torch.from_numpy(doc_bow)
        doc_bow = doc_bow.to(self.device)
        with torch.no_grad():
            self.wae.eval()
            theta = F.softmax(self.wae.encode(doc_bow),dim=1)
            return theta.detach().cpu().numpy()


    def inference(self, doc_tokenized, dictionary,normalize=True):
        doc_bow = torch.zeros(1,self.bow_dim)
        for token in doc_tokenized:
            try:
                idx = dictionary.token2id[token]
                doc_bow[0][idx] += 1.0
            except:
                print(f'{token} not in the vocabulary.')
        doc_bow = doc_bow.to(self.device)
        with torch.no_grad():
            self.wae.eval()
            theta = self.wae.encode(doc_bow)
            if normalize:
                theta = F.softmax(theta,dim=1)
            return theta.detach().cpu().squeeze(0).numpy()
    

    def get_doc_topic_distribution(self, dataset, n_samples=20, batch_size=256):
        self.wae.eval()
        data_loader = DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=dataset.collate_fn,
                                 worker_init_fn=SWTM.seed_worker
                                 )
        
        final_thetas = []
        for _ in range(n_samples):
            with torch.no_grad():
                collect_theta = []
                for iter, data in enumerate(data_loader):
                    txts, bows = data
                    bows = bows.to(self.device)
                    theta = self.wae.encode(bows)
                    collect_theta.extend(theta.cpu().numpy().tolist())
                
                final_thetas.append(np.array(collect_theta))     
        return np.sum(final_thetas, axis=0) / n_samples


    def get_embed(self,train_data, num=1000):
        self.wae.eval()
        data_loader = DataLoader(train_data, batch_size=512,shuffle=False, num_workers=4, collate_fn=train_data.collate_fn)
        embed_lst = []
        txt_lst = []
        cnt = 0
        for data_batch in data_loader:
            txts, bows = data_batch
            embed = self.inference_by_bow(bows)
            embed_lst.append(embed)
            txt_lst.append(txts)
            cnt += embed.shape[0]
            if cnt>=num:
                break
        embed_lst = np.concatenate(embed_lst,axis=0)[:num]
        txt_lst = np.concatenate(txt_lst,axis=0)[:num]
        return txt_lst, embed_lst


    def get_topic_word_dist(self,normalize=True):
        self.wae.eval()
        with torch.no_grad():
            idxes = torch.eye(self.n_topic).to(self.device)
            word_dist = self.wae.decode(idxes)  # word_dist: [n_topic, vocab.size]
            if normalize:
                word_dist = F.softmax(word_dist,dim=1)
            return word_dist.detach().cpu().numpy()

    def show_topic_words(self, topic_id=None, topK=15, dictionary=None):
        self.wae.eval()
        topic_words = []
        idxes = torch.eye(self.n_topic).to(self.device)
        word_dist = self.wae.decode(idxes)
        word_dist = F.softmax(word_dist, dim=1)
        vals, indices = torch.topk(word_dist, topK, dim=1)
        vals = vals.cpu().tolist()
        indices = indices.cpu().tolist()
        if self.id2token==None and dictionary!=None:
            self.id2token = {v:k for k,v in dictionary.token2id.items()}
        if topic_id == None:
            for i in range(self.n_topic):
                topic_words.append([self.id2token[idx] for idx in indices[i]])
        else:
            topic_words.append([self.id2token[idx] for idx in indices[topic_id]])
        return topic_words
    
    def load_model(self, model):
        self.wae.load_state_dict(model)
    
    @staticmethod
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

if __name__ == '__main__':
    model = WAE(encode_dims=[1024, 512, 256, 20],
                decode_dims=[20, 128, 768, 1024])
    model = model.cuda()
    inpt = torch.randn(234, 1024).cuda()
    out, mu, log_var = model(inpt)
    print(out.shape)
    print(mu.shape)
