import copy
import torch
import argparse
import importlib
import numpy as np
from tqdm import tqdm
from pathlib import Path
from sklearn.manifold import TSNE

import matplotlib.pyplot as plt
import plotly.express as px
from plotly import graph_objects as go

from torch.utils.data import DataLoader
from torchvision.transforms import GaussianBlur
from torch.nn import functional as F

from vae import VAE
from dgnn.utils import halfplane2disk

def plot_latent(args, data, vae):
    # if args.dataset == "Breakout":
    #     data.filter_dataset(3.8)
    
    colors = data.features
    data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=1)

    latents = []
    with torch.no_grad():
        for x in tqdm(data_loader):
            x = x.to(args.device)
            mean, _ = vae.encoder_layer(vae.encoder(x))
            mean[..., 1] = (0.5 * mean[..., 1]).exp()
            latents.append(mean.detach().cpu())

        latents = torch.concat(latents, dim=0)
        latents_disk = halfplane2disk(latents, args.c)
        
        latents_disk = latents_disk.cpu().numpy()
        for d in range(args.latent_dim):
            fig = px.scatter(
                x=latents_disk[:, d, 0], 
                y=latents_disk[:, d, 1], 
                color=colors
            )

            fig.update_yaxes(
                scaleanchor='x',
                scaleratio=1
            )
            fig.update_traces(marker=dict(size=5))

            fig.write_image(f'figures/latent_disk_{args.dataset}_{d:02}.png')  # , width=4096, height=4096, scale=5)

        latents = latents.numpy()
        print(colors.shape, latents.shape)
        for d in range(args.latent_dim):
            corr = np.corrcoef(latents[..., d, 1], -colors)[0, 1]
            print(f"Dim {d} correlation {corr}")
            
            fig = px.scatter(
                x=latents[:, d, 0], 
                y=latents[:, d, 1], 
                color=colors
            )

            fig.update_yaxes(
                scaleanchor='x',
                scaleratio=1
            )
            fig.update_traces(marker=dict(size=5))

            fig.write_image(f'figures/latent_{args.dataset}_{d:02}.png')  # , width=4096, height=4096, scale=5)

    print("done")
    return

def plot_tsne(args, data, vae):
    tsne = TSNE(n_components=2, random_state=0)

    data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=1)
    
    latents = []
    with torch.no_grad():
        for idx, x in enumerate(data_loader):
            x = x.to(args.device)
            mean, _ = vae.encoder_layer(vae.encoder(x)) # (B, D, 2)
            mean = mean.view(mean.size(0), -1) # (B, D * 2)
            latents.append(mean) 
    latents = torch.concat(latents, dim=0) # (-, D*2)
    latents = latents.cpu().numpy()

    latents_2d = tsne.fit_transform(latents[:])

    plt.scatter(x = latents_2d[:, 0], 
                y = latents_2d[:, 1], 
                c = data.features[:])
    plt.legend()
    plt.savefig(f"figures/tnse_{args.dist}.png")
    plt.close()

    print("done")
    return

def reconstruct_mu(args, data, vae):
    data.filter_class(0)
    data_loader = DataLoader(data, batch_size=1, shuffle=False, num_workers=1)
    with torch.no_grad():
        for idx, x in enumerate(data_loader):
            if idx > 10: break
            x = x.to(args.device)
            mean, _ = vae.encoder_layer(vae.encoder(x))

            mean[..., 0] = mean[..., 0] - 5
            for inc in tqdm(range(20)):
                recon = vae.decoder(vae.decoder_layer(mean))
                recon = torch.sigmoid(recon).cpu().numpy()

                plt.imshow(recon[0, 0, :, :])
                plt.savefig(f"figures/mu_{args.dataset}_{idx:04d}_{inc:03d}.png")
                plt.close()

                mean[..., 0] = mean[..., 0] + 1/2
    print("done")
    return

def with_hierarchy(args, data, vae):
    rewards = data.colors
    colors = []
    for r in rewards:
        colors.append(str(r))

    data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=1)

    latents = []
    with torch.no_grad():
        for x in tqdm(data_loader):
            x = x.to(args.device)
            mean, _ = vae.encoder_layer(vae.encoder(x))
            mean[..., 1] = (0.5 * mean[..., 1]).exp()
            latents.append(mean.detach().cpu())

        latents = torch.concat(latents, dim=0)
        # latents = halfplane2disk(latents, args.c)

        latents = latents.numpy()
        for d in range(args.latent_dim):
            fig = px.scatter(
                x=latents[:3000, d, 0], 
                y=latents[:3000, d, 1], 
                color=colors[:3000]
            )

            fig.update_yaxes(
                scaleanchor='x',
                scaleratio=1
            )
            fig.update_traces(marker=dict(size=5))

            fig.write_image(f'figures/hierarchy_{args.dataset}_{d:02}.png')  # , width=4096, height=4096, scale=5)
    print("done")
    return


