import json
import os
import pandas as pd
import re
import torch
import math
import numpy as np

from tqdm import tqdm

from train.prepare import prepare
from utils.data_utils import data_collate, move_tensors, get_intensity
from MorphoMNIST.morphomnist.measure import measure_image

# SET THESE THREE VARIABLES
artifact_path = "/home/ubuntu/Downloads/variational-causal-inference/artifact"
batch_size=256
measure=False

def measure_thickness_and_intensity(batch):
    batch = batch.mean(axis=1).cpu()
    relevant_measurements = [{
        "thickness": measure_image(img, verbose=False).thickness,
        "intensity": get_intensity(img),
    } for img in batch]

    return pd.DataFrame.from_dict(relevant_measurements)

def write_gt_info(run_name, checkpoint_name, measure=False):
    def load_args(run_name, artifact_path): 
        full_pth = os.path.join(
            artifact_path, 
            "saves",
            run_name,
            "log.txt"
        )

        with open(full_pth, "r+") as f: 
            log_str_rep = f.read()

        args_dict_str = log_str_rep.split('\n')[0]
        args = json.loads(args_dict_str)

        return args

    args = load_args(run_name, artifact_path)
    state_dict, args = torch.load(
        os.path.join(
            artifact_path, 
            "saves",
            run_name,
            checkpoint_name
        ),
        map_location="cpu"
    )
    _, model, datasets = prepare(args, state_dict=state_dict, device='cuda')
    datasets.update(
        {
            "valid_loader": torch.utils.data.DataLoader(
                datasets["valid"],
                batch_size=batch_size,
                shuffle=False,
                collate_fn=(lambda batch: data_collate(batch, nb_dims=datasets["valid"].nb_dims))
            )
        }
    )

    model.eval()
    abs_losses = None
    sq_losses = None
    thickness_abs_losses = None
    intensity_abs_losses = None
    thickness_abs_losses = None
    intensity_sq_losses = None

    print(f"Checking gt loss for run {run_name} and checkpoint {checkpoint_name}...")
    for batch_idx, batch in enumerate(datasets["valid_loader"]):
        batch = move_tensors(*batch, device='cuda')
        outcomes, treatments, covariates, cf_treatments, cf_outcomes = batch
        g_loss, g_log = model.loss(outcomes, treatments, covariates, cf_treatments, cf_outcomes)
        cf_outcomes_out = g_log["cf_outcomes_out"]

        if measure: 
            gt_measurements = measure_thickness_and_intensity(cf_outcomes)
            cf_measurements = measure_thickness_and_intensity(cf_outcomes_out)

            # removing nan values
            indices_to_drop = cf_measurements[cf_measurements.isnull().any(axis=1)].index
            gt_measurements = gt_measurements.drop(indices_to_drop)
            cf_measurements = cf_measurements.drop(indices_to_drop)

            thickness_abs_batch_losses = np.abs(gt_measurements["thickness"].to_numpy() - cf_measurements[ "thickness" ].to_numpy())
            intensity_abs_batch_losses = np.abs(gt_measurements["intensity"].to_numpy() - cf_measurements[ "intensity" ].to_numpy())
            thickness_sq_batch_losses = np.abs(gt_measurements["thickness"].to_numpy() - cf_measurements[ "thickness" ].to_numpy()) ** 2
            intensity_sq_batch_losses = np.abs(gt_measurements["intensity"].to_numpy() - cf_measurements[ "intensity" ].to_numpy()) ** 2

        gt_abs_losses = g_log["GT MAE CF Loss"]
        gt_sq_losses = g_log["GT MSE CF Loss"]
        gt_abs_losses = gt_abs_losses[~gt_abs_losses.isnan()]
        gt_sq_losses = gt_sq_losses[~gt_sq_losses.isnan()]

        if abs_losses is None:
            abs_losses = gt_abs_losses
            sq_losses = gt_abs_losses
            if measure: 
                thickness_abs_losses = thickness_abs_batch_losses
                intensity_abs_losses = intensity_abs_batch_losses 
                thickness_sq_losses = thickness_sq_batch_losses 
                intensity_sq_losses = intensity_sq_batch_losses
        else: 
            abs_losses = torch.cat((abs_losses, gt_abs_losses), axis=0)
            sq_losses = torch.cat((sq_losses, gt_sq_losses), axis=0)
            if measure: 
                thickness_abs_losses = np.concatenate((thickness_abs_losses, thickness_abs_batch_losses), axis=0)
                intensity_abs_losses = np.concatenate((intensity_abs_losses, intensity_abs_batch_losses), axis=0)
                thickness_sq_losses = np.concatenate((thickness_sq_losses, thickness_sq_batch_losses), axis=0)
                intensity_sq_losses = np.concatenate((intensity_sq_losses, intensity_sq_batch_losses), axis=0)

        print(batch_idx, gt_abs_losses.median())
        print(batch_idx, gt_sq_losses.median())


    average_abs_loss = abs_losses.mean().item()
    median_abs_loss = abs_losses.median().item()
    average_sq_loss = sq_losses.mean().item()
    median_sq_loss = sq_losses.median().item()
    full_pth = os.path.join(
        artifact_path, 
        "saves",
        run_name,
    )

    def write_to_df(pth, epoch, loss_dict, measure=False):
        df_rel_pth = 'gt_measure.csv' if measure else 'gt.csv' # 'gt_measure_thickness.csv' 'gt_measure_intensity.csv' 'gt_thickness.csv' 'gt_intensity.csv'
        df_path = os.path.join(pth, df_rel_pth)

        # If the file exists, load it. Otherwise, create a new DataFrame.
        if os.path.exists(df_path):
            df = pd.read_csv(df_path)
        else:
            df = pd.DataFrame(columns=list(loss_dict.keys()))

        # Check if the epoch already exists in the DataFrame
        if epoch in df['epoch'].values:
            # Update the value for the existing epoch
            for key, value in loss_dict.items(): 
                if key == "epoch": 
                    continue
            
                df.loc[df['epoch'] == epoch, key] = value
        else:
            # Add a new row for the epoch and its value
            # new_row = {'epoch': epoch, 'mean_mae': average_loss, 'median_mae': median_loss}
            new_row = loss_dict
            df = df.append(new_row, ignore_index=True)

        # Write the DataFrame back to the file
        df.to_csv(df_path, index=False)

    def extract_epoch_number(string):
        pattern = r'epoch=(\d+)'
        match = re.search(pattern, string)

        if match:
            epoch_number = int(match.group(1))
            return epoch_number
        else:
            return None

    epoch = extract_epoch_number(checkpoint_name)

    loss_dict = {
        "epoch": epoch,
        "mean_abs_loss": average_abs_loss,
        "median_abs_loss": median_abs_loss,
        "mean_sq_loss": average_sq_loss,
        "median_sq_loss": median_sq_loss,
    }

    if measure: 
        loss_dict["thickness_median_mae_loss"] = np.median(thickness_abs_losses)
        loss_dict["thickness_mean_mae_loss"] = thickness_abs_losses.mean()
        loss_dict["intensity_median_mae_loss"] = np.median(intensity_abs_losses)
        loss_dict["intensity_mean_mae_loss"] = intensity_abs_losses.mean()
        loss_dict["thickness_median_mse_loss"] = np.median(thickness_sq_losses)
        loss_dict["thickness_mean_mse_loss"] = thickness_sq_losses.mean()
        loss_dict["intensity_median_mse_loss"] = np.median(intensity_sq_losses)
        loss_dict["intensity_mean_mse_loss"] = intensity_sq_losses.mean()

    write_to_df(full_pth, epoch, loss_dict, measure=measure)


