import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import math
import numpy as np
import seaborn as sns
from data.generate_data import gen_data
from data.data_preprocess import block_split,uniform_split
import os
from models.MLP import ModuloClassifier,ModuloClassifier_noEmb
from models.simple_transformer import Causal_Transformer
from utils.custom_opt import adam_emb_wd
from experiments.trainer  import trainer_mlp,trainer_transformer
from utils.plotting import plot_accuracy,plot_accuracy_21
from torch.utils.data import DataLoader, TensorDataset
from data.generate_data import ArithmeticDataset

import os
import torch
import matplotlib.pyplot as plt
from matplotlib import rcParams


seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')


rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times"],
    "font.size": 18,
    "figure.titlesize": 18,
    "axes.titlesize": 18,
    "axes.labelsize": 18,
    "legend.fontsize": 18,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "figure.figsize": (10, 6)
})

datasets = ["modp"]#"modp", "amodp","dmodp",
sampling_methods = ['random', 'skew', 'uniform']
learning_rates = [0.001]
batch_sizes = [512]
scale_values = [1]

def load_results(dataset_dir, sampling_methods, learning_rates, batch_sizes, split_values):
    grouped_results = {
        lr: {
            bs: {
                scale: {sample: [] for sample in sampling_methods}
                for scale in scale_values
            }
            for bs in batch_sizes
        }
        for lr in learning_rates
    }
    COUNTER = 0 
    for filename in os.listdir(dataset_dir):
        if filename.endswith(".pth") and filename.startswith("mlp") and "adam_10" not in filename :
            parts = filename.split("_")
            lr = float(parts[2])
            bs = int(parts[8])
            sample = parts[10]
            scale = int(parts[-1].split(".")[0])
            if lr in learning_rates and bs in batch_sizes and scale in scale_values and sample in sampling_methods:
                print(COUNTER,filename)
                COUNTER+=1
                checkpoint_path = os.path.join(dataset_dir, filename)
                checkpoint = torch.load(checkpoint_path)
                grouped_results[lr][bs][scale][sample].append({
                    "train": checkpoint['train_acc'],
                    "test": checkpoint['test_acc'],
                    "filename": filename
                })
    return grouped_results

import matplotlib.pyplot as plt

def plot_train_test_curves(results, lr, bs):
    """
    Plots both train and test accuracy curves for each sampling method,
    for a given learning rate (lr) and batch size (bs).
    """
    plt.figure(figsize=(12, 7))
    colors = {'random': '#1f78b4', 'skew': '#c51b7d', 'uniform': '#4d9221'}
    linestyles = {'train': '-', 'test': '--'}
    alpha_vals = {'train': 0.85, 'test': 0.7}

    for sample in ['random', 'skew', 'uniform']:
        result_list = results.get(lr, {}).get(bs, {}).get(1, {}).get(sample, [])
        if len(result_list) > 0:
            train = result_list[0]["train"]
            test = result_list[0]["test"]
            N = (97*96)*0.3
            step_per_epoch = int(N/bs)
            epochs = [i * step_per_epoch for i in range(1, len(train)+1)]
            # Plot train
            plt.plot(
                epochs, train, 
                label=f"{sample} (Train)", 
                color=colors[sample], 
                linestyle=linestyles['train'], 
                linewidth=2.5, 
                alpha=alpha_vals['train'],
                marker='o', 
                markersize=4, 
                markerfacecolor=colors[sample], 
                markeredgewidth=0.0
            )
            plt.plot(
                epochs, test, 
                label=f"{sample} (Test)", 
                color=colors[sample], 
                linestyle=linestyles['test'], 
                linewidth=2.5, 
                alpha=alpha_vals['test'],
                marker='s', 
                markersize=4, 
                markerfacecolor=colors[sample], 
                markeredgewidth=0.0
            )

    plt.xlabel("Optimization steps", fontsize=16)
    plt.ylabel("Accuracy", fontsize=16)
    plt.title(f"Train & Test Accuracy for Sampling Methods\n(lr={lr}, batch size={bs})", fontsize=18)
    plt.xscale("log")
    plt.grid(True, which='both', linestyle=':', linewidth=1, alpha=0.5)
    plt.legend(fontsize=13)
    plt.tight_layout()
    plt.show()
for dataset in datasets:
    checkpoint_dir = f"checkpoints/{dataset}_old"
    results = load_results(checkpoint_dir, sampling_methods, learning_rates, batch_sizes, scale_values)
    for lr in learning_rates:
        for bs in batch_sizes:
            plot_train_test_curves(results, lr, bs)