import os
import torch
import wandb
import plot
import final_plot
import observability

from experiments.ResolutionExperiment import ResolutionExperiment
from experiments.TokenExperiment import TokenExperiment
from solvers.KSSolver import KSSolver
from solvers.WaveSolver import WaveSolver
from solvers.HeatSolver import HeatSolver
from config import config

if __name__ == '__main__':
    # Solver = WaveSolver
    # config['time_factor'] = 0.05
    Solver = HeatSolver
    config['time_factor'] = 0.4
    # Solver = KSSolver
    solver = Solver(config)

    torch.manual_seed(solver.seed)
    print("Seed: ", solver.seed)
    os.makedirs(f'output/{solver.seed}', exist_ok=True)

    # wandb.init(
    #     project="learning-to-operate",
    #     config=config
    # )

    # experiment = TokenExperiment(config)
    experiment = ResolutionExperiment(config)

    # experiment.adversary_learn(solver)
    # experiment.train_simplified_gan(solver.seed, tokens=False)
    # experiment.train_ks_ar()
    # experiment.train_linear_gan(Solver)
    # experiment.sample_ks('kse-tok-256-16/22', 'gan')
    # experiment.sample_ks_res('kse-res-256-16/22', 'data/generated/kse/gen_22.pt',  'gan')
    # experiment.sample_ks_many('kse-tok-256-16/22', 'kse-res-256-16-v2/200', method='gan')
    # experiment.save_real_ks('test')
    # experiment.save_random_graphs(solver)
    # experiment.torch_learn(solver)
    # experiment.linear_regression_multiple(Solver)
    # experiment.torch_learn_multiple(Solver)
    # experiment.save_multiple_datasets(Solver, 'wave')
    # experiment.lin_alg(solver)
    # experiment.vary_history(Solver, True)
    # experiment.vary_history_many(Solver, True)
    # experiment.alternate_initial_conditions_torch(solver, 12)
    # experiment.vary_solver_parameter(Solver, 'T', [32, 128, 256, 512, 1024])
    # experiment.sample(solver, 1747127895) # heat-tok
    # config['op_seed'] = 873282266
    # experiment.sample(solver, 1746564533) # wave-tok
    # experiment.sample(solver, 1746866987) # wave-tok-final
    experiment.sample(solver, 1746804099, data_path=f'data/generated/{solver.name}/') # heat-res
    # experiment.sample(solver, 1746884797, data_path=f'data/generated/{solver.name}/') # wave-res

    # observability.observability_check(solver)
    # observability.kse_lie_check()

    # final_plot.extract_data_from_file('slurm-4495578.out')
    # final_plot.plot_curves('slurm-heat-hist.out', 'slurm-wave-hist.out', 'L1')
    # final_plot.save_vids(solver, f'data/generated/{solver.name}/real_tok.pt', f'data/generated/{solver.name}/gen_tok.pt',
    #                      f'data/generated/{solver.name}/diff_tok.pt', f'data/generated/{solver.name}/real_full_tok.pt', 'tok')
    # final_plot.get_time_correlations_and_plot('data/generated/kse/for_cors/')
    # final_plot.plot_model_err('data/generated/heat/diff_res.pt', 'data/generated/wave/diff_res.pt')
    # final_plot.make_comparison_grid('data/generated/wave/real_tok.pt', 'data/generated/wave/gen_tok.pt', 'data/generated/wave/diff_tok.pt')
    # final_plot.load_and_plot_multi_run('wandb_export.csv')
    # final_plot.plot_kse_errors('output/gan/kse-tok-256-16/')
    # plot.animate_eq(solver)
    # plot.plot_field(solver, config, law='norm')

