from model_vc import Generator
import torch
import torch.nn.functional as F
import time
import datetime
import os
import torch.nn as nn
from AdversarialClassifier import AdversarialClassifier

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

class Solver(object):

    def __init__(self, vcc_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.vcc_loader = vcc_loader

        # Model configurations.
        self.lambda_cd = config.lambda_cd
        self.dim_neck = config.dim_neck
        self.dim_emb = config.dim_emb
        self.dim_pre = config.dim_pre
        self.dim_mel = config.dim_mel

        # Training configurations.
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters

        # Miscellaneous.
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.use_cuda else 'cpu')
        self.log_step = config.log_step
        self.restore_step = config.restore_step
        # Build the model and tensorboard.
        self.build_model()
        self.advloss = nn.CrossEntropyLoss()

    def build_model(self):

        self.G = Generator(self.dim_mel, self.dim_neck, self.dim_emb, self.dim_pre)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), 0.0001)
        if self.restore_step:
            ckpt = torch.load(os.path.join('./ckpt', '{}.pth'.format(self.restore_step)))
            self.G = self.G.load_state_dict(ckpt["model"])
            self.g_optimizer = self.g_optimizer.load_state_dict(ckpt['optimizer'])

        self.G.to(self.device)

    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()

    # =====================================================================================================================================#

    def train(self):
        # Set data loader.
        data_loader = self.vcc_loader
        os.makedirs('./ckpt', exist_ok=True)

        # Print logs in specified order
        keys = ['G/loss_id', 'G/loss_id_psnt', 'spk_loss', 'content_advloss', 'code_loss']

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(self.restore_step, self.num_iters):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #

            # Fetch data.
            try:
                x_real, emb_org, label = next(data_iter)
            except:
                data_iter = iter(data_loader)
                x_real, emb_org, label = next(data_iter)
            x_real = x_real.to(self.device)
            emb_org = emb_org.to(self.device)

            # =================================================================================== #
            #                               2. Train the generator                                #
            # =================================================================================== #

            self.G = self.G.train()

            x_identic, x_identic_psnt, encoder_output, spk_pred, content_pred = self.G(x_real, None, self.dim_neck)
            _, _, encoder_output_recon, _, _ = self.G(x_identic_psnt, None, self.dim_neck)
            x_real.require_grad = False
            label.requirre_grad = False
            spk_loss = self.advloss(spk_pred, label)
            content_adv_loss = self.advloss(content_pred, label)
            g_loss_id = F.mse_loss(x_identic.squeeze(1), x_real)
            g_loss_id_psnt = F.mse_loss(x_identic_psnt.squeeze(1), x_real)
            g_loss_cd = F.l1_loss(encoder_output, encoder_output_recon)

            # Backward and optimize.
            g_loss = g_loss_id + g_loss_id_psnt + self.lambda_cd * spk_loss + self.lambda_cd * content_adv_loss + 1 * g_loss_cd
            self.reset_grad()
            g_loss.backward()
            self.g_optimizer.step()

            # Logging.
            loss = {}
            loss['G/loss_id'] = g_loss_id.item()
            loss['G/loss_id_psnt'] = g_loss_id_psnt.item()
            loss['spk_loss'] = spk_loss.item()
            loss['content_advloss'] = content_adv_loss.item()
            loss['code_loss'] = g_loss_cd.item()

            # =================================================================================== #
            #                                 4. Miscellaneous                                    #
            # =================================================================================== #

            # Print out training information.
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i + 1, self.num_iters)
                for tag in keys:
                    log += ", {}: {:.4f}".format(tag, loss[tag])
                print(log)
            if (i + 1) % 50000 == 0:
                torch.save(
                    {
                        "model": self.G.state_dict(),
                        "optimizer": self.g_optimizer.state_dict(),
                    },
                    os.path.join('./ckpt', "{}.pth".format(i + 1))
                )






