from __future__ import absolute_import, print_function
import os
import torch

import torch.optim as optimizer
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import FashionMNIST, MNIST
from torchvision.utils import save_image

import numpy as np
import data
import sys

import statistics as stats
from pathlib import Path
from tqdm import tqdm

from options.testing_options import TestOptions
from utils.utils import *
from models.proto_VAE import Discriminator, Proto_classification

###
Uncertain = 0 # Not using uncertainty loss
opt_parser = TestOptions()
opt = opt_parser.parse(is_print=True)
use_cuda = opt.UseCUDA
if opt.GPU_num == 0:
    device = torch.device("cuda:0" if use_cuda else "cpu")
elif opt.GPU_num == 1:
    device = torch.device("cuda:1" if use_cuda else "cpu")
if use_cuda:
    print('Using CUDA')
else:
    print('Using CPU')
if opt.Uncertain == 'Yes':
    Uncertain = 1

###
batch_size_in = opt.BatchSize # batch size
num_class = 10
num_prototypes = opt.Proto_num
img_crop_size = 0
summing_dist = 1

### Setting distance metric
if opt.Metric == 'Dkl':
    distance = kl_divergence_metric
if opt.Metric == 'JSD':
    distance = jensen_shannon_distance
if opt.Metric == 'JTD':
    distance = jensen_tsallis_distance
if opt.Metric == 'WSR':
    distance = wasserstein_distance

### Set name of model
model_setting = get_model_setting(opt)

### Path setting
data_dir = os.path.join(os.getcwd(),"data")
model_dir = os.path.join(os.getcwd(),"saved_model")

### Transform
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToTensor()
])

### Loading dataset
if opt.Dataset == 'FashionMNIST':
    dataset = FashionMNIST(data_dir, download=True, transform=transform)
    chnum_in_ = 1
elif opt.Dataset == 'MNIST':
    dataset = MNIST(data_dir, download=True, transform=transform)
    chnum_in_ = 1

### Divide for train, validation
if opt.Validation == 'Yes':
    n_val = int(len(dataset) * opt.Val_percent)
    n_train = int(len(dataset) * (1 - opt.Val_percent))
    train_set, val_set = random_split(dataset, [n_train, n_val])
    print('Train, Val datas are divided by', n_train, n_val)
    train_loader = DataLoader(train_set, batch_size=batch_size_in, shuffle=True, drop_last=True) #필요없을수도 train_factor_loader = DataLoader(train_set, batch_size=batch_size_in, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=batch_size_in, shuffle=False, drop_last=True)
    D_loader = DataLoader(train_set, batch_size=batch_size_in, shuffle=True, drop_last=True)
elif opt.Validation == 'No':
    train_loader = DataLoader(dataset, batch_size=batch_size_in, shuffle=True, drop_last=True)
    D_loader = DataLoader(dataset, batch_size=batch_size_in, shuffle=True, drop_last=True)
    train_set = dataset

### Loading model structure
if (opt.ModelName == 'FactorVAE'):
    model = Proto_classification(chnum_in_, num_prototypes, num_class, opt.ModelName, opt.z_dim, sum_dist=summing_dist)
    D = Discriminator(opt.z_dim).to(device)
elif(opt.ModelName =='BetaVAE'):
    model = Proto_classification(chnum_in_, num_prototypes, num_class, opt.ModelName, opt.z_dim, sum_dist=summing_dist)
else:
    model = []
    print('Wrong Name.')

### Optimizer setting
optim = optimizer.Adam(model.parameters(), lr=opt.lr, weight_decay=0.00005, betas=(0.9, 0.999))
if Uncertain == 0:
    criterion = torch.nn.CrossEntropyLoss()
elif Uncertain == 1:
    criterion = uncertain_loss

if opt.ModelName == 'FactorVAE':
    optim_D = optimizer.Adam(D.parameters(), lr=opt.lr_D, weight_decay=0.00005, betas=(0.5, 0.9))