def reconstruct_sigma(args, data, vae, index):
    sample = torch.Tensor(data.data[index]).view(-1, 1, 28, 28)
    #sample = (sample - 0.5) * 2

    batch = sample.to(args.device)
    with torch.no_grad():
        mean, _ = vae.encoder_layer(vae.encoder(batch))
       
        for _ in tqdm(range(200)):
            mean[..., 1] = mean[..., 1].exp()
            mean[..., 1] += float(1/10)
            mean[..., 1] = mean[..., 1].log()
            
            recon = vae.decoder(vae.decoder_layer(mean))
       
            recon = recon.cpu().numpy()
            plt.imshow(recon[0, 0, :, :])
            plt.savefig(f"figures/recon_{args.dataset}_{_:03}.png")
            plt.close()
    print("done")
    return

def error_correlation(args, data, vae):
    data_ = copy.deepcopy(data)
    for c in np.unique(data.features):
        total_corr = 0
        data = copy.deepcopy(data_)
        data.filter_class(c)
        data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=1)

        # for d in range(args.latent_dim):
        sigmas, recon_losses = [], []
        
        with torch.no_grad():
            for x in data_loader:
                x = x.to(args.device)
                # loss, elbo, z, mean, recon_loss, kl_loss, _ = vae(x)
                mean, logvar = vae.encoder_layer(vae.encoder(x))  # (B, D)
                x_generated = vae.decoder(vae.decoder_layer(mean))  # (B, 1, 28, 28)
                recon_loss = F.binary_cross_entropy_with_logits(
                    x_generated, 
                    x, 
                    reduction='none'
                ) # (B)
                recon_loss = recon_loss.mean(dim=[-1, -2, -3]) # (B)
                if args.dist == "EuclideanNormal":
                    sigma = (0.5 * logvar).exp()
                else:
                    sigma = (0.5 * mean[..., 1]).exp() # (B)
                #sigma = mean[..., 1] # (B)
                sigmas.append(sigma.pow(2).sum(dim=1))
                recon_losses.append(recon_loss)

                for b in tqdm(range(x_generated.size(0))):
                    plt.imshow(torch.sigmoid(x_generated[b, 0, :, :]).detach().cpu().numpy())
                    plt.savefig(f"figures/recon_error_{args.dataset}_{b:03}.png")
                    plt.close()
                exit()
            sigmas = torch.concat(sigmas, dim=0)
            recon_losses = torch.concat(recon_losses, dim=0)

        corr = np.corrcoef(sigmas.cpu().numpy(), recon_losses.cpu().numpy())[0, 1]
        total_corr += corr
        # print(f"class {c} dim {d}:: Correlation btw. sigma and loss {corr}")
        print(f"** class {c}:: Correlation btw. sigma and loss {corr}")
        print()
    print("done")
    return

def error_correlation1(args, data, vae):
    data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=1)

    sigmas, recon_losses = [], []
    with torch.no_grad():
        for x in data_loader:
            x = x.to(args.device)
            # loss, elbo, z, mean, recon_loss, kl_loss, _ = vae(x)
            mean, _ = vae.encoder_layer(vae.encoder(x))  # (B, D)
            x_generated = vae.decoder(vae.decoder_layer(mean))  # (B, 1, 28, 28)
            recon_loss = F.binary_cross_entropy_with_logits(
                x_generated, 
                x, 
                reduction='none'
            ) # (B)
            recon_loss = recon_loss.mean(dim=[-1, -2, -3]) # (B)
            #sigma = (mean[..., 1] * 0.5).exp() # (B)
            sigma = mean[..., 1] # (B)
           
            sigmas.append(sigma.sum(dim=1))
            recon_losses.append(recon_loss)
        sigmas = torch.concat(sigmas, dim=0)
        recon_losses = torch.concat(recon_losses, dim=0)

    corr = np.corrcoef(sigmas.cpu().numpy(), recon_losses.cpu().numpy())[0, 1]
    print(f"Correlation btw. sigma and loss {corr}")
    return

