import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torchmetrics import MeanAbsoluteError, MeanMetric
import numpy as np
from .utils import disbased_weight
import os
import json
from scipy.io import savemat
from tango.common import Registrable
from tango.integrations.torch import Model

import matplotlib.pyplot as plt

from typing import Tuple, List, Dict
from collections import defaultdict

from .utils import expand_task_and_dataset
    
def construct_hypergraph(data_file, hyper, log_dir, output_fn):
    # for key, value in hyper.items():
        # hyper[key] = np.random.permutation(164)[:len(value)]

    Y = []
    X = []
    for key in data_file.keys():
        data = data_file[key]
        eucli = data["euclidean_dis"]
        sigma = np.median(eucli)
        x = data["x"][()]
        try:
            y = data["y"][()]
        except:
            continue
        num_nodes = x.shape[0]
        hyperedges = np.zeros((num_nodes))
        for i in range(num_nodes):
            if i in hyper:
                knn = hyper[i]
            elif str(i) in hyper:
                knn = hyper[str(i)]
            else:
                raise KeyError
            knn_nodes = x[knn]
            current_node = x[i]
            hyperedges[i]= disbased_weight(current_node,knn_nodes,sigma)
            
        X.append(hyperedges)
        Y.append(y)
            
    X = np.array(X)
    Y = np.array(Y)
    
    X = X.reshape(X.shape[0], np.prod(X.shape[1:]))
    X = np.transpose(X, (-1, 0))
    Y = Y.reshape(-1, 1)
    assert X.ndim == 2
    
    mat_path = os.path.join(log_dir, "hyperedges", output_fn + '.mat')
    
    savemat(mat_path, dict(rest_1_mats=X, PMAT_CR=Y))
    return mat_path

class LightningModule(pl.LightningModule, Registrable):
    default_implementation = "default"
LightningModule.register("default")(LightningModule)


