import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import mlflow
import os
import torch.nn as nn
import torchvision.models as models
from torchvision.models import resnet50, ResNet50_Weights
import random
import numpy as np
from collections import OrderedDict
from numbers import Number
import copy
import operator
import re
import pandas as pd
import zipfile
# from google.cloud import storage
import glob

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def set_seed(seed):
    random.seed(seed)                             # Python random module
    np.random.seed(seed)                          # NumPy
    torch.manual_seed(seed)                       # PyTorch (CPU)
    torch.cuda.manual_seed(seed)                  # PyTorch (single GPU)
    torch.cuda.manual_seed_all(seed)              # PyTorch (multi-GPU)
    
    torch.backends.cudnn.deterministic = True     # Force determinism
    torch.backends.cudnn.benchmark = False        # Turn off optimizations that break determinism

    torch.use_deterministic_algorithms(True)      # Force determinism

    os.environ['PYTHONHASHSEED'] = str(seed)      # Python hashing

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def normalize_to_distribution(x, dim=(-2, -1), eps=1e-8):
    x = x.clamp_min(eps)
    x_sum = x.sum(dim=dim, keepdim=True)
    return x / (x_sum + eps)

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, mode='min', verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False

        if self.mode == 'min':
            self.is_better = lambda a, b: a < b - self.min_delta
        elif self.mode == 'max':
            self.is_better = lambda a, b: a > b + self.min_delta
        else:
            raise ValueError("mode must be 'min' or 'max'")

    def __call__(self, current_score):
        if self.best_score is None:
            self.best_score = current_score
        elif self.is_better(current_score, self.best_score):
            self.best_score = current_score
            self.counter = 0
            if self.verbose:
                print(f"Validation metric improved.")
        else:
            self.counter += 1
            if self.verbose:
                print(f"No improvement for {self.counter} epoch(s).")
            if self.counter >= self.patience:
                self.early_stop = True

class GradCamResNet50(nn.Module):
    def __init__(self, num_classes=2):
        super(GradCamResNet50, self).__init__()

        # Load a pretrained ResNet50 model
        self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)

        # Target Layer
        self.target_layer = self.backbone.layer4[-1]

        # Replace the final fully connected layer
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Linear(in_features, num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x

class CustomResNet50(nn.Module):
    def __init__(self, num_classes=2):
        super(CustomResNet50, self).__init__()

        # Load a pretrained ResNet50 model
        self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)

        # Replace the final fully connected layer
        self.in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Linear(self.in_features, num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x

class WholeFish(nn.Module):
    def __init__(self, num_classes=2, weights=None):
        super(WholeFish, self).__init__()
        # Load a pretrained ResNet50 model
        self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)

        # Replace the final fully connected layer
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Linear(in_features, num_classes)
        
        if weights is not None:
            self.load_state_dict(copy.deepcopy(weights))

    def reset_weights(self, weights):
        self.load_state_dict(copy.deepcopy(weights))

    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x

class ResNet50_Featurizer(nn.Module):
    def __init__(self):
        super(ResNet50_Featurizer, self).__init__()
        # Load a pretrained ResNet50 model
        self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        x = self.backbone(x)
        return x

class ResNet50_MLP(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNet50_MLP, self).__init__()
        # Load a pretrained ResNet50 model
        self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)

        # Replace the final fully connected layer
        self.in_features = self.backbone.fc.in_features
        
        self.backbone = nn.Identity()
        self.classifier = nn.Linear(self.in_features, num_classes)

    def forward(self, x):
        x = self.classifier(x)
        return x