### Loading saved model
epoch_restart = opt.Epoch_restart
if epoch_restart is not 0 and model_dir is not None:
    model_name = os.path.join(model_dir, model_setting + '_' + str(epoch_restart) + '.pt')
    model.load_state_dict(torch.load(model_name))
    print('Model Loaded.')
    if opt.ModelName == 'FactorVAE':
        discrim_name = os.path.join(model_dir, model_setting + '_D' + str(epoch_restart) + '.pt')
        D.load_state_dict(torch.load(discrim_name))
        print('Discriminator Loaded.')

# Send models to device
model.to(device)

### Generating images
if opt.Train == 'Test':
    with torch.no_grad():
        model.eval()
        # Make Folders
        image_dir = os.path.join(model_dir, 'Test')
        os.makedirs(image_dir, exist_ok=True)
        # loading prototypes for saving images
        proto_save = model.prototypes[:, :opt.z_dim].clone().detach().view(num_prototypes, opt.z_dim, 1, 1).to(device)
        proto_save = model.decoder(proto_save).cpu()
        save_image(proto_save.view(num_prototypes, chnum_in_, 32, 32), image_dir + '/' + 'proto' + '.png', nrow=5, padding=20)
        for i in range(opt.Proto_num):
            proto_each = proto_save[i]
            save_image(proto_each.view(chnum_in_, 32, 32), image_dir + '/' + 'proto' + str(i+1) + '.png')
        
        ### Making prototype_connection
        constructing_connection = 1
        if constructing_connection == 1:
            print('Printing prototype connection')
            for i in range(opt.Proto_num):
                for j in range(opt.Proto_num):
                    proto1 = i
                    proto2 = j
                    p_mu_1 = model.prototypes[proto1, :opt.z_dim].clone().detach()
                    p_mu_2 = model.prototypes[proto2, :opt.z_dim].clone().detach()
                    
                    samples = []
                    minus_value = p_mu_2 - p_mu_1
                    inter_num = 8

                    for val in range(inter_num+1):
                        # val = 0~8
                        p_mu = p_mu_1 + minus_value/inter_num*val
                        p_mu = p_mu.view(1, opt.z_dim, 1, 1)
                        sample = model.decoder(p_mu).data
                        samples.append(sample)

                    samples = torch.cat(samples, dim=0).cpu()
                    save_image(samples.view(samples.shape[0], chnum_in_, 32, 32), image_dir + '/' + 'proto_connect' + str(i+1) + '_' + str(j+1) + '.png', nrow=9)

        ### Making input_prototype_connection
        print('Printing input prototype connection')
        constructing_input_connection = 1
        if constructing_input_connection == 1:
            for x, _ in train_loader:
                x = x.to(device)

                if opt.ModelName == 'BetaVAE':
                    _, dis_input, _, mu_x, _ = model(x, opt.ModelName, distance, sum_dist=summing_dist)
                if opt.ModelName == 'FactorVAE':
                    _, dis_input, _, mu_x, _, _ = model(x, opt.ModelName, distance, sum_dist=summing_dist)
                print(dis_input[0])

                mu_x = mu_x[0, :, :, :].view(1, opt.z_dim, 1, 1)
                x_recon = model.decoder(mu_x).data
                x_input = x[0].view(32, 32).data
                save_image(x_input, image_dir + '/' + 'Input' + '.png')
                save_image(x_recon.view(x_recon.shape[0], chnum_in_, 32, 32), image_dir + '/' + 'Input_reshape' + '.png')
                mu_x = mu_x.view(opt.z_dim)
                
                for i in range(opt.Proto_num):
                    proto = i
                    p_mu = model.prototypes[proto, :opt.z_dim].clone().detach()
                
                    samples = []
                    samples2 = []
                    minus_value = p_mu - mu_x
                    inter_num = 10
                    inter_num2 = 4

                    for val in range(inter_num+1):
                        # val = 0~8
                        mu = mu_x + minus_value/inter_num*val
                        mu = mu.view(1, opt.z_dim, 1, 1)
                        sample = model.decoder(mu).data
                        samples.append(sample)
                    for val2 in range(inter_num2+1):
                        # val = 0~8
                        mu2 = mu_x + minus_value/inter_num2*val2
                        mu2 = mu2.view(1, opt.z_dim, 1, 1)
                        sample2 = model.decoder(mu2).data
                        samples2.append(sample2)

                    samples = torch.cat(samples, dim=0).cpu()
                    samples2 = torch.cat(samples2, dim=0).cpu()
                    save_image(samples.view(samples.shape[0], chnum_in_, 32, 32), image_dir + '/' + 'deep_in_pro_connect' + str(i+1) + '.png', nrow=11)
                    save_image(samples2.view(samples2.shape[0], chnum_in_, 32, 32), image_dir + '/' + 'in_pro_connect' + str(i+1) + '.png', nrow=5)
                break
    print('Image saved for Test:' + str(model_setting))
    sys.exit()

