import argparse
import json
import math
import random
import re
import warnings
from collections import OrderedDict

import numpy as np
import pandas as pd
import scipy.stats
import torch
# import wandb
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.preprocessing import PowerTransformer
from torch import nn
from torch.utils.data import Dataset
from tqdm import tqdm

import mint
from mint.model.esm2 import ESM2

warnings.filterwarnings("ignore")

class FlabDataset(Dataset):
    def __init__(self, csv_path, target_col):
        super().__init__()

        data = pd.read_csv(csv_path, sep=",")
        self.heavy = data["heavy"].tolist()
        self.light = data["light"].tolist()
        self.target = data[target_col].tolist()

    def __len__(self):
        return len(self.heavy)

    def __getitem__(self, index):
        return self.heavy[index], self.light[index], self.target[index]


class DesautelsCollateFn:
    def __init__(self, truncation_seq_length=None):
        self.alphabet = mint.data.Alphabet.from_architecture("ESM-1b")
        self.truncation_seq_length = truncation_seq_length

    def __call__(self, batches):
        len(batches)
        heavy_chain, light_chain, labels = zip(*batches)
        chains = [self.convert(c) for c in [heavy_chain, light_chain]]
        chain_ids = [torch.ones(c.shape, dtype=torch.int32) * i for i, c in enumerate(chains)]
        chains = torch.cat(chains, -1)
        chain_ids = torch.cat(chain_ids, -1)
        labels = torch.from_numpy(np.stack(labels, 0))
        return chains, chain_ids, labels

    def convert(self, seq_str_list):
        batch_size = len(seq_str_list)
        seq_encoded_list = [
            self.alphabet.encode("<cls>" + seq_str.replace("J", "L") + "<eos>")
            for seq_str in seq_str_list
        ]
        if self.truncation_seq_length:
            for i in range(batch_size):
                seq = seq_encoded_list[i]
                if len(seq) > self.truncation_seq_length:
                    start = random.randint(0, len(seq) - self.truncation_seq_length + 1)
                    seq_encoded_list[i] = seq[start : start + self.truncation_seq_length]
        max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
        if self.truncation_seq_length:
            assert max_len <= self.truncation_seq_length
        tokens = torch.empty((batch_size, max_len), dtype=torch.int64)
        tokens.fill_(self.alphabet.padding_idx)

        for i, seq_encoded in enumerate(seq_encoded_list):
            seq = torch.tensor(seq_encoded, dtype=torch.int64)
            tokens[i, : len(seq_encoded)] = seq
        return tokens


