from casanovo.denovo.datasets import AnnotatedSpectrumDataset, SplitAwareAnnotatedSpectrumDataset
from casanovo.denovo.model import Spec2Pep
from casanovo.denovo.tokenizer import MassAwareMskbPeptideTokenizer
import torch
import pickle as pkl

import numpy as np
import pandas as pd

print("Started Script", flush=True)

lance_full = "mskb_casanovo_data/combined/spectra_82c0124b_combined.lance"
split_path = "mskb_casanovo_data/combined/test_indices_combined_Train-Cfix_Mox_Val-Mox.pkl"
weights_path_BASE = "weights/"

models = [
    {
        "name": "Scalar_Mass_prefix_correct_GAN_MSV_V1_189_859",
        "weights_path": weights_path_BASE + "open_GAN_MTL_scalar_1e6/epoch=0-step=1200000-val_loss_validation_dev_loss=189.859.ckpt",
        "is_gan": True
    },
    {
        "name": "Scalar_Mass_prefix_correct_MTL_MSV_V1_174_721_b_y_ions",
        "weights_path": weights_path_BASE + "mse_depthCasanovo_open_training_scalar_mass_prediction_mixed_real_sim_basic_MTL_on_b_y_sim/epoch=0-step=600000-train_MSELoss=0.000-valid_MSELoss=174.721.ckpt",
        "is_gan": False
    },
]

residues = {
    "G": 57.021464,
    "A": 71.037114,
    "S": 87.032028,
    "P": 97.052764,
    "V": 99.068414,
    "T": 101.047670,
    "C[Carbamidomethyl]": 160.030649, # 103.009185 + 57.021464
    "I": 113.084064,
    "L": 113.084064,
    "N": 114.042927,
    "D": 115.026943,
    "Q": 128.058578,
    "K": 128.094963,
    "E": 129.042593,
    "M": 131.040485,
    "H": 137.058912,
    "F": 147.068414,
    "R": 156.101111,
    "Y": 163.063329,
    "W": 186.079313,
    # Amino acid modifications.
    "M[Oxidation]": 147.035400,    # Met oxidation:   131.040485 + 15.994915
    "N[Deamidated]": 115.026943,     # Asn deamidation: 114.042927 +  0.984016
    "Q[Deamidated]": 129.042593,     # Gln deamidation: 128.058578 +  0.984016
    # N-terminal modifications.
    "[Acetyl]-": 42.010565,      # Acetylation
    "[Carbamyl]-": 43.005814,  # Carbamylation "+43.006"
    "[Ammonia-loss]-": -17.026549,     # NH3 loss
    "[+25.980265]-": 25.980265,
}

device = torch.device('cuda:0')

anno_tokenizer = MassAwareMskbPeptideTokenizer(residues=residues, reverse=True, replace_isoleucine_with_leucine=True)

for model_config in models:
    weights_path = model_config["weights_path"]
    is_gan_model = model_config["is_gan"]
    outname = model_config["name"]

    anno_dataset = SplitAwareAnnotatedSpectrumDataset(
        None,
        'seq',
        anno_tokenizer,
        256,
        shuffle=False,
        split_indices=split_path,
        path=lance_full
    )

    dl = torch.utils.data.DataLoader(anno_dataset,
                                    shuffle=None,
                                    num_workers=0,
                                    pin_memory=True,)

    if is_gan_model:
        model = Spec2Pep(
            dropout=0.18,
            max_charge=10,
        )

        checkpoint = torch.load(weights_path, map_location=device)
        checkpoint_state_dict = checkpoint["state_dict"]

        encoder_decoder_state_dict = {k.removeprefix("denovo_model."): v for k, v in checkpoint_state_dict.items() if k.startswith("denovo_model.")}
        model.load_state_dict(encoder_decoder_state_dict, strict=True)

        model = model.to(device)
    else:
        model = Spec2Pep.load_from_checkpoint(weights_path, map_location=device, strict=False)

    model.tokenizer = anno_tokenizer
    model.decoder.tokenizer = anno_tokenizer
    model.decoder._mass_lookup_table = [anno_tokenizer.residues.get(a, 0.0) for a in anno_tokenizer.reverse_index]

    print(f"All initialized for file {outname}", flush=True)

    super_df = None
    counter = 0
    for batch in dl:
        keys = ['precursor_mz', 'precursor_charge', 'mz_array', 'intensity_array', 'seq']
        for key in keys:
            try:
                batch[key] = batch[key].to(device)
            except:
                pass
        pred_scores, pred_mass, true_tokens, true_masses = model._forward_step(batch)
        pred_tokens = torch.argmax(pred_scores, dim=-1)

        mask = true_masses>0
        label_AA = pred_tokens[mask]
        score_AA = pred_scores.gather(dim=2, index=pred_tokens.unsqueeze(-1)).squeeze(-1)[mask]
        target = true_tokens[mask] if true_tokens is not None else torch.zeros_like(score_AA, dtype=torch.int32)

        msk_np = mask.detach().cpu()

        indexing_ar = np.array(anno_tokenizer.reverse_index)
        data = {
            'target_AA': indexing_ar[target.detach().cpu().numpy()],
            'score_AA': score_AA.detach().cpu().numpy(),
            'pred_AA': indexing_ar[label_AA.detach().cpu().numpy()],
            'pred_mass': pred_mass.detach().cpu().numpy().squeeze()[msk_np], #[:,:-1]
            'true_mass': true_masses.detach().cpu().numpy()[msk_np]
        }
        
        df = pd.DataFrame(data)
        if super_df is None:
            super_df = df
        else:
            super_df = pd.concat((super_df, df), axis=0)

        print(f"Done with iteration {counter}", flush=True)
        counter += 1

    super_df['label_AA'] = super_df['target_AA']==super_df['pred_AA']

    with open(f"seqResults/{outname}.pkl", "+wb") as f:
        pkl.dump(super_df, f)