import torch
import numpy as np
import argparse
import json
from utilities import get_flow_directory, get_gd_directory
import os

# usage: e.g. python src/make_lr_schedule.py cifar10-5k fc-relu mse 43 1.0 --loss_goal 0.01

def generate_data_points(s_gf, s_0):
    set1 = np.linspace(2/(4*s_gf), 12 /(2*s_gf) , 12)
    step = (2 /s_0 - 6/s_gf) / 8
    set2 = np.linspace(6/ s_gf + step, 6/s_gf + 8 * step, 8)
    
    # Combine both sets
    data_points = np.concatenate((set1, set2))
    
    return data_points

def generate_data_points_fine(s_gf, s_0, max_lr=-1, divide=1):
    # Fine set
    step = 0.5 / s_gf / divide
    if(max_lr!=-1):
        full = np.arange(0.5 / s_gf, max_lr, step)  
    else:
        full = np.arange(0.5 / s_gf, 6 / s_gf + step / 2, step)
    return full

def generate_data_points_coarse(s_gf, s_0, max_lr=-1, divide=1):
    # Coarse set
    step = (2 /s_0 - 6/s_gf) / 8 / divide

    set1 = np.arange(6/ s_gf-step,0, -step)
    set1= set1[::-1]
    if(max_lr!=-1):
        set2 = np.arange(6/ s_gf, max_lr, step)
    else:
        set2 = np.linspace(6/ s_gf , 6/s_gf + 8 * step, 9)

    return np.array(list(set1)+list(set2))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Load eigenvalues from flow directory.")
    parser.add_argument("dataset", type=str, help="which dataset to use")
    parser.add_argument("arch_id", type=str, help="which network architecture to use")
    parser.add_argument("loss", type=str, choices=["ce", "mse"], help="which loss function to use")
    parser.add_argument("seed", type=int, help="the random seed", default=0)
    parser.add_argument("tick", type=float, help="tick value for flow directory")
    parser.add_argument("--loss_goal", type=float, help="Loss goal threshold for convergence check")
    parser.add_argument("--es", type=float, default=-1, help="Early stopping prefix")
    parser.add_argument("--max_lr_coarse", type=float, default=-1,help="Upper limit for the learning rate for coarse, instead of the fixed schedule.")
    parser.add_argument("--divide_coarse", type=float, default=1,help="Double->2,Triple->3/... etc. the points in the coarse setting.")
    parser.add_argument("--max_lr_fine", type=float, default=-1,help="Upper limit for the learning rate for fine, instead of the fixed schedule.")
    parser.add_argument("--divide_fine", type=int, default=1,help="Double->2, Triple->3/... etc. the points in the fine setting.")
    parser.add_argument("--init_scaling", type=float, help="multiply initial weights by this", default=1.0)
    args = parser.parse_args()
    
    # Get flow directory
    flow_directory = os.path.expanduser(get_flow_directory(dataset=args.dataset,arch_id=args.arch_id, 
                                      seed=args.seed, loss=args.loss,tick=args.tick))
    gd_directory = os.path.dirname(os.path.expanduser(get_gd_directory(dataset=args.dataset,arch_id=args.arch_id, 
                                      seed=args.seed, loss=args.loss,opt="gd",lr=0)))
    if not args.init_scaling == 1.0:
        flow_directory = flow_directory + f"/scaling_{args.init_scaling}"
        gd_directory = gd_directory + f"/scaling_{args.init_scaling}"
    if(args.es==args.loss_goal):
        prefix=str(args.loss_goal).replace(".","_")+"_"
    else:
        prefix=""

    results= {}
    # Load flow-vallues
    gf_sharpness = torch.load(f"{flow_directory}/{prefix}eigs")
    s_gf = gf_sharpness.max().item()
    s_0 =  gf_sharpness[0].item()
    results["s_gf"]= s_gf
    results["s_0"]= s_0
    results["loss_goal"]=args.loss_goal

    print("s_gf = ",s_gf)
    print("s_0 = ",s_0)
    
    # Generate data points
    data_points = generate_data_points(s_gf, s_0)
    data_points = np.round(data_points, 4)
    results["data_points_rounded"] = list( np.round(data_points,4))

    coarse= generate_data_points_coarse(s_gf, s_0, max_lr= args.max_lr_coarse,divide=args.divide_coarse)
    fine= generate_data_points_fine(s_gf, s_0, max_lr= args.max_lr_fine,divide=args.divide_fine)
    coarse = np.round(coarse, 4)
    fine = np.round(fine, 4)

    results["coarse_rounded"] = list(coarse)
    results["fine_rounded"] = list(fine)

    # Load train loss and check convergence
    train_loss = torch.load(f"{flow_directory}/{prefix}train_loss")
    if args.loss_goal is not None and train_loss[-1].item() <= args.loss_goal:
        print("Old schedule")
        print(data_points)

        print("\nFull coarse grid:")
        print(coarse,"\n")
        print("\nCoarse grid, added values:")
        print(np.setdiff1d(coarse,data_points[11:]))


        print("\nFull fine grid:")
        print(fine,"\n")
        print("\n Fine grid, added values:")
        print(np.setdiff1d(fine,data_points[:12]))

        os.makedirs(gd_directory, exist_ok=True)
        with open(f"{gd_directory}/lr_schedule.json","w") as fp:
            json.dump(results,fp)
            print(f"saved learning rate schedule to {gd_directory}/lr_schedule.json \n")

            print(results)

    else:
        print("GF not converged!")
       