import numpy as np 
from pyemma.coordinates import tica
from argparse import ArgumentParser
import os 
import torch 
from datasets.aldp import ALDPDataset
from openmmtools.testsystems import AlanineDipeptideImplicit
import yaml
from utils.utils import get_model ,save_prediction_pdb
from torch_geometric.loader import DataLoader
from tqdm import tqdm 
from utils.plotting import save_ramachandran_plot

# Skript to analyze latent space with tica
def make_tica_plot(data: np.ndarray, time_lag: int, dim: int, args):
    tc = tica(data, lag=time_lag, dim=dim)
    traj = next(tc.iterator())[1]
    save_ramachandran_plot(traj[:, 0], traj[:, 1], 0, args, name="tica_latent", path=os.path.join(args.log_dir, args.run_name))












if __name__ == "__main__":
    parser = ArgumentParser()
    # Model arguments
    parser.add_argument("--model_dir", type=str, default="workdir", help="Directory where models are stored")
    parser.add_argument("--run_name", type=str, default="run", help="Name of the run")
    

    # TICA arguments
    parser.add_argument("--time_lag", type=int, default=10, help="Time lag for tica")
    parser.add_argument("--dim", type=int, default=2, help="Dimension of the tica projection")



    args = parser.parse_args()
    
    # load arguments from the model
    config_dict = yaml.load(open(os.path.join(args.model_dir,args.run_name,"model_parameters.yml")), Loader=yaml.FullLoader)
    arg_dict = args.__dict__
    for key, value in config_dict.items():
        if isinstance(value, list):
            for v in value:
                arg_dict[key].append(v)
        else:
            arg_dict[key] = value

    # backwards compatibility
    if "no_propagator" not in arg_dict:
        arg_dict["no_propagator"] = False
    if "propagator_type" not in arg_dict:
        arg_dict["propagator_type"] = "linear"
        arg_dict["no_propagator"] = True
    if "sequence_length" not in arg_dict:
        arg_dict["sequence_length"] = 1

    
    testsystem = AlanineDipeptideImplicit(constraints=None)
    args.wandb = False
    args.save_pdb = False
    args.device = "cpu"
    args.testsystem = testsystem
    
    dataset = ALDPDataset(args,testsystem)
    
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
    # load the model
    model = get_model(args, coordinate_transform = dataset.coordinate_transform)
    state_dict = torch.load(os.path.join(args.model_dir, args.run_name, "best_model.pt"))
    model.load_state_dict(state_dict)
    model.eval()


    latent_space_batches = []
    with torch.no_grad():
        print("Computing forward pass of encoder")
        for data ,*_ in tqdm(loader): 
            latent_space = model.encoder(data)
            if args.no_vae:
                # take mu output (if --no_vae is set this is the latent space - logvar is not trained)
                latent_space = latent_space[0]
            else:
                # reparameterize if vae
                latent_space = model.reparameterize(latent_space[0],latent_space[1])
            latent_space_batches.append(latent_space)
        latent_space = torch.cat(latent_space_batches, dim=0)
        

    latent_space = latent_space.numpy()
    make_tica_plot(latent_space, args.time_lag, args.dim, args)


    # also save pdb prediction of model 
    run_dir = os.path.join(args.log_dir, args.run_name)
    save_prediction_pdb(model, next(iter(loader))[0], dataset.coordinate_transform, args, run_dir, dataset.inverse_transform)