import os, torch
from arguments import get_args
from Record.file_management import read_obj_dumps, load_from_pickle, save_to_pickle, create_directory
from ACState.object_dict import ObjDict
from Buffer.train_test_buffers import generate_buffers
from Model.base_model import InferenceModel
from ActualCausal.train_loop import pretrain, train_inference, test_dataset
from ACState.extractor import regenerate
from ACState.compute_mean_variance import compute_encoder_mean_variance
from Record.wandb_logging import initialize_wandb
from Model.model_utils import save_model, load_model
# from Causal.Training.full_test import test_full, test_full_train

from Environment.Environments.initialize_environment import initialize_environment

from Network.network_utils import pytorch_model
import numpy as np
import sys
import psutil

def train_all(args):
    # initialize wandb logging, None if unused
    wdb_run = initialize_wandb(args)
    # initialize the environment
    environment, record = initialize_environment(args.environment, args.record)
    # initialize the state handling, and fills args.factor with appropriate values
    enc_range, enc_dyn = compute_encoder_mean_variance(args)
    extractor, normalization = regenerate(args, environment, all=True, encoding_dim = args.image_enc.encoding_dim if len(args.train.load_encodings) > 0 else -1, enc_rng = enc_range, enc_dyn = enc_dyn)
    # initialize the model
    model = InferenceModel(args, extractor, normalization, environment)
    model = load_model(model, args.record.load_dir, device = args.torch.gpu if args.torch.cuda else "cpu")

    # get the train and test buffers
    if len(args.record.load_intermediate) > 0: train_buffer, test_buffer = load_from_pickle(os.path.join(args.record.load_intermediate,environment.name + "_traintest.pkl"))
    else: train_buffer, test_buffer = generate_buffers(environment, args, extractor, normalization, args.train.train)
    if len(args.record.save_intermediate) > 0: save_to_pickle(os.path.join(create_directory(args.record.save_intermediate), environment.name +  "_traintest.pkl"), (train_buffer, test_buffer))

    active_like, passive_like = None, None
    if len(args.record.load_intermediate) > 0: 
        try:
            model = load_from_pickle(os.path.join(args.record.load_intermediate, environment.name + "_inter_model.pkl"))
            print("loaded model")
            print(model)
            model.cpu().cuda(device = args.torch.gpu)
            passive_like, active_like = load_from_pickle(os.path.join(args.record.load_intermediate, environment.name + "_pretrain_outputs.pkl"))
        except FileNotFoundError as e:
            pass
    # perform the pretraining step which trains the models to get a baseline performance value
    if args.train.train and args.pretrain.num_iters > 0: passive_like, active_like = pretrain(args, model, train_buffer, test_buffer, wdb_run=wdb_run)
    # saving the passive models and weights
    if len(args.record.save_intermediate) > 0:
        save_to_pickle(os.path.join(create_directory(args.record.save_intermediate), environment.name +  "_inter_model.pkl"), model)
        save_to_pickle(os.path.join(args.record.save_intermediate, environment.name +  "_pretrain_outputs.pkl"), (passive_like, active_like))

    # generate the output error value from the last 100 active outputs TODO:goes inside loading, which isn't implemented
    args.factor.converged_active_loss_value = (np.mean(active_like) if active_like is not None else 3) * extractor.target_dim
    args.factor.converged_passive_loss_value =( np.mean(passive_like) if passive_like is not None else 0) * extractor.target_dim
    print("active like", args.factor.converged_active_loss_value)
    # # pretraining with the true traces, not used for the main algorithm
    # if args.train.train and args.inter.interaction.interaction_pretrain > 0: run_train_interaction(model, train_buffer, None, test_buffer, None, args, environment)
    
    # training the active and interaction models
    extractor, normalization = regenerate(args, environment, all=True, encoding_dim = args.image_enc.encoding_dim if len(args.train.load_encodings) > 0 else -1, enc_rng = enc_range, enc_dyn = enc_dyn)
    model.regenerate(extractor, normalization, environment)
    
    if args.train.train: train_inference(args, model, train_buffer, test_buffer, wdb_run=wdb_run)
    test_dataset(args, model, test_buffer, extractor)
    if len(args.record.save_dir) > 0: save_model(model, args.record.save_dir)

if __name__ == '__main__':
    args = get_args()
    print(args) # print out args for records
    torch.cuda.set_device(args.torch.gpu)
    np.set_printoptions(threshold=3000, linewidth=120, precision=4, suppress=True)
    torch.set_printoptions(precision=4, sci_mode=False)

    train_all(args)