class ParamDict(OrderedDict):
    """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile.
    A dictionary where the values are Tensors, meant to represent weights of
    a model. This subclass lets you perform arithmetic on weights directly."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, *kwargs)

    def _prototype(self, other, op):
        if isinstance(other, Number):
            return ParamDict({k: op(v, other) for k, v in self.items()})
        elif isinstance(other, dict):
            return ParamDict({k: op(self[k], other[k]) for k in self})
        else:
            raise NotImplementedError

    def __add__(self, other):
        return self._prototype(other, operator.add)

    def __rmul__(self, other):
        return self._prototype(other, operator.mul)

    __mul__ = __rmul__

    def __neg__(self):
        return ParamDict({k: -v for k, v in self.items()})

    def __rsub__(self, other):
        # a- b := a + (-b)
        return self.__add__(other.__neg__())

    __sub__ = __rsub__

    def __truediv__(self, other):
        return self._prototype(other, operator.truediv)

class MLFlowSaver:
    def __init__(self, experiment_name):
        self.client = mlflow.tracking.MlflowClient()
        self.experiment_name = experiment_name
        self.EXP_STATS_INCLUSION_LIST = [
            'start_time', 'end_time', 
            'tags.mlflow.runName', 'parent_name'
        ]
        self.experiment = mlflow.get_experiment_by_name(self.experiment_name)
        self.experiment_id = self.experiment.experiment_id
        self.df_parents = self.get_parents_df()

        self.df_runs = mlflow.search_runs(experiment_ids=[self.experiment_id],
            filter_string='tags.mlflow.parentRunId != "None" AND status = "FINISHED"')
        self.df_exp = self.create_experiment_sheet()
        self.list_df_runs, self.run_names = self.create_indiv_run_sheets()

    def get_parents_df(self):
        df_parents = mlflow.search_runs(experiment_ids=[self.experiment_id],
            filter_string='status = "FINISHED"')
        df_parents = df_parents[df_parents['tags.mlflow.parentRunId'].isna()]
        df_parents = df_parents[['tags.mlflow.runName', 'run_id']]
        return df_parents
    
    def create_experiment_sheet(self):
        df_exp = self.df_runs
        df_exp['parent_name'] = df_exp.apply(
            lambda x: self.df_parents[self.df_parents['run_id'] == x['tags.mlflow.parentRunId']]['tags.mlflow.runName'].iloc[0].item(),
            axis=1)

        best_metrics_cols = [col for col in df_exp.columns if re.search("metrics.best|params", col)]
        
        inclusion_list = self.EXP_STATS_INCLUSION_LIST.copy()
        inclusion_list.extend(best_metrics_cols)

        df_exp = df_exp[inclusion_list].sort_values(by=['metrics.best_val_wga', 'metrics.best_val_aa'],
                                                    ascending=False)
        return df_exp

    def create_indiv_run_sheets(self):
        df_runs = self.df_runs
        indiv_runs = []
        run_names = []

        for _, row in df_runs.iterrows():
            run = mlflow.get_run(row['run_id'])
            metric_keys = run.data.metrics.keys()
            all_metrics = OrderedDict()
            
            for key in metric_keys:
                if key not in ('best_val_aa', 'best_val_wga', 'best_val_wga_idx'):
                    history = self.client.get_metric_history(row['run_id'], key)
                    all_metrics[key] = [m.value for m in history]
        
            # Convert to DF and save as CSV
            df_metrics = pd.DataFrame(all_metrics)
            df_metrics_cols = df_metrics.columns.to_list()
            df_metrics['epoch'] = list(range(1, len(df_metrics) + 1))
            df_metrics = df_metrics.loc[:, ['epoch'] + df_metrics_cols]

            indiv_runs.append(df_metrics)
            run_names.append(row['tags.mlflow.runName'])

        return indiv_runs, run_names

    def save_experiment_and_runs(self, root_dir='.', cleanup=True):
        zip_dir = os.path.join(root_dir, f'{self.experiment_name}_runs.zip')
            
        final_exp_sheet_path = os.path.join(root_dir, f'{self.experiment_name}_Run_Metrics.csv.gz')
        self.df_exp.to_csv(final_exp_sheet_path, index=False, compression='gzip')

        with zipfile.ZipFile(zip_dir, 'w') as zipf:
            zipf.write(final_exp_sheet_path, arcname=os.path.basename(final_exp_sheet_path))
            if cleanup:
                os.remove(final_exp_sheet_path)
            for indiv_run, run_name in zip(self.list_df_runs, self.run_names):        
                final_runs_sheet_path = os.path.join(root_dir, f'{self.experiment_name}_{run_name}_Epoch_Metrics.csv.gz')
                indiv_run.to_csv(final_runs_sheet_path, index=False, compression='gzip')
                zipf.write(final_runs_sheet_path, arcname=os.path.basename(final_runs_sheet_path))
                if cleanup:
                    os.remove(final_runs_sheet_path)
        print('Successfully saved experiment and runs!')

# def upload_to_gcs(bucket_name, source_file_path, destination_blob_path):
#     client = storage.Client()
#     bucket = client.bucket(bucket_name)
#     blob = bucket.blob(destination_blob_path)
#     blob.upload_from_filename(source_file_path)
#     print(f'Uploaded to gs://{bucket_name}/{destination_blob_path}')

# def upload_folder_to_gcs(bucket_name, source_folder_path, destination_folder_path):
#     client = storage.Client()
#     bucket = client.bucket(bucket_name)
#     for file in glob.glob(os.path.join(source_folder_path, '**', '*'), recursive=True):
#         if os.path.isfile(file):
#             relative_path = os.path.relpath(file, start=source_folder_path)
#             destination_blob_path = os.path.join(destination_folder_path, relative_path)
#             blob = bucket.blob(destination_blob_path)
#             blob.upload_from_filename(file)
#             print(f'Uploaded to gs://{bucket_name}/{destination_blob_path}')