import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import wandb


def save_ramachandran_plot(phis,psis,iteration, args, name="", path = None):
    # Ramachandran plot
    plt.figure(figsize=(10, 10))
    plt.hist2d(phis, psis, bins=64, norm=matplotlib.colors.LogNorm(),range=[[-np.pi, np.pi], [-np.pi, np.pi]])
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.xlabel('$\phi$', fontsize=24)
    plt.ylabel('$\psi$', fontsize=24)
    if path is None:
        plt.savefig(os.path.join(args.data_dir, f'{name}_ramachandran_{args.data_temperature}_{args.data_save_frequency}_{args.data_size}.png'), dpi=300)
    else:
        plt.savefig(os.path.join(path, f'{name}_ramachandran_{args.data_temperature}_{args.data_save_frequency}_{args.data_size}.png'), dpi=300)
    plt.close()
    #TODO: Configure wandb at some point
    if args.wandb:
        if path is None:
            # logging actual ramachandran
            metric_name = f'ramachandran_{name}'
            wandb.log({metric_name: wandb.Image(os.path.join(args.data_dir, f'{name}_ramachandran_{args.data_temperature}_{args.data_save_frequency}_{args.data_size}.png'),caption=f"{metric_name}")})

        else:
            # logging latent tica 
            metric_name = f'{name}'
            wandb.log({metric_name: wandb.Image(os.path.join(path, f'{name}_ramachandran_{args.data_temperature}_{args.data_save_frequency}_{args.data_size}.png'),caption=f"{metric_name}")})