### Train start
if opt.ModelName == 'FactorVAE':
    D.train()
    ones = torch.ones(batch_size_in, dtype=torch.long, device=device)
    zeros = torch.zeros(batch_size_in, dtype=torch.long, device=device)

Acc_list = []
recon_list = []
global_step = 0
for epoch in range(epoch_restart + 1, opt.Epoch + 1):
    model.train()
       
    batch = tqdm(train_loader, total=len(train_set) // batch_size_in)
    train_loss = []
    recon = []
    accuracy = []

    # Train with train data
    for (x, y), (x_2, y_2) in zip(batch, D_loader):
        x = x.to(device)
        y = y.to(device)
        
        if opt.ModelName == 'BetaVAE':
            logits, proto_distance, x_recon, mu, logvar = model(x, opt.ModelName, distance, sum_dist=summing_dist)
        if opt.ModelName == 'FactorVAE':
            logits, proto_distance, x_recon, mu, logvar, z = model(x, opt.ModelName, distance, sum_dist=summing_dist)

        recon_loss = reconstruction_loss(x, x_recon)
        total_kld = kl_divergence(mu, logvar)
        pred = logits.argmax(dim=1, keepdim=True)
        correct = pred.eq(y.view_as(pred)).sum().item()
        acc = correct / x.shape[0]
        
        if opt.ModelName == 'BetaVAE':
            vae_loss = recon_loss + opt.Beta * total_kld
        if opt.ModelName == 'FactorVAE':
            D_z = D(z)
            vae_tc_loss = (D_z[:, :1] - D_z[:, 1:]).mean() # sum인지 mean인지 다른 loss 보고 정확히 구분해 줘야함
            vae_loss = recon_loss + total_kld + opt.Gamma * vae_tc_loss

        if Uncertain == 0:
            total_loss = opt.lambda_class * criterion(logits, y) + opt.lambda_vae * vae_loss
        if Uncertain == 1:
            # Using uncertainty
            evidence = torch.exp(torch.clamp(logits, min=-10, max=10))
            alpha = evidence + 1

            u = 10 / torch.sum(alpha, dim=1, keepdim=True)
            prob = alpha / torch.sum(alpha, dim=1, keepdim=True)

            n_batches = len(train_set) // batch_size_in
            annealing_step = 50 * n_batches
            
            y_hot = F.one_hot(y, num_classes=10)
            class_loss = torch.mean(criterion(y_hot, alpha, global_step, annealing_step, device))
            total_loss = opt.lambda_class * class_loss + opt.lambda_vae * vae_loss

        # For printing loss
        train_loss.append(total_loss.detach().item())
        recon.append(recon_loss.detach().item())
        accuracy.append(acc)
        batch.set_description('Epoch:' + str(epoch) + ' Train Loss:' + str(stats.mean(train_loss)) + ' Classification Acc(%):' + str(100*stats.mean(accuracy)))

        # Backprop
        optim.zero_grad()
        if opt.ModelName == 'BetaVAE':
            total_loss.backward()
        if opt.ModelName == 'FactorVAE':
            total_loss.backward(retain_graph=True)

        if opt.ModelName == 'FactorVAE':
            x_2 = x_2.to(device)
            z_2 = model(x_2, opt.ModelName, distance, sum_dist=summing_dist, no_dec=True)
            z_perm = permute_dims(z_2).detach()
            D_z_perm = D(z_perm)
            D_tc_loss = 0.5*(F.cross_entropy(D_z, zeros) + F.cross_entropy(D_z_perm, ones))

            optim_D.zero_grad()
            D_tc_loss.backward()
            optim_D.step()
        optim.step()

        global_step += 1

    # Test with validation data
    if opt.Validation == 'Yes':
        with torch.no_grad():
            model.eval()
            
            val_batch = tqdm(val_loader, total=len(val_set) // batch_size_in)
            val_acc = []
            for x2, y2 in val_batch:
                x2 = x2.to(device)
                y2 = y2.to(device)

                if opt.ModelName == 'BetaVAE':
                    logits2, _, x_recon2, _, _ = model(x2, opt.ModelName, distance, sum_dist=summing_dist)
                if opt.ModelName == 'FactorVAE':
                    logits2, _, x_recon2, _, _, _ = model(x2, opt.ModelName, distance, sum_dist=summing_dist)

                recon_loss2 = reconstruction_loss(x2, x_recon2)
                val_pred = logits2.argmax(dim=1, keepdim=True)
                correct2 = val_pred.eq(y2.view_as(val_pred)).sum().item()
                acc2 = correct2 / x2.shape[0]

                val_acc.append(acc2)
                val_batch.set_description('Epoch:' + str(epoch) + ' Validation Accuracy:' + str(100*stats.mean(val_acc)))
    Acc_list.append(100*stats.mean(val_acc))
    recon_list.append(recon_loss2.detach().item())
   
    if epoch % 100 == 0:
        # Saving model
        model_file = model_dir / Path(model_setting + '_' + str(epoch) + '.pt')
        torch.save(model.state_dict(), str(model_file))
        if (opt.ModelName == 'FactorVAE'):
            discrim_file = model_dir / Path(model_setting + '_D' + str(epoch) + '.pt')
            torch.save(D.state_dict(), str(discrim_file))
        print('Model saved for epoch:' + str(epoch))
        # Make folder to save image
        image_dir = os.path.join(model_dir, model_setting)
        os.makedirs(image_dir, exist_ok=True)
        # Saving accuracy, mu&sigma
        log_acc_dir = os.path.join(image_dir,'Accuracy.txt')
        f = open(log_acc_dir, 'a')
        f.write('Epoch:' + str(epoch) + ' Validation Accuracy[20]:' + str(Acc_list[-20:]) + '%\n')
        f.close()
        log_recon_dir = os.path.join(image_dir,'Reconstruction_error.txt')
        f = open(log_recon_dir, 'a')
        f.write('Epoch:' + str(epoch) + ' Validation Recon_error[20]:' + str(recon_list[-20:]) + '%\n')
        f.close()
        log_dir = os.path.join(image_dir,'Mu_Sigma.txt')
        f = open(log_dir, 'a')
        f.write('Epoch:' + str(epoch) + '%\n')
        f.write('Mu:' + str(model.prototypes[:, :opt.z_dim].clone().detach().cpu().numpy()) + '%\n')
        f.write('Sigma:' + str(model.prototypes[:, opt.z_dim:].clone().detach().div(2).exp().cpu().numpy()) + '%\n')
        f.close()
        # Saving reconstructed images
        save_recon_image(x, x_recon, image_dir, str(epoch), chnum_in_) # save_image => made by torchvision.utils.make_grid => tensor
        # Saving sample images
        with torch.no_grad():
            sample = torch.randn(64, opt.z_dim, 1, 1).to(device)
            sample = model.decoder(sample).cpu()
            save_image(sample.view(64, chnum_in_, 32, 32), image_dir + '/' + 'sample_' + str(epoch) + '.png')
            # loading prototypes for saving images
            proto_save = model.prototypes[:, :opt.z_dim].clone().detach().view(num_prototypes, opt.z_dim, 1, 1).to(device)
            proto_save = model.decoder(proto_save).cpu()
            #save_proto_image(proto_save, image_dir, str(epoch), num_prototypes) ##########
            save_image(proto_save.view(num_prototypes, chnum_in_, 32, 32), image_dir + '/' + 'proto_' + str(epoch) + '.png')
            
        print('Image saved for epoch:' + str(epoch))

