import os
import re
import numpy as np
import torch
import argparse
from utilities import get_gd_directory, get_flow_directory
import matplotlib.pyplot as plt

def get_lr_values(path):
    lr_values = []
    pattern = re.compile(r'lr_(\d+\.?\d*)')
    for entry in os.listdir(path):
        match = pattern.match(entry)
        if match:
            lr_values.append(float(match.group(1)))
    return sorted(lr_values)

def main(dataset: str, arch_id: str, loss: str, ind = "flow", seed: int = 0, norm_type = 1, show=False, es=-1,eig_freq = 1,all=False):
    torch.manual_seed(seed)

    #Indices for iterate plotting with option --all
    inds = [8,15,19, 22] 
    if es!=-1:
        prefix=str(es).replace(".","_")+"_"
    else:
        prefix=""
    if ind == "flow":
        tick=1.0
        directory = os.path.expanduser(get_flow_directory(dataset, arch_id, seed, loss, tick))
    else:
        directory = os.path.expanduser(os.path.dirname(get_gd_directory(dataset, 0, arch_id, seed, "gd", loss, None)))
        print(directory)

        if(all):
            print(get_lr_values(directory))
            etas= [get_lr_values(directory)[ind] for ind in inds] # hard coded indices
            print(etas)
        else:
            etas = [get_lr_values(directory)[ind]]
            
        directories = [os.path.expanduser(get_gd_directory(dataset, eta, arch_id, seed, "gd", loss, None)) for eta in etas]
    
    shaps_all= []
    norms_all=[]  
    norms_from_init_all=[]
    losses_all=[]
    tests_all=[]

    for eta,directory in zip(etas, directories):
        if os.path.exists(directory+"/"+prefix+"eigs"):
            shp = torch.load(directory+"/"+prefix+"eigs")
            shaps = [s[0] for s in shp]
        else:
            shaps = [-1]
        shaps = np.array(shaps)
        shaps_all.append(shaps)
        if os.path.exists(directory+"/"+prefix+"norms"):
            nrm = torch.load(directory+"/"+prefix+"norms")
            if norm_type == 1:
                norms = nrm[:,0]
            elif norm_type == "fro" or norm_type == 2:
                norms = nrm[:,1]
            elif norm_type == "nuc":
                norms = nrm[:,2]
            else:
                norms = [-1]
        else:
            norms = [-1]
        if os.path.exists(directory+"/"+prefix+"norms_from_init"):
            nrm_i = torch.load(directory+"/"+prefix+"norms_from_init")
            if norm_type == 1:
                norms_from_init = nrm_i[:,0]
            elif norm_type == "fro" or norm_type == 2:
                norms_from_init = nrm_i[:,1]
            elif norm_type == "nuc":
                norms_from_init = nrm_i[:,2]
            else:
                norms_from_init = [-1]
        else:
            norms_from_init = [-1]
        norms = np.array(norms)
        norms_from_init = np.array(norms_from_init)
        norms_from_init_all.append(norms_from_init)
        norms_all.append(norms)
        if os.path.exists(directory+"/"+prefix+"train_loss"):
            losses = torch.load(directory+"/"+prefix+"train_loss")
        else:
            losses = [-1]
        losses = np.array(losses)
        losses_all.append(losses)
        if os.path.exists(directory+"/"+prefix+"test_loss"):
            tests = torch.load(directory+"/"+prefix+"test_loss")
        else:
            tests = [-1]
        tests = np.array(tests)
        tests_all.append(tests)
      
    colors = {
        "converged": "#00CC03", 
        "maximum": "#FA0400", 
        "bound": "#0081D1",  
        "flow": "#FF7F0F",  
        "heuristic": "#CC00F5", 
        "goal": "#70C8FF", 
         "converged_last": "darkgreen", 
    }     
    
    if es == -1:
        plot_dirname = f"figures/{args.dataset}_{args.arch_id}_{args.loss}_seed{args.seed}/final/init_"
    else:
        plot_dirname = f"figures/{args.dataset}_{args.arch_id}_{args.loss}_seed{args.seed}/es_{es}/init_"
    os.makedirs(plot_dirname, exist_ok=True)
    plot_dirname = plot_dirname + "etas_"+str(etas)
    print("saving to ", plot_dirname)

    plt.rcParams['font.size'] = 14
    plt.rcParams['mathtext.fontset'] = 'stix'
    plt.rcParams['font.family'] = 'STIXGeneral'

    cs = [ colors["converged_last"], colors["converged"],colors["bound"],colors["maximum"]]
    # Sharpness
    plt.figure(figsize=(7,4))
    for i, (shaps, eta,c) in enumerate(zip(shaps_all,etas,cs)):
        plt.plot(np.arange(0,len(shaps)*eig_freq, eig_freq), shaps, c=c,label=f"$\eta={eta}$")
        if ind != "flow":
            plt.axhline(2/eta, color=c, linestyle="--", alpha=.7)
    plt.axhline(-1,c="lightgray",linestyle="--",label="$2/\eta$")
    l,r=plt.ylim()
    plt.ylim(0,r)
    plt.xlabel("iteration")
    plt.ylabel("sharpness")    
    plt.legend(loc="upper right")
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "sharpness.png", dpi=300)
   
   # Sharpness
    plt.figure(figsize=(7,4))
    for i, (shaps, eta,c) in enumerate(zip(shaps_all,etas,cs)):
        plt.plot(np.arange(0,len(shaps)*eig_freq, eig_freq), shaps, c=c,label=f"$\eta={eta}$")
        if ind != "flow":
            plt.axhline(2/eta, color=c, linestyle="--", alpha=.7)
    plt.axhline(-1,c="lightgray",linestyle="--",label="$2/\eta$")
    l,r=plt.ylim()
    plt.ylim(0,r)
    plt.xlabel("iteration")
    plt.ylabel("sharpness")
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "_no_legend_sharpness.png", dpi=300)
    if show:
        plt.show()

    # Norm
    plt.figure(figsize=(7,4))
    for i, (norms, eta,c) in enumerate(zip(norms_all,etas,cs)):
      plt.plot(np.arange(0,len(norms)*eig_freq, eig_freq), norms, c=c,label=f"$\eta = {eta}$")
    plt.xlabel("iteration")
    plt.legend(loc="lower right")
    plt.ylabel("norm")
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "norms.png", dpi=300)
    if show:
        plt.show()

    # Norm
    plt.figure(figsize=(7,4))
    for i, (norms, eta,c) in enumerate(zip(norms_all,etas,cs)):
      plt.plot(np.arange(0,len(norms)*eig_freq, eig_freq), norms, c=c,label=f"$\eta = {eta}$")
    plt.xlabel("iteration")
    plt.ylabel("norm")
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "no_legend_norms.png", dpi=300)
    if show:
        plt.show()

    # Norm_from_init
    plt.figure(figsize=(7,4))
    for i, (norms_from_init, eta,c) in enumerate(zip(norms_from_init_all,etas,cs)):
      plt.plot(np.arange(0,len(norms_from_init)*eig_freq, eig_freq), norms_from_init, c=c,label=f"$\eta = {eta}$")
    plt.xlabel("iteration")
    plt.legend(loc="lower right")
    plt.ylabel("norm from initialization")
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "norms_from_init.png", dpi=300)
    if show:
        plt.show()

    # Norm_from_init
    plt.figure(figsize=(7,4))
    for i, (norms_from_init, eta,c) in enumerate(zip(norms_from_init_all,etas,cs)):
      plt.plot(np.arange(0,len(norms_from_init)*eig_freq, eig_freq), norms_from_init, c=c,label=f"$\eta = {eta}$")
    plt.xlabel("iteration")
    plt.ylabel("norm from initialization")
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "no_legend_norms_from_init.png", dpi=300)
    if show:
        plt.show()

    #  Loss
    plt.figure(figsize=(7,4))
    if(all):
        for i, (losses,tests, eta,c) in enumerate(zip(losses_all,tests_all,etas,cs)):
            plt.plot(losses, c=c,label=f"$\eta={eta}$")
            plt.plot(tests, c=c,linestyle="--")
        plt.axhline(-1,c="lightgray",linestyle="--",label="test")
        plt.axhline(-1,c="lightgray",label="train")
        l,r=plt.ylim()
        plt.ylim(0,r)
        plt.legend(loc="upper right")
    else:
        plt.plot(losses_all[0], c=colors["converged"],label=f"train loss, $\eta = {etas[0]}$")
        plt.plot(tests_all[0], c=colors["maximum"],label=f"test loss,  $\eta = {etas[0]}$")
    plt.xlabel("iteration")
    plt.ylabel("loss")
   
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "loss.png", dpi=300)
    if show:
        plt.show()

     # Test Loss
    plt.figure(figsize=(7,4))
    for i, (tests, eta,c) in enumerate(zip(tests_all,etas,cs)):
        plt.plot(tests, c=c,label=f"$\eta={eta}$")
    plt.xlabel("iteration")
    plt.ylabel("test loss")
    plt.yscale("log")
    plt.grid(True, linestyle='-')
    plt.legend(loc="upper right")

    plt.tight_layout()
    
    plt.savefig(plot_dirname + "test_loss.png", dpi=300)

    # Train Loss
    plt.figure(figsize=(7,4))    
    for i, (losses, eta,c) in enumerate(zip(losses_all,etas,cs)):
        plt.plot(losses, c=c,label=f"$\eta={eta}$")
    plt.xlabel("iteration")
    plt.ylabel("train loss")
    plt.yscale("log")
    plt.legend(loc="upper right")
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "train_loss.png", dpi=300)
    if show:
        plt.show()

    # Test Loss
    plt.figure(figsize=(7,4))    
    for i, (losses, eta,c) in enumerate(zip(losses_all,etas,cs)):
        plt.plot(losses, c=c,label=f"$\eta={eta}$")
    plt.xlabel("iteration")
    plt.ylabel("train loss")
    plt.yscale("log")
    plt.grid(True, linestyle='-')
    plt.tight_layout()
    plt.savefig(plot_dirname + "_no_legend_train_loss.png", dpi=300)
    if show:
        plt.show()
    if ind == "flow":
        plt.suptitle("Activation: " + arch_id + ", loss: " + loss + ", data: " + str(dataset) + ", seed: " + str(seed) + ", flow")
    else:
        plt.suptitle("Activation: " + arch_id + ", loss: " + loss + ", data: " + str(dataset) + ", seed: " + str(seed) + ", eta: " + str(eta))
    plt.tight_layout()
    plot_filename = f"figures/iterates_{args.dataset}_{args.arch_id}_{args.loss}_seed{args.seed}_ind_{ind}.png"
    plt.savefig(plot_filename, dpi=300) 
    plt.show()