def plot_blur(args, data, vae, index=0):
    figs = []
    for d in range(args.latent_dim):
        figs.append(go.Figure())
    
    sample = torch.Tensor(data.data[index]).view(-1, 1, 28, 28)
    sample = (sample - 0.5) * 2

    batch, intensity = [], list(range(1, 7))
    for ii in intensity:
        transform = GaussianBlur(kernel_size=9, sigma=ii)
        t_sample = transform(sample)
        plt.imshow(t_sample[0, 0, :, :])
        plt.savefig(f"figures/ex_{args.dataset}_{ii:02}.png")
        plt.close()

        batch.append(t_sample)
    batch = torch.stack(batch, dim=0).to(args.device).view(-1, 1, 28, 28)
    
    with torch.no_grad():
        mean, _ = vae.encoder_layer(vae.encoder(batch))
    mean[..., 1] = (0.5 * mean[..., 1]).exp()

    latents = mean.cpu()
    latents = latents.numpy()
    for d in range(args.latent_dim):
        fig = px.scatter(
            x=latents[:, d, 0], 
            y=latents[:, d, 1], 
            color=intensity
        )

        fig.update_yaxes(
            scaleanchor='x',
            scaleratio=1
        )
        fig.update_traces(marker=dict(size=5))

        fig.write_image(f'figures/blur_{args.dataset}_{d:02}.png')  # , width=4096, height=4096, scale=5)

    print("done")
    return

def blur_correlation(args, data, vae):
    sample = torch.Tensor(data.data) # (B, 28, 28)
    sample = (sample - 0.5) * 2

    batch, intensity = [], list(range(1, 6))
    for ii in intensity:
        transform = GaussianBlur(kernel_size=9, sigma=ii)
        t_sample = transform(sample)
        batch.append(t_sample)
    batch = torch.stack(batch, dim=1).to(args.device).view(-1, sample.size(1), sample.size(2)) # (B * len(intensity), 28, 28)
    print('batch', batch.size())

    with torch.no_grad():
        mean, logvar = vae.encoder_layer(vae.encoder(batch))
    # mean[..., 1] = (0.5 * mean[..., 1]).exp()
    # mean = mean.cpu().numpy() # (B * I, D, 2)

    if args.dist == "EuclideanNormal":
        sigma = (0.5 * logvar).exp()
    else:
        mean, sigma = mean[..., 0], (0.5 * mean[..., 1]).exp()
    print('mean sigma', mean.size(), sigma.size())

    B = sample.size(0)
    means  = mean.cpu().numpy().reshape([B, len(intensity), -1]).transpose(2, 0, 1) # (B, I, D)
    sigmas = sigma.cpu().numpy().reshape([B, len(intensity), -1]).transpose(2, 0, 1) # (B, I, D)
    ins: np.ndarray = np.array(intensity)  # (I)

    result_mean = 0
    for d in range(args.latent_dim):
        total_corr = 0
        for b in tqdm(range(means.shape[1])):
            base = means[d, b, :]
            base = base - base[0]
            corr = np.corrcoef(base, ins)
            total_corr += corr[0, 1]

        total_corr /= means.shape[1]
        result_mean += total_corr
        print(f"Dimension {d} : Correlation btw. means and intensity: {total_corr}.")

    result_sigma = 0
    for d in range(args.latent_dim):
        total_corr = 0
        for b in tqdm(range(sigmas.shape[1])):
            base = sigmas[d, b, :]
            base = base - base[0]
            corr = np.corrcoef(base, ins)
            total_corr += corr[0, 1]
        total_corr /= sigmas.shape[1]
        result_sigma += total_corr
        print(f"Dimension {d} : Correlation btw. sigma and intensity: {total_corr}.")
    print(f"** Total mean of correlation btw. means : {result_mean/args.latent_dim} and sigmas : {result_sigma/args.latent_dim}")
    print("done")
    return

def anomaly_detection(args, data1, data2, vae):
    data_loader1 = DataLoader(data1, batch_size=args.batch_size, shuffle=False, num_workers=1)
    data_loader2 = DataLoader(data2, batch_size=args.batch_size, shuffle=False, num_workers=1)

    colors = []
    colors += ["Anomal_dataset"] * len(data2)
    colors += ["MNIST"] * len(data1)

    latents = []
    with torch.no_grad():
        for x in tqdm(data_loader2):
            x = x.to(args.device)
            mean, _ = vae.encoder_layer(vae.encoder(x))
            mean[..., 1] = (0.5 * mean[..., 1]).exp()
            latents.append(mean.detach().cpu())

        for x in tqdm(data_loader1):
            x = x.to(args.device)
            mean, _ = vae.encoder_layer(vae.encoder(x))
            mean[..., 1] = (0.5 * mean[..., 1]).exp()
            latents.append(mean.detach().cpu())

        latents = torch.concat(latents, dim=0)
        # latents = halfplane2disk(latents, args.c)

        latents = latents.numpy()
        for d in range(args.latent_dim):
            fig = px.scatter(
                x=latents[:, d, 0], 
                y=latents[:, d, 1], 
                color=colors
            )

            fig.update_yaxes(
                scaleanchor='x',
                scaleratio=1
            )
            fig.update_traces(marker=dict(size=5))

            fig.write_image(f'figures/anomaly_{args.dataset}_{d:02}.png')  # , width=4096, height=4096, scale=5)
    print("done")
    return

