"""Adaptation of code from: https://github.com/Cranial-XIX/CAGrad"""
from tqdm import tqdm
import argparse
import numpy as np
import torch
import os
import csv

from optimizer import *
from utils import *
from toy import Toy
from plot import *

# Create arg parser
parser = argparse.ArgumentParser(description='Arguments for Toy MOO task')

# general
parser.add_argument('--seed', type=int, default=2024, help='random seed')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--num_epochs', type=int, default=10000, help='number of epochs')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate for the model')
parser.add_argument('--fd_eps', default=1e-4, type=float)
parser.add_argument('--number', default=0, type=int)
parser.add_argument('--sigma1', default=0.05, type=float)
parser.add_argument('--toy', default='toy', type=str)
parser.add_argument('--init_num', default=3, type=int)
parser.add_argument('--init1', default=[6.5, 6.5])
parser.add_argument('--init2', default=[-6.5, 5.5])
parser.add_argument('--init3', default=[-10, -5])

# MoDo
parser.add_argument('--gamma_modo', type=float, default=0.001, help='learning rate of lambda')
parser.add_argument('--rho_modo', type=float, default=0.0, help='regularization parameter')

args = parser.parse_args()

### Define the problem ###
F = Toy(args)

maps = {
    "mgd": mgd,
    "pcgrad": pcgrad,
    "modo": modo
}