def int_or_str(value):
    try:
        return int(value)
    except ValueError:
        if value in {"nuc", "fro", "flow"}:
            return value
        raise argparse.ArgumentTypeError(f"Invalid value: {value}. Must be an integer, 'nuc', or 'fro'.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot results for different learning rates.")
    parser.add_argument("dataset", type=str, help="which dataset")
    parser.add_argument("arch_id", type=str, help="which network architecture")
    parser.add_argument("loss", type=str, choices=["ce", "mse"], help="which loss function")
    parser.add_argument("index", type=int_or_str, help="'flow' or an index for a GD eta, obsolete when --all specified")
    parser.add_argument("--seed", type=int, help="the random seed used when initializing the network weights", default=42)
    parser.add_argument("--norm_type", type=int_or_str, help="prefered weight norm type, an integer or one of the strings 'nuc', 'fro'.", default=1)
    parser.add_argument("--show", action=argparse.BooleanOptionalAction, help="if 'True', shows each plot")
    parser.add_argument("--es", type=float, help="early stopping loss",default=-1)
    parser.add_argument("--eig_freq", type=int, help="save freq of eig", default=10)
    parser.add_argument("--all", action=argparse.BooleanOptionalAction, help="Run for all learning rates (hard coded, take care!)")

    args = parser.parse_args()
    main(dataset=args.dataset, arch_id=args.arch_id, loss=args.loss, ind=args.index, seed=args.seed, norm_type=args.norm_type,show=args.show, es = args.es, eig_freq = args.eig_freq, all=args.all)