import numpy as np
import torch
from sklearn.metrics import accuracy_score, jaccard_score
import pandas as pd
import wandb
import matplotlib.pyplot as plt

from locationencoder.locationencoder3d import LocationEncoder3D
from utils.classification_utils import combined_vision_location_encoding_evaluation

class InatLocationEncoder3D(LocationEncoder3D):
    """
    Extension of LocationEncoder3D for iNat experiments, implementing test hooks
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def test_step(self, batch, batch_idx):
        lat_lon_time, vision_logits, label = batch
        prediction_logits = self.forward(lat_lon_time)
        loss = self.loss_fn(self, prediction_logits, label)
        final_ortho_regularizer, space_ortho_regularizer, time_ortho_regularizer = self.compute_ortho_regularizer(prediction_logits)
        self.log("test_final_regularizer", final_ortho_regularizer, on_step=True, on_epoch=True)
        self.log("test_space_regularizer", space_ortho_regularizer, on_step=True, on_epoch=True)
        self.log("test_time_regularizer", time_ortho_regularizer, on_step=True, on_epoch=True)
        if self.regression:
            prediction_logits = prediction_logits * self.stds + self.means
            label = label * self.stds + self.means
            loss *= self.mean_var        
        non_binary_task = self.regression
        if (prediction_logits.size(1) == 1) and not (non_binary_task):
            y_pred = (prediction_logits.squeeze() > 0)
            average = "binary"
        elif self.regression:
            y_pred = prediction_logits.cpu()
        else:
            y_pred = prediction_logits.argmax(-1).cpu()
            average = "macro"
        self.log("test_loss", loss, on_step=True, on_epoch=True)
        if self.regression:
            nan_scaling_factor_label = 1.0
            MAE = np.mean(np.nanmean(np.abs(label.cpu() - y_pred.cpu()) / nan_scaling_factor_label, axis=0), axis=0)
            RMSE = np.sqrt(np.nanmean((label.cpu() - y_pred.cpu())**2 / nan_scaling_factor_label**2, axis=0)).mean()
            self.log("test_MAE", MAE, on_step=True, on_epoch=True)
            self.log("test_RMSE", RMSE, on_step=True, on_epoch=True)
            test_results = {
                "test_loss": loss,
                "test_MAE": MAE,
                "test_RMSE": RMSE,
                "y_pred": y_pred,
                "logits": prediction_logits,
                "vision_logits": vision_logits,
            }
        else:
            accuracy = accuracy_score(y_true=label.cpu(), y_pred=y_pred)
            IoU = jaccard_score(y_true=label.cpu(), y_pred=y_pred, average=average)
            self.log("test_accuracy", accuracy, on_step=False, on_epoch=True)
            self.log("test_IoU", IoU, on_step=False, on_epoch=True)
            test_results = {
                "test_loss": loss,
                "test_accuracy": accuracy,
                "logits": prediction_logits,
                "vision_logits": vision_logits,
                "y_pred": y_pred,
            }
        
        return test_results

    def on_test_epoch_start(self):
        self.lons = np.array([])
        self.lats = np.array([])
        self.times = np.array([])
        self.vision_logits = np.array([])
        self.logits = np.array([])
        self.y_preds = np.array([])
        self.labels = np.array([])

    def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):
        lon_lat_time, vision_logits, label = batch
        logits = outputs["logits"]
        y_pred = outputs["y_pred"]
        self.lons = np.append(self.lons, lon_lat_time[:,0].cpu().numpy())
        self.lats = np.append(self.lats, lon_lat_time[:,1].cpu().numpy())
        self.times = np.append(self.times, lon_lat_time[:,2].cpu().numpy())
        self.vision_logits = np.concatenate((self.vision_logits, vision_logits.cpu().numpy()), axis=0) if len(self.vision_logits) > 0 else vision_logits.cpu().numpy()    
        self.logits = np.concatenate((self.logits, logits.cpu().numpy()), axis=0) if len(self.logits) > 0 else logits.cpu().numpy()    
        self.y_preds = np.concatenate((self.y_preds, y_pred.cpu().numpy()), axis=0) if len(self.y_preds) > 0 else y_pred.cpu().numpy()
        self.labels = np.concatenate((self.labels, label.cpu().numpy()), axis=0) if len(self.labels) > 0 else label.cpu().numpy()

    def on_test_epoch_end(self):
        lonlat_pairs = np.unique(np.stack((self.lons, self.lats), axis=-1), axis=0)
        temporal_accuracies = np.array([])
        spatial_accuracies = np.array([])
        for i in np.unique(self.times):
            indices = np.where(self.times == i)[0]
            target = self.labels[indices]
            y_pred = self.y_preds[indices]
            accuracy = np.mean((target==y_pred))
            temporal_accuracies = np.append(temporal_accuracies, accuracy)
        temporal_df = pd.DataFrame({
            "time": np.unique(self.times),
            "temporal_accuracy" : temporal_accuracies, 
            })
        temp_table = wandb.Table(dataframe=temporal_df)
        self.logger.experiment.log({"temporal_accuracy_table": temp_table})
        for lon, lat in lonlat_pairs:
            indices = np.where((self.lons == lon) & (self.lats == lat))[0]
            target = self.labels[indices]
            y_pred = self.y_preds[indices]
            accuracy = np.mean((target == y_pred), axis=0)
            spatial_accuracies = np.append(spatial_accuracies, accuracy)
        spatial_df = pd.DataFrame({
            "lons": lonlat_pairs[:,0],
            "lats": lonlat_pairs[:,1],
            "spatial_accuracy": spatial_accuracies
        })
        spatial_table = wandb.Table(dataframe=spatial_df)
        self.logger.experiment.log({"table of spatial accuracies": spatial_table})
        avg_spatial_accuracy = np.nanmean(spatial_accuracies)
        self.log("test_spatial_accuracy", avg_spatial_accuracy, on_step=False, on_epoch=True)
        avg_temporal_accuracy = np.mean(temporal_df["temporal_accuracy"])
        self.log("test_temporal_accuracy", avg_temporal_accuracy, on_step=False, on_epoch=True)
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.set_title("Spatial Accuracy")
        ax.scatter(spatial_df["lons"], spatial_df["lats"], c=spatial_df["spatial_accuracy"], cmap='viridis', s=2, alpha=0.5)
        ax.set_xlabel("Longitude")
        ax.set_ylabel("Latitude")
        fig.colorbar(ax.collections[0], ax=ax, orientation='vertical')
        plt.show()
        wandb.log({"Spatial Accuracy": wandb.Image(fig)})
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.set_title("Distribution of Time Values")
        ax.hist(self.times, bins=366, color='indianred', alpha=0.7)
        ax.set_xlabel("Time")
        ax.set_ylabel("Frequency")
        plt.show()
        wandb.log({"Distribution of Time Values": wandb.Image(fig)})
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.set_title("Temporal Accuracy")
        ax.plot(temporal_df["time"], temporal_df["temporal_accuracy"])
        ax.set_xlabel("Time")
        ax.set_ylabel("Temporal Accuracy")
        plt.show()
        wandb.log({"Temporal Accuracy": wandb.Image(fig)})
        k=5
        (
            top_1_combined_accuracy,
            top_k_combined_accuracy,
            top_1_vision_accuracy,
            top_k_vision_accuracy,
            top_1_location_accuracy,
            top_k_location_accuracy,
        ) = combined_vision_location_encoding_evaluation(self.logits, self.vision_logits, self.labels, k=k)
        wandb.log({
            f"Top-1 Combined Accuracy": top_1_combined_accuracy,
            f"Top-{k} Combined Accuracy": top_k_combined_accuracy,
            f"Top-1 Vision Accuracy": top_1_vision_accuracy,
            f"Top-{k} Vision Accuracy": top_k_vision_accuracy,
            f"Top-1 Location Accuracy": top_1_location_accuracy,
            f"Top-{k} Location Accuracy": top_k_location_accuracy,
        })


class InatVisionLocationEncoder3D(InatLocationEncoder3D):
    """
    A class for integrating vision logits with location embeddings for iNat experiments.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def common_step(self, batch, batch_idx, mode):
        """
        Common step for training and validation.

        Args:
            batch (tuple): Batch of data containing lat_lon_time, vision_logits, and label.
            batch_idx (int): Batch index.
            mode (str): Mode of operation ('train' or 'val').

        Returns:
            Tensor: Loss value.
        """
        lat_lon_time, vision_logits, label = batch
        # normalize vision logits
        vision_logits = torch.nn.functional.softmax(vision_logits, dim=1)
        
        # Compute time gradient
        # time_grad = self.compute_time_grad(lat_lon_time)
        time_grad_penalty = 0.0
        # if time_grad is not None and time_grad.shape[0] > 0:
        #     time_grad_norm = time_grad.norm(p=2) / time_grad.shape[0]
        #     self.log(f"{mode}_avg_time_derivative_norm", time_grad_norm, on_step=True, on_epoch=True)
        #     time_grad_penalty = time_grad_norm**2

        # Forward pass to compute location logits
        location_logits = self.forward(lat_lon_time)

        # Combine location logits with vision logits
        combined_logits = location_logits + vision_logits

        # Process labels
        labels_one_hot = torch.nn.functional.one_hot(label.long(), num_classes=8142).double()

        labels_reduced = labels_one_hot * (1 - vision_logits)

        # # Compute loss
        # loss = torch.nn.functional.cross_entropy(location_logits, labels_reduced.double(), reduction="mean")

        # Compute loss using combined logits
        loss = torch.nn.functional.cross_entropy(combined_logits, labels_reduced.double(), reduction="mean")
        
        # Add time gradient penalty to the loss
        loss += time_grad_penalty

        return loss
    
    def training_step(self, batch, batch_idx):
        """
        Training step.

        Args:
            batch (tuple): Batch of data.
            batch_idx (int): Batch index.

        Returns:
            Tensor: Loss value.
        """
        loss = self.common_step(batch, batch_idx, mode="train")
        if self.regression:
            loss *= self.mean_var

        self.log("train_loss", loss, on_step=True, on_epoch=True)

        # compute ortho regularizer
        lat_lon_time, _, _ = batch
        ortho_regularizer = self.compute_ortho_regularizer(lat_lon_time)

        self.log("ortho_regularizer", ortho_regularizer[0], on_step=True, on_epoch=True)

        return loss + self.ortho_weight*ortho_regularizer[0]