runs_to_check = [
    # "morphomnist-sch-abl-both-0_2024.01.09_13:55:55",
    # "morphomnist-sch-abl-both-1_2024.01.09_13:55:53",
    # "morphomnist-sch-abl-both-2_2024.01.09_13:55:51",
    # "morphomnist-sch-abl-both-3_2024.01.09_13:55:50",
    # "morphomnist-sch-abl-both-4_2024.01.09_13:55:55",
    # "morphomnist-sch-abl-kl-0_2024.01.09_13:55:55",
    # "morphomnist-sch-abl-kl-1_2024.01.09_13:55:55",
    # "morphomnist-sch-abl-kl-2_2024.01.09_13:55:48",
    # "morphomnist-sch-abl-kl-3_2024.01.09_13:55:55",
    # "morphomnist-sch-abl-kl-4_2024.01.09_13:55:49",
    # "morphomnist-sch-abl-om1-0_2024.01.09_13:55:48",
    # "morphomnist-sch-abl-om1-1_2024.01.09_13:55:50",
    # "morphomnist-sch-abl-om1-2_2024.01.09_13:55:55",
    # "morphomnist-sch-abl-om1-3_2024.01.09_13:55:56",
    # "morphomnist-sch-abl-om1-4_2024.01.09_13:55:51",
    # "morphomnist-sch-om1-sm-0_2024.01.09_13:55:55",
    # "morphomnist-sch-om1-sm-1_2024.01.09_13:55:49",
    # "morphomnist-sch-om1-sm-2_2024.01.09_13:55:55",
    # "morphomnist-sch-om1-sm-3_2024.01.09_13:55:55",
    # "morphomnist-sch-om1-sm-4_2024.01.09_13:55:55",

    # "morphomnist-sch-abl-both-0_2024.03.24_01:23:39",
    # "morphomnist-sch-abl-both-1_2024.03.24_01:23:39",
    # "morphomnist-sch-abl-both-2_2024.03.24_01:23:39",
    # "morphomnist-sch-abl-both-3_2024.03.24_01:23:39",
    # "morphomnist-sch-abl-both-4_2024.03.24_01:23:33",
    # "morphomnist-sch-abl-both-5_2024.03.24_01:26:03",
    # "morphomnist-sch-abl-kl-0_2024.03.24_01:23:39",
    # "morphomnist-sch-abl-kl-1_2024.03.24_01:23:39",
    # "morphomnist-sch-abl-kl-2_2024.03.24_01:23:38",
    # "morphomnist-sch-abl-kl-3_2024.03.24_01:23:33",
    # "morphomnist-sch-abl-kl-4_2024.03.24_01:23:39",
    # "morphomnist-sch-abl-kl-5_2024.03.24_01:26:02",
    # "morphomnist-sch-om1-sm-0_2024.03.24_01:23:33",
    # "morphomnist-sch-om1-sm-1_2024.03.24_01:23:39",
    # "morphomnist-sch-om1-sm-2_2024.03.24_01:23:40",
    # "morphomnist-sch-om1-sm-3_2024.03.24_01:23:40",
    # "morphomnist-sch-om1-sm-4_2024.03.24_01:23:39",
    # "morphomnist-sch-om1-sm-5_2024.03.24_01:26:02",

    # "morphomnist-lin-sch-om1-sm-0_2024.04.03_23:29:19",
    # "morphomnist-lin-sch-om1-sm-0_2024.04.15_18:36:15",
    # "morphomnist-lin-sch-om1-sm-1_2024.04.03_23:29:19",
    # "morphomnist-lin-sch-om1-sm-1_2024.04.15_18:36:15",
    # "morphomnist-lin-sch-om1-sm-2_2024.04.15_18:36:15",
    # "morphomnist-lin-sch-om1-sm-3_2024.04.15_18:36:15",
    # "morphomnist-lin-sch-om1-sm-4_2024.04.15_18:36:15",
    # "morphomnist-full-lin-sch-abl-kl-0_2024.04.15_19:10:53",
    # "morphomnist-full-lin-sch-abl-kl-1_2024.04.15_19:10:55",
    # "morphomnist-full-lin-sch-abl-kl-2_2024.04.20_20:12:34",
    # "morphomnist-full-lin-sch-abl-kl-3_2024.04.20_20:12:34",
    "morphomnist-full-lin-sch-abl-kl-4_2024.04.20_20:12:35",
    "morphomnist-full-lin-sch-abl-kl-5_2024.04.20_20:12:35"
]