def upgrade_state_dict(state_dict):
    """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
    prefixes = ["encoder.sentence_encoder.", "encoder."]
    pattern = re.compile("^" + "|".join(prefixes))
    state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
    return state_dict


class FlabWrapper(nn.Module):
    def __init__(
        self, cfg, checkpoint_path, freeze_percent=0.0, use_multimer=True, device="cuda:0"
    ):
        super().__init__()
        self.cfg = cfg
        self.model = ESM2(
            num_layers=cfg.encoder_layers,
            embed_dim=cfg.encoder_embed_dim,
            attention_heads=cfg.encoder_attention_heads,
            token_dropout=cfg.token_dropout,
            use_multimer=use_multimer,
        )
        checkpoint = torch.load(checkpoint_path, map_location=device)

        if use_multimer:
            # remove 'model.' in keys
            new_checkpoint = OrderedDict(
                (key.replace("model.", ""), value)
                for key, value in checkpoint["state_dict"].items()
            )
            self.model.load_state_dict(new_checkpoint)
        else:
            new_checkpoint = upgrade_state_dict(checkpoint["model"])
            self.model.load_state_dict(new_checkpoint)
        total_layers = 33
        for name, param in self.model.named_parameters():
            if "embed_tokens.weight" in name or "_norm_after" in name or "lm_head" in name:
                param.requires_grad = False
            else:
                layer_num = name.split(".")[1]
                if int(layer_num) <= math.floor(total_layers * freeze_percent):
                    param.requires_grad = False

    def forward(self, chains, chain_ids):
        mask = (
            (~chains.eq(self.model.cls_idx))
            & (~chains.eq(self.model.eos_idx))
            & (~chains.eq(self.model.padding_idx))
        )
        chain_out = self.model(chains, chain_ids, repr_layers=[33])["representations"][33]
        mask_expanded = mask.unsqueeze(-1).expand_as(chain_out)
        masked_chain_out = chain_out * mask_expanded
        sum_masked = masked_chain_out.sum(dim=1)
        mask_counts = mask.sum(dim=1, keepdim=True).float()  # Convert to float for division
        mean_chain_out = sum_masked / mask_counts
        return mean_chain_out


@torch.no_grad()
def get_embeddings(model, loader, device="cuda"):

    model.to(device)

    embeddings = []
    targets = []

    for step, eval_batch in enumerate(tqdm(loader)):

        chains, chain_ids, target = eval_batch
        chains = chains.to(device)
        chain_ids = chain_ids.to(device)
        target = target.to(device).float()

        embedding = model(chains, chain_ids)

        embeddings.append(embedding.detach().cpu().numpy())
        targets.append(target.cpu().numpy())

    embeddings = np.concatenate(embeddings)
    targets = np.concatenate(targets)

    return embeddings, targets


def gaussian_transform(y):
    y = PowerTransformer().fit_transform(y.reshape(-1, 1))
    return y


def cross_validate(embeddings, targets, scale_all=True):
    lambda_grid = np.logspace(
        0, -6, num=7
    ).tolist()  # creates [1, 0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001]
    lambda_grid.append(0)  # Append 0 to the list of lambdas
    param_grid = {"alpha": lambda_grid}

    outer_cv = KFold(n_splits=10, shuffle=True, random_state=0)
    inner_cv = KFold(n_splits=5, shuffle=True, random_state=0)

    X_scaled = embeddings

    if scale_all:
        targets = gaussian_transform(targets)

    # Initialize the Ridge Regression model
    ridge_model = Ridge()

    # Setup the GridSearchCV object
    clf = GridSearchCV(estimator=ridge_model, param_grid=param_grid, cv=inner_cv, scoring="r2")

    outer_scores = []
    outer_corrs = []

    for train_idx, test_idx in tqdm(outer_cv.split(X_scaled), total=10):
        # Split data into training and test sets for the outer CV
        X_train, X_test = X_scaled[train_idx], X_scaled[test_idx]
        Y_train, Y_test = targets[train_idx], targets[test_idx]

        if not scale_all:
            Y_train = gaussian_transform(Y_train)
            Y_test = gaussian_transform(Y_test)

        # Fit the model (and find the best lambda using inner CV)
        clf.fit(X_train, Y_train)

        # Best model found by GridSearchCV
        best_model = clf.best_estimator_

        # Evaluate the best model on the outer test set
        Y_pred = best_model.predict(X_test)
        r2 = r2_score(Y_test, Y_pred)
        corr = scipy.stats.pearsonr(Y_test[:, 0], Y_pred[:, 0])[0]

        # Append the score
        outer_scores.append(r2)
        outer_corrs.append(corr)

    # Output the performance
    # print("R2 for each fold:", outer_scores)
    print("Average R2 across all folds:", np.mean(outer_scores))
    print("R2 Standard deviation across all folds:", np.std(outer_scores))
    print("Average pearson correlation across all folds:", np.mean(outer_corrs))
    print("pearson correlation tandard deviation across all folds:", np.std(outer_corrs))
    print("\n")

    scores_dict = {
        "R2_avg": np.mean(outer_scores),
        "R2_std": np.std(outer_scores),
        "Pearson_avg": np.mean(outer_corrs),
        "Pearson_std": np.std(outer_corrs),
    }

    return scores_dict


def calculate_scores(model, dataset_file, device):
    print(f"Evaluating for dataset file {dataset_file}")

    # if "Kd" in dataset_file or "kd" in dataset_file:
    #     dataset = FlabDataset(dataset_file, "negative log Kd")
    # else:
    #     dataset = FlabDataset(dataset_file, "negative log expression")
    dataset = FlabDataset(dataset_file, "fitness")
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=64, collate_fn=DesautelsCollateFn(), shuffle=False
    )
    embeddings, targets = get_embeddings(model, loader, device=device)
    targets = np.array(dataset.target)
    scores_dict = cross_validate(embeddings, targets, scale_all=True)
    return scores_dict


def main(args):
    dataset_files = [
        "datasets/Shanehsazzadeh2023_trastuzumab_zero_kd.csv",
        "datasets/Warszawski2019_d44_Kd_ppl.csv",
        "datasets/Koenig2017_g6_Kd_ppl.csv",
        "datasets/Koenig2017_g6_er_ppl.csv",
    ]

    cfg = argparse.Namespace()
    with open(
        f"./data/esm2_t33_650M_UR50D.json"
    ) as f:
        cfg.__dict__.update(json.load(f))

    model = FlabWrapper(cfg, args.checkpoint_path, 1.0, args.use_multimer, args.device)

    for dataset_file in dataset_files:
        scores_dict = calculate_scores(model, dataset_file, args.device)
        # wandb.log(scores_dict)
        # wandb.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Finetuning on Flab dataset")

    parser.add_argument("--wandb", action="store_true", default=False)
    parser.add_argument("--wandb_key", type=str, default=None)
    parser.add_argument(
        "--checkpoint_path",
        type=str,
        default="/new/LLM/MINT/mint.ckpt"
    )
    parser.add_argument("--use_multimer", action="store_true", default=True)
    parser.add_argument("--device", type=str, default="cuda:0")

    args = parser.parse_args()
    main(args)
