from collections import OrderedDict
from datetime import datetime
import os
import random
from typing import Set

import hydra
from omegaconf import DictConfig, OmegaConf
import torch.multiprocessing
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb

from datasets import *
from fmax_auprc_metrics import compute_precision_recall_metrics
from metrics import compute_metrics


torch.multiprocessing.set_sharing_strategy('file_system')

KMP_AFFINITY=None  # to put all the worker processed onto all available CPUs

def build_model_trainer(experiment_cfg: DictConfig):
    """Build Trainer from yaml config file."""
    return Trainer(experiment_cfg)


class Trainer:
    def __init__(
        self,
        config: DictConfig,
    ):
        self.config = config
        self.common = config.common
        self.features = config.features
        self.ontology = config.dataset.ontology
        self.standardize = self.features.standardize
        torch.set_num_threads(self.common.num_threads)

        self.num_workers = self.common.num_workers
        self.hp = config.hyper_params
        self.accumulation_steps = self.hp.accumulation_steps
        self.train_batch_size = self.hp.train_batch_size
        self.val_batch_size = self.hp.val_batch_size
        self.n_class = config.model.n_class
        self.device = self.common.device
        self.model = self.build_model()

        if hasattr(self.common, "class_weights"):
            self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
        else:
            self.loss_fn = nn.BCEWithLogitsLoss()

        # Log configs
        time_info = datetime.now()
        self.start_time = time_info
        self.timestamp = f"{time_info.year}-{time_info.month}-{time_info.day}-{time_info.hour}-{time_info.minute}"
        self.output_path = os.path.join(self.common.output_path, self.common.experiment_name)
        self.timestamp_path = os.path.join(self.output_path, self.timestamp)
        self.checkpoint_path = os.path.join(self.timestamp_path, "checkpoints")

        if self.common.validation_dir == "val":
            os.makedirs(self.timestamp_path, exist_ok=True)
            os.makedirs(self.checkpoint_path, exist_ok=True)

        if self.common.checkpoint:
            print(f"Reloading model from {self.common.checkpoint}")
            checkpoint_file = os.path.join(self.common.checkpoints_dir, self.common.checkpoint)
            state_dict = torch.load(checkpoint_file)
            if "module.linear.weight" in state_dict.keys():
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    if 'module' in k:
                        new_key = k[7:]
                    else:
                        new_key = k
                    new_state_dict[new_key] = v
                self.model.load_state_dict(new_state_dict)
            else:
                self.model.load_state_dict(torch.load(checkpoint_file, self.device))
        else:
            print("Initializing from scratch")

        if config.log_wandb:
            self._initialize_wandb()

        self.get_files = GetDataFiles(self.config)
        self.labels = self.get_files.get_labels()

    def build_model(self):
        model = hydra.utils.instantiate(self.config.model)
        print(self.config.model._target_)
        model = model.to(self.device)
        print(f"Model is on device {self.device}.")
        return model

    def _initialize_wandb(self):
        """Initialize WandB logging."""
        wandb.init(
            project=f"ddots-{self.ontology.upper()}",
            entity="",
            name=f"{self.common.experiment_name}",
            dir=self.timestamp_path,
        )
        wandb.config.update(OmegaConf.to_container(self.config))

    def build_dataloader_for_embeds(self, data: Dict, batch_size: int) -> DataLoader:
        dataset = EmbeddingDataset(self.config, data)
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=lambda x: padding_collator(
                x,
                standardize=self.standardize,
                use_lm=self.features.use_lm,
                use_af=self.features.use_af,
                use_single=(self.features.af.use_single_rep or self.features.af.use_states_rep),
                use_pair=self.features.af.use_pair_rep
            ),
            num_workers=self.num_workers
        )
        return dataloader

    def setup_train(self) -> None:
        # Random seed
        torch.backends.cudnn.deterministic = True
        torch.manual_seed(self.hp.seed)
        random.seed(self.hp.seed)
        np.random.seed(self.hp.seed)
        torch.cuda.manual_seed_all(self.hp.seed)

        # Optimizer
        self.optimizer = optim.Adam(
            params=self.model.parameters(),
            lr=self.hp.learning_rate,
            weight_decay=self.hp.weight_decay
        )

        # Datasets
        self.labels = self.get_files.get_labels()

        train_data = self.get_files(split="train", max_length=self.config.dataset.train_max_length)
        val_data = self.get_files(split="val", max_length=self.config.dataset.val_max_length)

        self.train_dataloader = self.build_dataloader_for_embeds(
            data=train_data,
            batch_size=self.train_batch_size
        )
        self.val_dataloader = self.build_dataloader_for_embeds(
            data=val_data,
            batch_size=self.val_batch_size
        )

    def train_loop(self) -> Dict[str, float]:
        self.model.train()

        total_loss = []
        all_logits, all_labels, all_clusters = [], [], []

        for i, data in enumerate(self.train_dataloader):
            if self.features.get('use_lm') and self.features.get('use_af'):
                if self.features.af.get('use_pair_rep'):
                    logits = self.model(s=data['embeds'].to(self.device), z=data['af_pairs'].to(self.device))
                elif (self.features.af.get('use_single_rep') or self.features.af.get('use_states_rep')):
                    raise NotImplementedError

            elif self.features.get('use_lm'):
                if self.standardize:
                    feats_mean = data['feats_mean'].unsqueeze(1).expand(data['embeds'].shape)
                    feats_std = data['feats_std'].unsqueeze(1).expand(data['embeds'].shape)
                    x = (data['embeds'] - feats_mean) / feats_std
                else:
                    x = data['embeds']
                logits = self.model(x.to(self.device))

            elif self.features.get('use_af'):
                if self.features.af.get('use_single_rep') and self.features.af.get('use_pair_rep'):
                    logits = self.model(s=data['embeds'].to(self.device), z=data['af_pairs'].to(self.device))
                elif self.features.af.get('use_single_rep'):
                    logits = self.model(data['embeds'].to(self.device))
                elif self.features.af.get('use_pair_rep'):
                    logits = self.model(z=data['af_pairs'].to(self.device))

            # [*, n_class]
            labels = data['labels'].reshape(shape=(-1, self.n_class))

            loss = self.loss_fn(logits, labels.to(self.device))
            loss.backward()
            total_loss.append(loss.detach().item())
            
            # gradients accumulation
            if (self.accumulation_steps != 0 and i % self.accumulation_steps == 0) or self.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
            
            all_labels.append(labels)
            all_logits.append(logits.detach().cpu())
            all_clusters.append(data['clusters'])
        
        # [N_seq, n_class]
        all_logits = torch.cat(all_logits, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        all_clusters = np.concatenate(all_clusters, axis=0)

        fmax, reweighted_fmax, macro_aupr, micro_aupr = compute_metrics(
            all_logits, all_labels, do_reweighted_fmax=True, clusters=all_clusters
            )

        metrics = {
            "train macro auprc": macro_aupr,
            "train micro auprc": micro_aupr,
            "train fmax": fmax,
            "train fmax cluster": reweighted_fmax,
            "train metrics mean": np.mean([macro_aupr, reweighted_fmax]),
            "train loss": np.mean(total_loss),
        }
        return metrics

    def eval_loop(self, train_labels: Set[int], val_labels: Set[int]) -> Dict[str, float]:
        self.model.eval()
        total_loss = []
        all_logits, all_labels, all_thresholds, all_clusters = [], [], [], []

        with torch.no_grad():
            for i, data in enumerate(self.val_dataloader):
                if (self.features.get('use_lm') and self.features.get('use_af')):
                    if self.features.af.get('use_pair_rep'):
                        logits = self.model(s=data['embeds'].to(self.device), z=data['af_pairs'].to(self.device))
                    elif (self.features.af.get('use_single_rep') or self.features.af.get('use_states_rep')):
                        raise NotImplementedError

                elif self.features.get('use_lm'):
                    if self.standardize:
                        feats_mean = data['feats_mean'].unsqueeze(1).expand(data['embeds'].shape)
                        feats_std = data['feats_std'].unsqueeze(1).expand(data['embeds'].shape)
                        x = (data['embeds'] - feats_mean) / feats_std
                    else:
                        x = data['embeds']
                    logits = self.model(x.to(self.device))

                elif self.features.get('use_af'):
                    if self.features.af.get('use_single_rep') and self.features.af.get('use_pair_rep'):
                        logits = self.model(s=data['embeds'].to(self.device), z=data['af_pairs'].to(self.device))
                    elif self.features.af.get('use_single_rep'):
                        logits = self.model(data['embeds'].to(self.device))
                    elif self.features.af.get('use_pair_rep'):
                        logits = self.model(z=data['af_pairs'].to(self.device))

                labels = data['labels'].reshape(shape=(-1, self.n_class))
                all_labels.append(labels)
                all_logits.append(logits.detach().cpu())
                all_thresholds.append(data['thresholds'])
                all_clusters.append(data['clusters'])

            # [N_proteins, n_class]
            all_logits = torch.cat(all_logits, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            all_thresholds = np.concatenate(all_thresholds, axis=0)
            all_clusters = np.concatenate(all_clusters, axis=0)

            if not self.config.dataset.keep_zero_shot_labels:
                # optionally exclude positive labels not seen in train
                zero_shot_labels = sorted(val_labels - train_labels)
                label_mask = np.ones(all_labels.shape[-1], bool)
                label_mask[zero_shot_labels] = False
                all_logits = all_logits[:, label_mask]
                all_labels = all_labels[:,label_mask]

            losses, fmax_vals, fmax_cluster_vals, micro_aupr_vals, macro_aupr_vals, metrics_means = [], [], [], [], [], []
            for t in np.unique(all_thresholds):
                indices = np.where(all_thresholds==t)
                logits_t = all_logits[indices]
                labels_t = all_labels[indices]
                clusters_t = all_clusters[indices]
                loss_t = self.loss_fn(logits_t.to(self.device), labels_t.to(self.device)).mean().detach().item()
                fmax_t, reweighted_fmax_t, macro_aupr_t, micro_aupr_t = compute_metrics(
                    logits_t, labels_t, do_reweighted_fmax=True, clusters=clusters_t
                    )
                metrics_mean_t = np.mean([macro_aupr_t, fmax_t])
                losses.append(loss_t)
                fmax_vals.append(fmax_t)
                fmax_cluster_vals.append(reweighted_fmax_t)
                macro_aupr_vals.append(macro_aupr_t)
                micro_aupr_vals.append(micro_aupr_t)
                metrics_means.append(metrics_mean_t)

            total_loss = self.loss_fn(all_logits.to(self.device), all_labels.to(self.device)).mean().detach().item()
            fmax, reweighted_fmax, macro_aupr, micro_aupr = compute_metrics(
                all_logits, all_labels, do_reweighted_fmax=True, clusters=all_clusters
                )
            metrics_mean = np.mean([macro_aupr, reweighted_fmax])

        metrics = {
            "val macro aupr": macro_aupr,
            "val micro aupr": micro_aupr,
            "val fmax": fmax,
            "val fmax cluster": reweighted_fmax,
            "val metrics mean": metrics_mean,
            "val loss": total_loss,
        }

        for i, t in enumerate(np.unique(all_thresholds)):
            metrics[f"val {t} macro aupr"] = macro_aupr_vals[i]
            metrics[f"val {t} micro aupr"] = micro_aupr_vals[i]
            metrics[f"val {t} fmax"] = fmax_vals[i]
            metrics[f"val {t} fmax cluster"] = fmax_cluster_vals[i]
            metrics[f"val {t} metrics mean"] = metrics_means[i]
            metrics[f"val {t} loss"] = losses[i]

        if hasattr(self.config, "lr_scheduler"):
            self.lr_scheduler.step(metrics_mean)
        return metrics

    def eval_loop_new_test_metrics(self, train_labels: Set[int], val_labels: Set[int]) -> Dict[str, float]:
        self.model.eval()
        total_loss = []
        all_logits, all_labels, all_thresholds, all_clusters = [], [], [], []

        with torch.no_grad():
            for i, data in enumerate(self.val_dataloader):
                if (self.features.get('use_lm') and self.features.get('use_af')):
                    if self.features.af.get('use_pair_rep'):
                        logits = self.model(s=data['embeds'].to(self.device), z=data['af_pairs'].to(self.device))
                    elif (self.features.af.get('use_single_rep') or self.features.af.get('use_states_rep')):
                        raise NotImplementedError

                elif self.features.get('use_lm'):
                    if self.standardize:
                        feats_mean = data['feats_mean'].unsqueeze(1).expand(data['embeds'].shape)
                        feats_std = data['feats_std'].unsqueeze(1).expand(data['embeds'].shape)
                        x = (data['embeds'] - feats_mean) / feats_std
                    else:
                        x = data['embeds']
                    logits = self.model(x.to(self.device))

                elif self.features.get('use_af'):
                    if self.features.af.get('use_single_rep') and self.features.af.get('use_pair_rep'):
                        logits = self.model(s=data['embeds'].to(self.device), z=data['af_pairs'].to(self.device))
                    elif self.features.af.get('use_single_rep'):
                        logits = self.model(data['embeds'].to(self.device))
                    elif self.features.af.get('use_pair_rep'):
                        logits = self.model(z=data['af_pairs'].to(self.device))
 
                labels = data['labels'].reshape(shape=(-1, self.n_class))
                all_labels.append(labels)
                all_logits.append(logits.detach().cpu())
                all_thresholds.append(data['thresholds'])
                all_clusters.append(data['clusters'])

            # [N_proteins, n_class]
            all_logits = torch.cat(all_logits, dim=0)
            all_labels = torch.cat(all_labels, dim=0)
            all_thresholds = np.concatenate(all_thresholds, axis=0)
            all_clusters = torch.as_tensor(np.concatenate(all_clusters, axis=0), dtype=int)

            if not self.config.dataset.keep_zero_shot_labels:
                # optionally exclude positive labels not seen in train
                zero_shot_labels = sorted(val_labels - train_labels)
                label_mask = np.ones(all_labels.shape[-1], bool)
                label_mask[zero_shot_labels] = False
                all_logits = all_logits[:, label_mask]
                all_labels = all_labels[:,label_mask]

            losses, fmax_vals, fmax_clust_vals, auprc_clust_vals, auprc_bfr_vals, auprc_aft_vals = [], [], [], [], [], []
  
            for t in np.unique(all_thresholds):
                indices = np.where(all_thresholds==t)
                logits_t = all_logits[indices]
                labels_t = all_labels[indices]
                clusters_t = all_clusters[indices]
                loss_t = self.loss_fn(logits_t.to(self.device), labels_t.to(self.device)).mean().detach().item()

                fmax_t = compute_precision_recall_metrics(
                    logits=logits_t, 
                    labels=labels_t, 
                    metric_type="fmax", 
                    averaging="per_sample", 
                    average_before_metric=True
                    )
                fmax_clust_t = compute_precision_recall_metrics(
                    logits=logits_t, 
                    labels=labels_t, 
                    metric_type="fmax", 
                    averaging="per_sample", 
                    average_before_metric=True, 
                    clusters=clusters_t
                    )
                auprc_clust_t = compute_precision_recall_metrics(
                    logits=logits_t, 
                    labels=labels_t, 
                    metric_type="auprc", 
                    averaging="per_sample", 
                    average_before_metric=True, 
                    clusters=clusters_t
                    )
                auprc_bfr_t = compute_precision_recall_metrics(
                    logits=logits_t,
                      labels=labels_t, 
                      metric_type="auprc", 
                      averaging="per_label", 
                      average_before_metric=True
                      )
                auprc_aft_t = compute_precision_recall_metrics(
                    logits=logits_t, 
                    labels=labels_t, 
                    metric_type="auprc", 
                    averaging="per_label", 
                    average_before_metric=False
                    )
 
                losses.append(loss_t)
                fmax_vals.append(fmax_t)
                fmax_clust_vals.append(fmax_clust_t)
                auprc_clust_vals.append(auprc_clust_t)
                auprc_bfr_vals.append(auprc_bfr_t)
                auprc_aft_vals.append(auprc_aft_t)

            total_loss = self.loss_fn(all_logits.to(self.device), all_labels.to(self.device)).mean().detach().item()
            
            fmax = compute_precision_recall_metrics(
                logits=all_logits, 
                labels=all_labels, 
                metric_type="fmax", 
                averaging="per_sample", 
                average_before_metric=True
                )
            fmax_clust = compute_precision_recall_metrics(
                logits=all_logits,
                labels=all_labels, 
                metric_type="fmax", 
                averaging="per_sample", 
                average_before_metric=True, 
                clusters=all_clusters
                )
            auprc_clust = compute_precision_recall_metrics(
                logits=all_logits, 
                labels=all_labels, 
                metric_type="auprc", 
                averaging="per_sample", 
                average_before_metric=True, 
                clusters=all_clusters
                )
            auprc_bfr = compute_precision_recall_metrics(
                logits=all_logits, 
                labels=all_labels, 
                metric_type="auprc", 
                averaging="per_label", 
                average_before_metric=True
                )
            auprc_aft = compute_precision_recall_metrics(
                logits=all_logits, 
                labels=all_labels, 
                metric_type="auprc", 
                averaging="per_label", 
                average_before_metric=False
                )
            mean_aft = np.mean([fmax_t, auprc_aft_t])

        metrics = {
            "fmax": round(fmax, 4),
            "fmax_clust": round(fmax_clust, 4),
            "auprc_clust": round(auprc_clust, 4),
            "auprc_bfr": round(auprc_bfr, 4),
            "auprc_aft": round(auprc_aft, 4),
            "loss": round(total_loss, 4),
        }

        for i, t in enumerate(np.unique(all_thresholds)):
            metrics[f"{t}_fmax"] = round(fmax_vals[i], 4)
            metrics[f"{t}_fmax_clust"] = round(fmax_clust_vals[i], 4)
            metrics[f"{t}_auprc_clust"] = round(auprc_clust_vals[i], 4)
            metrics[f"{t}_auprc_bfr"] = round(auprc_bfr_vals[i], 4)
            metrics[f"{t}_auprc_aft"] = round(auprc_aft_vals[i], 4)
            metrics[f"{t}_loss"] = round(losses[i], 4)

        if hasattr(self.config, "lr_scheduler"):
            self.lr_scheduler.step(mean_aft)
        return metrics
    
    def train(self) -> None:
        self.setup_train()
        last_res = 0.0001
        before_last_res = 0.0001
 
        train_labels = self.labels["train"]
        val_labels = self.labels["val"]

        for epoch in range(self.hp.num_epoch):
            train_metrics = self.train_loop()

            # evaluate and log metrics only every 5 epochs
            if (epoch % 10 == 0) or epoch ==1:
                val_metrics = self.eval_loop(train_labels, val_labels)
                # all_metrics = train_metrics | val_metrics # works with python 3.9
                all_metrics = {**train_metrics, **val_metrics}

                if self.config.log_wandb:
                    current_lr = self.optimizer.param_groups[0]['lr']
                    all_metrics['epoch'] = epoch
                    all_metrics['learning rate'] = current_lr
                    all_metrics['minutes elapsed'] = (datetime.now() - self.start_time).seconds / 60
                    wandb.log(all_metrics)

                if (self.common.save_checkpoints) and (epoch > 1):
                    exp_name = self.common.experiment_name
                    save_name = f"{exp_name}_epoch_{epoch}_loss_{all_metrics['val loss']:.4f}_metrics_mean_{all_metrics['val metrics mean']:.3f}.pt"
                    torch.save(self.model.state_dict(), os.path.join(self.checkpoint_path, save_name))

                # early stopping
                new_res = val_metrics["val metrics mean"]
                if epoch == 1:
                    last_res = new_res
                else:
                    if ((new_res / last_res)-1 < 0.001) and ((new_res / before_last_res)-1 < 0.001):
                        break
                    else:
                        before_last_res = last_res
                        last_res = new_res

    def evaluate(self) -> None:
        test_data = self.get_files(split="test", max_length=self.config.dataset.val_max_length)
        train_labels = self.labels["train"]
        test_labels = self.labels["test"]

        self.val_dataloader = self.build_dataloader_for_embeds(
            data=test_data,
            batch_size=self.val_batch_size
        )

        test_metrics = self.eval_loop(train_labels, test_labels)
        for i, j in test_metrics.items():
            print(i, round(j, 2))
        return test_metrics

    def evaluate_new_test_metrics(self) -> None:
        test_data = self.get_files(split="test", max_length=self.config.dataset.val_max_length)
        train_labels = self.labels["train"]
        test_labels = self.labels["test"]

        self.val_dataloader = self.build_dataloader_for_embeds(
            data=test_data,
            batch_size=self.val_batch_size
        )

        test_metrics = self.eval_loop_new_test_metrics(train_labels, test_labels)
        for i, j in test_metrics.items():
            print(i, round(j, 2))
        return test_metrics
    