import numpy as np
import tqdm

import torch
import pandas as pd
import matplotlib.pyplot as plt
import os
import wandb

from locationencoder.locationencoder3d import LocationEncoder3D
from birdsnap.birdsnap_utils import combined_vision_location_encoding_evaluation
from birdsnap.birdsnap_test_snapshots import check_or_update_combined_accuracies_snapshot

from sklearn.metrics import (
    accuracy_score, 
    jaccard_score,
)

class BirdsnapLocationEncoder3D(LocationEncoder3D):
    """
    Birdsnap-specific LightningModule for 3D location encoding, with custom test hooks and snapshotting.
    """

    def __init__(self, *args, sub_experiment=None, args_birdsnap_harmonized=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.sub_experiment = sub_experiment
        self.args_birdsnap_harmonized = args_birdsnap_harmonized

    def test_step(self, batch: tuple, batch_idx: int) -> dict:
        """
        Test step.

        Args:
            batch (tuple): Batch of data.
            batch_idx (int): Batch index.

        Returns:
            dict: Test results.
        """
        lat_lon_time, vision_logits, label = batch

        prediction_logits = self.forward(lat_lon_time)
        loss = self.loss_fn(self, prediction_logits, label)
        final_ortho_regularizer, space_ortho_regularizer, time_ortho_regularizer = self.compute_ortho_regularizer(prediction_logits)

        self.log("test_final_regularizer", final_ortho_regularizer, on_step=True, on_epoch=True)
        self.log("test_space_regularizer", space_ortho_regularizer, on_step=True, on_epoch=True)
        self.log("test_time_regularizer", time_ortho_regularizer, on_step=True, on_epoch=True)



        if self.regression:
            prediction_logits = prediction_logits * self.stds + self.means
            label = label * self.stds + self.means
            loss *= self.mean_var        

        # check if binary
        non_binary_task = self.regression
        if (prediction_logits.size(1) == 1) and not (non_binary_task):
            y_pred = (prediction_logits.squeeze() > 0)
            average = "binary"
        elif self.regression:
            y_pred = prediction_logits.cpu()
        else: # take argmax
            y_pred = prediction_logits.argmax(-1).cpu()
            average = "macro"

        self.log("test_loss", loss, on_step=True, on_epoch=True)
        if self.regression:
            # nan_scaling_factor_label = np.nanstd(label.cpu(), axis=0) + 1e-8
            nan_scaling_factor_label = 1.0
            # nan_scaling_factor_label = 1.0 / self.stds # undo the normalization of the labels

            MAE = np.mean(np.nanmean(np.abs(label.cpu() - y_pred.cpu()) / nan_scaling_factor_label, axis=0), axis=0)
            RMSE = np.sqrt(np.nanmean((label.cpu() - y_pred.cpu())**2 / nan_scaling_factor_label**2, axis=0)).mean()
            
            self.log("test_MAE", MAE, on_step=True, on_epoch=True)
            self.log("test_RMSE", RMSE, on_step=True, on_epoch=True)
            
            test_results = {
                            "test_loss" : loss,
                            "test_MAE" : MAE,
                            "test_RMSE" : RMSE,
                            "y_pred" : y_pred,
                            "logits" : prediction_logits,
                            "vision_logits": vision_logits,
                            }            
        else:
            accuracy = accuracy_score(y_true=label.cpu(), y_pred= y_pred) #.astype("float32")
            IoU = jaccard_score(y_true=label.cpu(),  y_pred = y_pred, average=average) # .astype("float32")
            self.log("test_accuracy", accuracy, on_step=False, on_epoch=True)
            self.log("test_IoU", IoU, on_step=False, on_epoch=True)
            
            test_results = {
                            "test_loss":loss,
                            "test_accuracy": accuracy,
                            "logits": prediction_logits,
                            "vision_logits" : vision_logits,
                            "y_pred": y_pred,
                            }
                
        return test_results
    
    def on_test_epoch_start(self):
        self.lons = np.array([])
        self.lats = np.array([])
        self.times = np.array([])
        self.vision_logits = np.array([])
        self.logits = np.array([])
        self.y_preds = np.array([])
        self.labels = np.array([])

    def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):
        lon_lat_time, vision_logits, label = batch
        logits = outputs["logits"]
        y_pred = outputs["y_pred"]

        self.lons = np.append(self.lons, lon_lat_time[:,0].cpu().numpy())
        self.lats = np.append(self.lats, lon_lat_time[:,1].cpu().numpy())
        self.times = np.append(self.times, lon_lat_time[:,2].cpu().numpy())

        self.vision_logits = np.concatenate((self.vision_logits, vision_logits.cpu().numpy()), axis=0) if len(self.vision_logits) > 0 else vision_logits.cpu().numpy()    
        self.logits = np.concatenate((self.logits, logits.cpu().numpy()), axis=0) if len(self.logits) > 0 else logits.cpu().numpy()    
        self.y_preds = np.concatenate((self.y_preds, y_pred.cpu().numpy()), axis=0) if len(self.y_preds) > 0 else y_pred.cpu().numpy()
        self.labels = np.concatenate((self.labels, label.cpu().numpy()), axis=0) if len(self.labels) > 0 else label.cpu().numpy()

    def on_test_epoch_end(self):
        # --- Temporal accuracy ---
        temporal_accuracies = []
        unique_times = np.unique(self.times)
        for t in unique_times:
            idx = np.where(self.times == t)[0]
            temporal_accuracies.append(np.mean(self.labels[idx] == self.y_preds[idx]))
        temporal_df = pd.DataFrame({"time": unique_times, "temporal_accuracy": temporal_accuracies})
        self.logger.experiment.log({"temporal_accuracy_table": wandb.Table(dataframe=temporal_df)})

        # --- Spatial accuracy ---
        lonlat_pairs = np.unique(np.stack((self.lons, self.lats), axis=-1), axis=0)
        spatial_accuracies = []
        for lon, lat in lonlat_pairs:
            idx = np.where((self.lons == lon) & (self.lats == lat))[0]
            spatial_accuracies.append(np.mean(self.labels[idx] == self.y_preds[idx]))
        spatial_df = pd.DataFrame({"lons": lonlat_pairs[:,0], "lats": lonlat_pairs[:,1], "spatial_accuracy": spatial_accuracies})
        self.logger.experiment.log({"table of spatial accuracies": wandb.Table(dataframe=spatial_df)})

        self.log("test_spatial_accuracy", np.nanmean(spatial_accuracies), on_step=False, on_epoch=True)
        self.log("test_temporal_accuracy", np.mean(temporal_df["temporal_accuracy"]), on_step=False, on_epoch=True)

        # --- Plotting (optional, can be commented out if running headless) ---
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.set_title("Spatial Accuracy")
        ax.scatter(spatial_df["lons"], spatial_df["lats"], c=spatial_df["spatial_accuracy"], cmap='viridis', s=2, alpha=0.5)
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")
        fig.colorbar(ax.collections[0], ax=ax, orientation='vertical')
        plt.close(fig)
        wandb.log({"Spatial Accuracy": wandb.Image(fig)})

        fig, ax = plt.subplots(figsize=(10, 5))
        ax.set_title("Distribution of Time Values")
        ax.hist(self.times, bins=366, color='indianred', alpha=0.7)
        ax.set_xlabel("Time")
        ax.set_ylabel("Frequency")
        plt.close(fig)
        wandb.log({"Distribution of Time Values": wandb.Image(fig)})

        fig, ax = plt.subplots(figsize=(10, 5))
        ax.set_title("Temporal Accuracy")
        ax.plot(temporal_df["time"], temporal_df["temporal_accuracy"])
        ax.set_xlabel("Time")
        ax.set_ylabel("Temporal Accuracy")
        plt.close(fig)
        wandb.log({"Temporal Accuracy": wandb.Image(fig)})

        # --- Combined vision-location accuracies ---
        k = 5
        (
            top_1_combined_accuracy,
            top_k_combined_accuracy,
            top_1_vision_accuracy,
            top_k_vision_accuracy,
            top_1_location_accuracy,
            top_k_location_accuracy,
        ) = combined_vision_location_encoding_evaluation(self.logits, self.vision_logits, self.labels, k=k)

        wandb.log({
            f"Top-1 Combined Accuracy": top_1_combined_accuracy,
            f"Top-{k} Combined Accuracy": top_k_combined_accuracy,
            f"Top-1 Vision Accuracy": top_1_vision_accuracy,
            f"Top-{k} Vision Accuracy": top_k_vision_accuracy,
            f"Top-1 Location Accuracy": top_1_location_accuracy,
            f"Top-{k} Location Accuracy": top_k_location_accuracy,
        })

        # --- Snapshot all returned accuracies ---
        combined_accuracies = {
            "top_1_combined_accuracy": top_1_combined_accuracy,
            "top_k_combined_accuracy": top_k_combined_accuracy,
            "top_1_vision_accuracy": top_1_vision_accuracy,
            "top_k_vision_accuracy": top_k_vision_accuracy,
            "top_1_location_accuracy": top_1_location_accuracy,
            "top_k_location_accuracy": top_k_location_accuracy,
        }
        sub_experiment = self.sub_experiment
        args_birdsnap_harmonized = self.args_birdsnap_harmonized

        os.makedirs("artifacts", exist_ok=True)
        combined_acc_snapshot_file = (
            f"artifacts/combined_accuracies_"
            f"seed{args_birdsnap_harmonized.seed}_"
            f"combtype_{sub_experiment['combination_type']}_"
            f"timeemb_{sub_experiment['time_embedding_type']}_"
            f"timeembdim_{sub_experiment['time_embedding_dim']}_"
            f"ortho_{sub_experiment['ortho_weight']}_"
            f"ortho_space_{sub_experiment.get('ortho_weight_space', 0)}_"
            f"ortho_time_{sub_experiment.get('ortho_weight_time', 0)}_"
            f"subset_{sub_experiment['subset_fraction']}_"
        )

        check_or_update_combined_accuracies_snapshot(
            accuracies=combined_accuracies,
            config=sub_experiment,
            snapshot_file=combined_acc_snapshot_file
        )

        # --- Log predictions and class distribution ---
        location_embeddings = torch.sigmoid(torch.from_numpy(self.logits)).numpy()
        location_embeddings = location_embeddings / np.sum(location_embeddings, axis=1, keepdims=True)
        combined_logits = location_embeddings * self.vision_logits
        combined_predictions = np.argmax(combined_logits, axis=1)

        predictions_df = pd.DataFrame({
            'longitude': self.lons,
            'latitude': self.lats,
            'time': self.times,
            'true_label': self.labels,
            'predicted_label': self.y_preds,
            'combined_prediction': combined_predictions,
        })
        predictions_table = wandb.Table(dataframe=predictions_df.head(1000))
        self.logger.experiment.log({"predictions_sample": predictions_table})

        class_distribution = pd.DataFrame({
            'class': range(self.combined_encoding_args["output_dim"]),
            'count': np.bincount(self.labels.astype(int), minlength=self.combined_encoding_args["output_dim"])
        })
        class_dist_table = wandb.Table(dataframe=class_distribution)
        self.logger.experiment.log({"class_distribution": class_dist_table})