def run_all(args, number, seeds=1):
    all_traj = {}
    all_traj_error_norm = {}
    all_traj_true_dir_norm = {}

    # the initial positions
    inits = [
        torch.Tensor(args.init1),
        torch.Tensor(args.init2),
        torch.Tensor(args.init3),
        # torch.Tensor([5, -5.]),
        # torch.Tensor([7.5, 7.5]),
    ]

    for i, init in enumerate(inits):
        print(f'\ninit:{init}\n')
        for m in tqdm(["modo"]):
            all_traj[m] = []
            all_traj_error_norm[m] = []
            all_traj_true_dir_norm[m] = []

            for seed in range(seeds):
                traj = []
                traj_error_norm = []
                traj_true_dir_norm = []
                PS = []
                PS_em = []
                gen_erros = []
                CA_erros = []
                solver = maps[m]
                x = init.clone()
                x.requires_grad = True

                n_iter = args.num_epochs  # this is #iterations T
                opt = torch.optim.Adam([x], lr=args.lr)  # original
                # opt = torch.optim.SGD([x], lr=lr)
                decay = lambda epoch: 1 / (epoch + 1) ** args.sigma1
                scheduler = torch.optim.lr_scheduler.LambdaLR(
                    opt, lr_lambda=decay)

                for it in range(n_iter):
                    traj.append(x.detach().numpy().copy())

                    if it == 0:
                        lambd = 0.5 * torch.ones([2, 1])  # init lambda
                        gamma = args.gamma_modo  # 8e-2 # this is step size gamma
                        rho = args.rho_modo  # 1e-16
                    _, grads1 = F(x, True, 'stoch', args.batch_size // 2)
                    _, grads2 = F(x, True, 'stoch', args.batch_size // 2)

                    g, lambd = solver(grads1, grads2, lambd, gamma, rho)

                    G_emp_norm, G_norm, CA_erro = perform(F, x, lambd, maps)
                    PS.append(G_norm)
                    traj_true_dir_norm.append(G_emp_norm)
                    PS_em.append(G_emp_norm)
                    gen_erros.append(abs(G_norm-G_emp_norm))
                    CA_erros.append(CA_erro)

                    if it % 500 == 0:
                        print(f'it: {it}, PS: {G_norm}, PS_EM: {G_emp_norm}, gen_error: {abs(G_norm-G_emp_norm)}, CA_error: {CA_erro}')

                    opt.zero_grad()
                    x.grad = g
                    opt.step()
                    scheduler.step()

                all_traj[m].append(torch.tensor(np.array(traj)))
                all_traj_error_norm[m].append(torch.tensor(traj_error_norm))
                all_traj_true_dir_norm[m].append(torch.tensor(traj_true_dir_norm))

        # folder_name = "results/gamma/1E-4/"
        folder_name = "results/"+str(args.number) +"/"
        # folder_name = "results/static/"
        if not os.path.exists(folder_name):
            os.makedirs(folder_name)

        torch.save(all_traj, folder_name + args.toy + f"{i}-runs{seeds}.pt")
        torch.save(all_traj_error_norm,
                   folder_name + args.toy + f"{i}-error_norm-runs{seeds}.pt")
        torch.save(
            all_traj_true_dir_norm,
            folder_name + args.toy + f"{i}-true_dir_norm-runs{seeds}.pt")
        index = '_toy' + str(args.number) + '_' + str(args.batch_size) + '_' + str(args.lr) + '_' + str(args.gamma_modo) + '_' + '.csv'
        with open('results/'+str(args.number)+'/PS_' +str(i)+ index, 'w', newline='') as f1:
            writer = csv.writer(f1)
            writer.writerows([[val] for val in PS])
        with open('results/'+str(args.number)+'/PS_em_' +str(i)+ index, 'w', newline='') as f2:
            writer = csv.writer(f2)
            writer.writerows([[val] for val in PS_em])
        with open('results/'+str(args.number)+'/CA_erros_' +str(i)+ index, 'w', newline='') as f3:
            writer = csv.writer(f3)
            writer.writerows([[val] for val in CA_erros])
        with open('results/'+str(args.number)+'/gen_erros_' +str(i)+ index, 'w', newline='') as f4:
            writer = csv.writer(f4)
            writer.writerows([[val] for val in gen_erros])


def plot_results(args, index=None):
    # plot3d(F)
    levels = [-20, -18, -15, -13, -10, -5, 0, 3, 5, 10]
    # plot_contour(F, task=1, levels=levels, name="./imgs/_toy_task_1")
    # plot_contour(F, task=2, levels=levels, name="./imgs/_toy_task_2")

    folder_name = "./results/"+str(args.number)+"/"
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    t1 = torch.load(folder_name + args.toy + "0-runs1.pt", weights_only=False)
    t2 = torch.load(folder_name + args.toy + "1-runs1.pt", weights_only=False)
    t3 = torch.load(folder_name + args.toy + "2-runs1.pt", weights_only=False)

    key_list = list(t1.keys())
    print('\n Loaded keys:\n')
    print(key_list)
    print()
    # length = t1[key_list[0]][0].shape[0]

    folder_name = "./results/" + str(args.number) + "/imgs/"
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    for method in key_list:  #
        t = 10000
        plot_contour(F, task=0, levels=levels,
                    traj=[t1[method][0][:t], t2[method][0][:t], t3[method][0][:t]],
                    # ,t4[method][0][:t],t5[method][0][:t]
                    plotbar=(method == "modo"),
                    name=folder_name + args.toy +index, args=args)

if __name__ == "__main__":
    # For running the toy example and generate trajectories
    args.number = 0
    for init1 in [[8.5,7.5]]:
        args.init1 = init1
        for init2 in [[-7.5,6.5]]:
            args.init2 = init2
            # for init3 in [[-10,-5],[-7.5,-3.5],[7.5,-3.5],[10,-5]]:
            for init3 in [[-10.5, -2.5]]:
                args.init3 = init3
                for lr in [0.005, 0.01, 0.05, 0.1, 0.5, 1.0]:
                    args.lr = lr
                    for gamma in [0.0005,0.00075,0.0001,0.0005, 0.001, 0.005, 0.01, 0.05, 0.075, 0.1]:
                        args.gamma_modo = gamma
                        index = '_' + str(args.number) + '_' + str(args.batch_size) + '_' + str(args.lr) + '_' + str(args.gamma_modo)
                        run_all(args,args.number)

                        # Plot trajectories
                        plot_results(args, index)

                        ## Plot Pareto trajectory in Pareto front
                        plot_2d_pareto("modo", out_path="./results/" + str(args.number) + "/imgs/", data_type='pop', args=args, index=index)
                        plot_2d_pareto("modo", out_path="./results/" + str(args.number) + "/imgs/", data_type='emp', args=args, index=index)