def find_best_checkpoint(run_to_check_full_path, metric="mean_sq_loss", verbose=True):
    if verbose: 
        print(f"finding best checkpoint for run {run_to_check_full_path}...")

    df_pth = os.path.join(run_to_check_full_path, "gt.csv") # "gt_thickness.csv" "gt_intensity.csv"
    df = pd.read_csv(df_pth)

    min_idx = df[metric].argmin()
    min_epoch = int(df.iloc[min_idx]['epoch'])

    best_ckpt_pth = [x for x in os.listdir(full_pth) if f"epoch={min_epoch}.pt" in x]
    assert(len(best_ckpt_pth) == 1)

    return best_ckpt_pth[0]

for run in runs_to_check: 
    print(f"measuring run {run}...")
    full_pth = os.path.join(
        artifact_path, 
        "saves",
        run
    )

    checkpoints = [x for x in os.listdir(full_pth) if x.endswith(".pt")]
    #checkpoints = [x for x in checkpoints if int(x.split(".")[0].split("=")[-1]) <= 100]
    for checkpoint in checkpoints:
        print(f"    measuring ckpt {checkpoint}")
        #todo: add check in to see if it is cached.
        with torch.no_grad():
            write_gt_info(run_name=run, checkpoint_name=checkpoint, measure=measure)

    best_checkpoint = find_best_checkpoint(full_pth)
    print(f"measuring best ckpt {best_checkpoint}")
    #todo: add check in to see if it is cached.
    with torch.no_grad():
        write_gt_info(run_name=run, checkpoint_name=best_checkpoint, measure=True)