@LightningModule.register("deep_recon")
class DeepRecon(LightningModule):
    def __init__(
        self,
        recon_model: Model,
        mask_model: Model,
        learning_rate: float,
        beta: float,
        ) -> None:
        super().__init__()
        
        self.recon_model = recon_model
        self.mask_model = mask_model
        self.learning_rate = learning_rate
        self.beta = beta

        self.recon_criterion = nn.L1Loss()
        self.adv = MeanMetric()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

        def lr_scaler(epoch):
            warmup_epoch = 20
            if epoch < warmup_epoch:
                # warm up lr
                lr_scale = epoch/warmup_epoch
            else:
                lr_scale = 1.

            return lr_scale

        # scheduler = torch.optim.lr_scheduler.LambdaLR(
        #     optimizer,
        #     lr_lambda=lr_scaler
        # )
        scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=20, max_epochs=self.trainer.max_epochs)

        return [optimizer], [scheduler]
    
    
    def forward(self, batch):
        bs = batch['x'].size(0)
        n_region = self.trainer.datamodule.n_region
        # center_mask = F.one_hot(torch.randint(0, n_region, size=(bs,), device=batch['x'].device), num_classes=n_region)
        center_mask = F.one_hot(torch.ones(bs, device=batch['x'].device, dtype=torch.long), num_classes=n_region)
        
        y = batch['x'][center_mask.bool()]
        # eps = torch.randn_like(y) * 5
        # batch['x'][center_mask.bool()] += eps

        x = batch['x']

        masking_logits = self.mask_model(x, center_mask)
        # token_mask = F.gumbel_softmax(masking_logits, tau=self.tau , hard=True)
        token_mask = F.softmax(masking_logits, dim=-1)
        # randperm = torch.randperm(masking_logits.size(1))
        # masking_logits = masking_logits[:, randperm, :]
        # masking_logits = torch.randn_like(masking_logits)
        # y_soft = F.softmax(masking_logits/self.tau, dim=-1)
        # index = y_soft.max(-1, keepdim=True)[1]
        # y_hard = torch.zeros_like(masking_logits, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        # ret = y_hard - y_soft.detach() + y_soft
        # token_mask = ret
        
        token_mask = token_mask[..., 1]
        # token_mask = torch.ones_like(token_mask)
        token_mask = token_mask * (1 - center_mask.float())
        
        y_hat = self.recon_model(x, center_mask, token_mask)

        return y, y_hat, masking_logits, token_mask, center_mask

    def training_step(self, batch, batch_idx):
        y, y_hat, masking_logits, token_mask, center_mask = self.forward(batch)
        
        mask_density = token_mask.mean()
        
        loss, recon_loss, density_loss = self._get_loss(y_hat, y, masking_logits)

        self.log("train/reconstruction", recon_loss)
        self.log("train/density_loss", density_loss)
        self.log("temperature", self.tau)
        self.log("train/mask_density", mask_density)
        self.log("train/loss", loss)
        # beta = 0.0015
        return loss
    
    def validation_step(self, batch, batch_idx):
        y, y_hat, masking_logits, token_mask, center_mask = self.forward(batch)

        mae_trivial = torch.abs(y).mean()
        mae_pred = torch.abs(y_hat - y).mean()
        adv = mae_trivial - mae_pred
    
        self.adv(adv)
        self.log("advantage", self.adv)
        
        loss, recon_loss, density_loss = self._get_loss(y_hat, y, masking_logits)
        self.log("val/reconstruction", recon_loss)
        self.log("val/density_loss", density_loss)
        self.log("val/loss", loss)
        
        if batch_idx == 0:
            fig, axes = plt.subplots(4, 1, figsize=(6, 15))
            for i in range(min(4, y.size(0))):
                _y = y[i].detach().cpu().numpy()
                _y_hat = y_hat[i].detach().cpu().numpy()
                x = range(len(_y))
                axes[i].plot(x, _y, label='y')
                axes[i].plot(x, _y_hat, label='y_hat')
            self.trainer.logger.experiment.log({"val/fig": fig})
            
    
    def _get_loss(self, y_hat, y, masking_logits):
        recon_loss = self.recon_criterion(y_hat, y)

        one_logits = torch.softmax(masking_logits/self.tau, dim=-1)[..., 1]
        # density_loss = torch.log(one_logits).mean()
        # density_loss = max(- torch.log(1 - one_logits).mean(), 0.7)
        density_loss = (one_logits * self.node_entropy).mean()
        # density_loss = max(token_mask.mean(), 0.5)
        loss = recon_loss + self.beta * density_loss
        
        return loss, recon_loss, density_loss
        
    def export(self, k, data_module, log_dir: str, output_fn: str, use_exist_assign: bool = False):
        json_path = os.path.join(log_dir, "hyperassign", output_fn + ".json")
        
        if not use_exist_assign:
            n_region = data_module.n_region
            center_mask = F.one_hot(torch.arange(n_region, device=self.device), num_classes=n_region)
            
            masking_logits = self.mask_model(None, center_mask)
            masking_probs = masking_logits.softmax(-1)

            topk = True
            if topk:
                masking_probs = masking_probs[..., 1]
                thresh = torch.topk(masking_probs.view(-1), int(k * n_region), dim=-1)[0][-1]
                token_mask = masking_probs >= thresh
            else:
                token_mask = torch.argmax(masking_probs, -1)

            hyper = {}
            for c, ns in enumerate(list(token_mask)):
                ns = ns.nonzero(as_tuple=True)[0].tolist()
                hyper[str(c)] = ns
            json.dump(hyper, open(json_path, "w"))

        else:
            hyper = json.load(open(json_path, "r"))
        
        return construct_hypergraph(data_module.h5, hyper, log_dir, output_fn)

    
from torchmetrics import Metric