if __name__ == "__main__":
    parser = argparse.ArgumentParser(add_help=True)
    parser.add_argument('--dataset', type=str, default="MNIST")
    parser.add_argument('--data_dir', type=str, default='data/')
    parser.add_argument('--latent_dim', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--dist', type=str, default="PGMNormal")
    parser.add_argument('--layer', type=str, default="Vanilla")
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--ckpt_path', type=str)
    parser.add_argument('--c', type=float, default=-0.5)
    parser.add_argument('--mode', type=str, default='with_hierarchy', 
                        help='plot_latent, plot_tsne, \
                              reconstruct_mu, with_hierarchy, \
                              reconstruct_sigma, error_correlation, blur_data, blur_correlation, \
                              anomaly_detection')
    parser.add_argument('--num_cluster', type=int, default=3)
    parser.add_argument('--sample_index', type=int, default=0)
    parser.add_argument('--anomal_dataset', type=str, default="FashionMNIST")

    args = parser.parse_args()

    np.random.seed(1)
    torch.manual_seed(1)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.set_default_tensor_type(torch.DoubleTensor)
    
    # dataset
    if args.dataset == "MNIST":
        from tasks.MNIST import Dataset, Encoder, Decoder
    elif args.dataset == "Omniglot":
        from tasks.Omniglot import Dataset, Encoder, Decoder
    elif args.dataset == "FashionMNIST":
        from tasks.FashionMNIST import Dataset, Encoder, Decoder
    elif args.dataset == "Breakout":
        from tasks.Breakout import Dataset, Encoder, Decoder
    
    if args.mode == "with_hierarchy":
        data = Dataset(args, is_train=False, \
                       with_hierarchy=True, num_cluster=args.num_cluster)
    else:
        data = Dataset(args, is_train=False)
            
    if args.mode == "anomaly_detection":
        if args.anomal_dataset == "MNIST":
            from tasks.MNIST import Dataset as AnomalDataset
        if args.anomal_dataset == "FashionMNIST":
            from tasks.FashionMNIST import Dataset as AnomalDataset
        if args.anomal_dataset == "Omniglot":
            from tasks.Omniglot import Dataset as AnomalDataset
        if args.anomal_dataset == "Breakout":
            from tasks.Breakout import Dataset as AnomalDataset
        data2 = AnomalDataset(args, is_train=False)

    # model
    encoder = Encoder(args)
    dist_module = importlib.import_module(f'distributions.{args.dist}')
    encoder_layer = getattr(dist_module, f'{args.layer}EncoderLayer')(args, encoder.output_dim)
    decoder = Decoder(args)
    decoder_layer = getattr(dist_module, f'{args.layer}DecoderLayer')(args)
    vae = VAE(
        args,
        None, 
        None, 
        encoder, 
        encoder_layer, 
        decoder, 
        decoder_layer, 
        None
    ).to(args.device)
    vae.load_state_dict(
        torch.load(
            Path(args.ckpt_path) / 'model.pt', 
            map_location=args.device
        )
    )

    # main
    if args.mode == "plot_latent":
        plot_latent(args, data, vae)
    elif args.mode == "plot_tsne":
        plot_tsne(args, data, vae)
    elif args.mode == "with_hierarchy":
        with_hierarchy(args, data, vae)
    elif args.mode == "reconstruct_mu":
        reconstruct_mu(args, data, vae)
    elif args.mode == "error_correlation":
        error_correlation(args, data, vae)
    elif args.mode == "blur_data":
        plot_blur(args, data, vae, args.sample_index)
    elif args.mode == "blur_correlation":
        blur_correlation(args, data, vae)
    elif args.mode == "anomaly_detection":
        anomaly_detection(args, data, data2, vae)
    elif args.mode == "reconstruct_sigma":
        reconstruct_sigma(args, data, vae, args.sample_index)
