import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
def plot_accuracy(train_acc,test_acc,args,save_dir):
    plt.figure(figsize=(12, 8))
    sns.set_style("whitegrid", {'axes.edgecolor': 'white', 'axes.facecolor': 'lightgray'})
    sns.set_context("talk")

    # Plotting with more distinct colors and line styles for better representation, using log scale for x-axis
    plt.plot([i * args.batch_size for i in range(0, args.epochs)], train_acc, label=f'Train accuracy (sample={args.sample})', color='navy', linestyle='-', linewidth=2)
    plt.plot([i * args.batch_size for i in range(0, args.epochs)], test_acc, label=f'Test accuracy (sample={args.sample})', color='crimson', linestyle='-', linewidth=2)
    plt.xscale('log')   

    plt.xlabel('Optimization Steps')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(save_dir + 'acc.png', format='pdf', bbox_inches='tight')

def plot_accuracy_21(train_acc1, test_acc1, train_acc2, test_acc2, args1,args2, save_dir):
    plt.figure(figsize=(12, 8))
    sns.set_style("whitegrid", {'axes.edgecolor': 'white', 'axes.facecolor': 'lightgray'})
    sns.set_context("talk")

    # Plotting with more distinct colors and line styles for better representation, using log scale for x-axis
    plt.plot([i * args1.batch_size for i in range(0, args1.epochs)], train_acc1, label=f'Train accuracy 1, random split', color='navy', linestyle='-', linewidth=2)
    plt.plot([i * args1.batch_size for i in range(0, args1.epochs)], test_acc1, label=f'Test accuracy 1, random split', color='crimson', linestyle='--', linewidth=2)
    plt.plot([i * args2.batch_size for i in range(0, args2.epochs)], train_acc2, label=f'Train accuracy 2, block split', color='green', linestyle='-', linewidth=2)
    plt.plot([i * args2.batch_size for i in range(0, args2.epochs)], test_acc2, label=f'Test accuracy 2, block split', color='orange', linestyle='--', linewidth=2)
    plt.xscale('log')   

    plt.xlabel('Optimization Steps')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(save_dir + 'acc.png', format='png', bbox_inches='tight')
    # plt.savefig(save_dir + 'acc.pdf', format='pdf', bbox_inches='tight')

def plot_accuracy_mv(train_acc,test_acc,args,ws,save_dir):
    plt.figure(figsize=(12, 8))
    sns.set_style("whitegrid", {'axes.edgecolor': 'white', 'axes.facecolor': 'lightgray'})
    sns.set_context("talk")

    # Applying moving average to remove noise
    window_size = ws
    train_acc_smoothed = np.convolve(train_acc, np.ones(window_size)/window_size, mode='valid')
    test_acc_smoothed = np.convolve(test_acc, np.ones(window_size)/window_size, mode='valid')

    # Plotting with more distinct colors and line styles for better representation, using log scale for x-axis
    plt.plot([i * args.batch_size for i in range(0, len(train_acc_smoothed))], train_acc_smoothed, label=f'Train accuracy (sample={args.sample})', color='navy', linestyle='-', linewidth=2)
    plt.plot([i * args.batch_size for i in range(0, len(test_acc_smoothed))], test_acc_smoothed, label=f'Test accuracy (sample={args.sample})', color='crimson', linestyle='-', linewidth=2)
    plt.xscale('log')   

    plt.xlabel('Optimization Steps')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    # plt.savefig(save_dir + 'acc.pdf', format='pdf', bbox_inches='tight')

def plot_number_distribution(train_data, p):
    counts = []
    for i in range(p):
        count = ((train_data[:,0]==i) | (train_data[:,2]==i)).sum().item()
        counts.append(count)

    plt.figure(figsize=(12,6))
    plt.bar(range(p), counts)
    plt.xlabel('Number')
    plt.ylabel('Count in positions 1 or 3') 
    plt.title('Distribution of numbers in positions 1 and 3 of training data')
    plt.show()