class CorrMetric(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("preds", default=[])
        self.add_state("target", default=[])

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        self.preds.append(preds.detach())
        self.target.append(target.detach())

    def compute(self):
        if len(self.preds) == 0:
            return torch.tensor(0)
        preds = torch.cat(self.preds)
        target = torch.cat(self.target)
        assert preds.shape == target.shape
        
        vx = preds - torch.mean(preds)
        vy = target - torch.mean(target)
        corr = torch.sum(vx * vy) / (1e-7 + torch.norm(vx) * torch.norm(vy))
        return corr


class CorrLoss(nn.Module):
    def __init__(self, eps: float = 1e-7):
        super().__init__()
        self.eps = eps

    def forward(self, y_hat, y):
        x = y_hat
        vx = x - torch.mean(x)
        vy = y - torch.mean(y)

        loss = - torch.sum(vx * vy) / (self.eps + torch.norm(vx) * torch.norm(vy))
        return loss

@LightningModule.register("regressor")
class Regressor(LightningModule):
    def __init__(
        self,
        model: Model,
        learning_rate: float,
        mat_save_path: str,
        beta: float = 0.0,
        node_entropy: Dict[str, np.ndarray] = None,
        ) -> None:
        super().__init__()
        
        self.model = model
        self.learning_rate = learning_rate
        self.beta = beta
        self.mat_save_path = mat_save_path

        if node_entropy is not None:
            for key, value in node_entropy.items():
                self.register_buffer("entropy_" + key, torch.from_numpy(value/4.))

        self.criterion = nn.L1Loss()
        self.train_metric = CorrMetric()
        self.val_metric = CorrMetric()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

        def lr_scaler(epoch):
            warmup_epoch = -10
            if epoch < warmup_epoch:
                # warm up lr
                lr_scale = epoch/warmup_epoch
            else:
                lr_scale = 1.

            return lr_scale

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lr_scaler
        )
        # scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=20, max_epochs=self.trainer.max_epochs)

        return [optimizer], [scheduler]
    
    def forward(self, batch, is_predicting: bool = False):
        x = batch['x']
        meta = batch['meta']

        if isinstance(meta, list):
            task_key = meta[0]['task']
        elif isinstance(meta, dict):
            task_key = meta['task'][0]
        else:
            raise NotImplementedError
        
        if self.model.__class__.__name__ == "BrainNetworkTransformer":
            outputs = self.model((None, batch['x'], batch['y']))
            assert 'loss' not in outputs
        elif self.model.__class__.__name__ == "BrainGNN":
            outputs = self.model(batch, self.criterion)
            assert 'loss' in outputs
        elif self.model.__class__.__name__ == "HGSL":
            outputs = self.model(batch['x'], batch['y'])
            batch['y'] = outputs['trues']
        else:
            outputs = self.model(x, task_key, is_predicting)
            assert 'loss' not in outputs

        y_hat = outputs['preds']
        last = outputs['last']
        mask = outputs.get('mask', None)
        mask_logits = outputs.get('mask_logits', None)
        assert y_hat.dim() == 1

        y = batch['y']
        assert y.dim() == 1

        if 'loss' in outputs:
            max_term = - outputs['loss']
        else:
            max_term = - self.criterion(y_hat, y)

        if meta[0]['dataset_name'] == "ABCD1" and meta[0]['task'] == 'Rest':
            max_term = max_term * 1.5

        do_min_term = outputs.get("min_term", False)
        if do_min_term:
            density = mask.mean()
            node_entropy = getattr(self, "entropy_" + task_key)
            min_term = (torch.sigmoid(mask_logits.squeeze()) * node_entropy).mean()
        else:
            density = -1
            min_term = 0
        
        loss = - max_term + self.beta * min_term

        return (loss, max_term, min_term), (y_hat, y, last), density

    def training_step(self, batch, batch_idx):
        # if self._trainer.current_epoch == 1 and batch_idx == 1:
        #     import ipdb; ipdb.set_trace()
        (loss, max_term, min_term), (y_hat, y, _), density = self.forward(batch)
        self.train_metric(y_hat, y)
        self.log("train/loss", loss)
        self.log("train/density", density)
        self.log("train/max_term", max_term)
        self.log("train/min_term", min_term)

        return loss

    def on_train_epoch_end(self):
        self.log("train/corr", self.train_metric)
    
    def validation_step(self, batch, batch_idx):
        (loss, max_term, min_term), (y_hat, y, _), density = self.forward(batch)
        self.val_metric(y_hat, y)
        self.log("val/loss", loss)
        self.log("val/density", density)
        self.log("val/max_term", max_term)
        self.log("val/min_term", min_term)
        
    def on_validation_epoch_end(self):
        self.log("val/corr", self.val_metric)
        
    def on_predict_start(self):
        num_dataloaders = 8
        self.preds = defaultdict(list)
        self.trues = defaultdict(list)
    
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch['x']
        bs = x.size(0)
        seq_len = x.size(-1)
        
        _, (y_hat, y, last), _ = self(batch, True)
        y_hats = last
        
        meta = batch['meta']

        if isinstance(meta, list):
            meta = meta[0]
        elif isinstance(meta, dict):
            meta = { k: v[0] for k, v in meta.items() }
        task: str = meta['task']
        dataset_name: str = meta['dataset_name']
        
        self.preds[dataset_name + '_' + task].append(y_hats.cpu().numpy())
        self.trues[dataset_name + '_' + task].append(batch['y'].cpu().numpy())    
        
    def on_predict_end(self) -> None:        
        for key in self.preds.keys():
            preds = self.preds[key]
            trues = self.trues[key]
            preds = np.concatenate(preds)
            # preds = 2 * ((preds - preds.min(0)) / (preds.max(0) - preds.min(0)) - 0.5)
            preds = (preds - preds.mean()) / (1e-7 + preds.std())
            trues = np.concatenate(trues)
            # (*, n)
            preds = np.transpose(preds)
            # preds = np.ones_like(preds) * trues
            X = np.concatenate([preds, -preds], axis=0)
            # (n, 1)
            Y = trues[..., None]
            savemat(self.mat_save_path + f"_{key}.mat", dict(rest_1_mats=X, PMAT_CR=Y))