import os

import argparse
import torch
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
from torchvision.utils import save_image
from siren_pytorch import Sine
from tensorboardX import SummaryWriter
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt
import time

circular_pad = False

if circular_pad:
    pad_size = 1
else:
    pad_size = 0
    
def psnr(output,truth):
    abs_error = torch.abs(output - truth)
    l2 = torch.sqrt(torch.mean(abs_error**2))
    psnr_val = 20 * torch.log10(torch.abs(truth).max() / l2)
    return psnr_val

def ssim(
    ref: torch.Tensor,
    x: torch.Tensor,
    multichannel: bool = False,
    data_range = None,
    **kwargs,
):
    """Compute structural similarity index metric. Does not preserve autograd.

    Based on implementation of Wang et. al. [1]_

    The image is first converted to magnitude image and normalized
    before the metric is computed.

    Args:
        ref (torch.Tensor): The target. Shape (...)x2
        x (torch.Tensor): The prediction. Same shape as `ref`.
        multichannel (bool, optional): If `True`, computes ssim for real and
            imaginary channels separately and then averages the two.
        data_range(float, optional): The data range of the input image
        (distance between minimum and maximum possible values). By default,
        this is estimated from the image data-type.
        """

    gaussian_weights = kwargs.get("gaussian_weights", True)
    sigma = kwargs.get("sigma", 1.5)
    use_sample_covariance = kwargs.get("use_sample_covariance", False)

    assert ref.shape[-1] == 2
    assert x.shape[-1] == 2

    if not multichannel:
        ref = cplx.abs(ref)
        x = cplx.abs(x)

    x = x.squeeze().numpy()
    ref = ref.squeeze().numpy()

    if data_range in ("range", "ref-range"):
        data_range = ref.max() - ref.min()
    elif data_range in ("ref-maxval", "maxval"):
        data_range = ref.max()
    elif data_range == "x-range":
        data_range = x.max() - x.min()
    elif data_range == "x-maxval":
        data_range = x.max()

    return structural_similarity(
        ref,
        x,
        data_range=data_range,
        gaussian_weights=gaussian_weights,
        sigma=sigma,
        use_sample_covariance=use_sample_covariance,
        multichannel=multichannel
    )

def add_noise(img,noise_std):
    device = torch.device("cpu")
    noise = torch.randn(img.size()) * noise_std
    noisy_img = img + noise.to(device)
    return noisy_img

def to_img(x,dim,chans,num_imgs):
    x = x.view(x.size(0), chans, dim+2*pad_size, num_imgs*(dim+2*pad_size))
    return x

def vis_image(img,writer,epoch,args,num_imgs,cur_im):
    img -= img.min()
    img /= img.max()
    num_to_vis = 1
    if args.dataset == 'fastmri':
        x = img.unsqueeze(1)
    else:
        x_all = to_img(img.cpu().data,args.dim,args.chans,num_imgs)
        x = x_all[0:num_to_vis,:,:,:]
        
    #save_image(x_all,args.log_dir+'/'+args.problem+'_x_{}.png'.format(epoch))
    x_grid = torchvision.utils.make_grid(x, nrow=1, pad_value=1)
    writer.add_image(cur_im, x_grid, epoch)

def visualize(model,testset,args):
    writer = SummaryWriter(log_dir=args.log_dir)
    vis_test_imgs = next(iter(testset)) 
    vis_img, vis_labels = vis_test_imgs
    if circular_pad:
        vis_img = nn.functional.pad(vis_img, 2*(pad_size,)+2*(pad_size,), mode='circular')
    vis_img = vis_img.view(vis_img.size(0),1, -1)
    vis_noisy_img = add_noise(vis_img,args.noise_std)
    vis_noisy_img = Variable(vis_noisy_img).cuda()
    vis_output,vis_feats = model(vis_noisy_img)
    vis_out_cpu = vis_feats.cpu().detach().numpy()
    n_components = 20
    disp_components = 4
    pca = PCA(n_components=n_components)
    vis_out_cpu = pca.fit_transform(np.squeeze(vis_out_cpu))  
    
    if args.problem == 'primal':
        pca_reshaped = pca.components_.reshape(n_components,args.dim+2*pad_size,args.dim+2*pad_size,args.chans)
    
        fig, axs = plt.subplots(1,disp_components)
        for i in range(disp_components):
            if args.dataset == 'MNIST':
                pca_comps = pca_reshaped[i,:,:,0]*255
                axs[i].imshow(pca_comps,cmap='gray')
                #x_grid = torchvision.utils.make_grid(torch.from_numpy(pca_comps), nrow=1, pad_value=1)
                #writer.add_image('pca_components', x_grid)
            elif args.dataset == 'CIFAR-10':
                pca_comps = pca_reshaped[i,:,:,:]*255
                axs[i].imshow(pca_comps,cmap='gray')
                #x_grid = torchvision.utils.make_grid(torch.from_numpy(pca_comps), nrow=1, pad_value=1)
                #writer.add_image('pca_components', x_grid)
        plt.savefig(args.log_dir+'/pca_components'+str(args.noise_std)+'.png')
        
    tsne = TSNE(n_components=2).fit_transform(vis_out_cpu)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    for i in range(10):
        ax.scatter(tsne[i==vis_labels.cpu().detach().numpy(),0], tsne[i==vis_labels.cpu().detach().numpy(),1],color='C'+str(i),label=str(i))    
    ax.legend()
    plt.savefig(args.log_dir+'/'+args.problem+'_tsne_vis_feats_out_twolayerlin_noise_'+str(args.noise_std)+'.png')
    