import sys
import torch
import numpy as np
from einops import rearrange

from . import Filler
from ..nn.utils.awp import AWP
from lib.utils import parse_mask_ratio

class GCNDecFiller(Filler):
    def __init__(self,
                 model_class,
                 model_kwargs,
                 optim_class,
                 optim_kwargs,
                 loss_fn,
                 use_awp,
                 known_mask_ratio,
                 mask_decay,
                 domain_adaptation,
                 scaled_target=False,
                 whiten_prob=0.05,
                 metrics=None,
                 scheduler_class=None,
                 scheduler_kwargs=None,
                 preemptive=False):
        super(GCNDecFiller, self).__init__(model_class=model_class,
                                           model_kwargs=model_kwargs,
                                           optim_class=optim_class,
                                           optim_kwargs=optim_kwargs,
                                           loss_fn=loss_fn,
                                           scaled_target=scaled_target,
                                           whiten_prob=whiten_prob,
                                           metrics=metrics,
                                           scheduler_class=scheduler_class,
                                           scheduler_kwargs=scheduler_kwargs)
        self.preemptive = preemptive
        self.use_awp = use_awp
        if self.use_awp:
            self.automatic_optimization = False
            self.gradient_clip_val = None
            self.gradient_clip_algorithm = None
        self.known_mask_ratio = parse_mask_ratio(known_mask_ratio)
        self.mask_decay = mask_decay
        self.domain_adaptation = domain_adaptation

    def _get_perturbed_adjacency(self, batch_data):
        """
        Generate perturbed adjacency matrix for current batch
        """
        # Get dataset from trainer
        if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'torch_dataset'):
            dataset = self.trainer.datamodule.torch_dataset.dataset
            if hasattr(dataset, 'get_perturbed_adj_and_position'):
                try:
                    # Get model arguments for threshold
                    adj_threshold = getattr(self.hparams, 'adj_threshold', 0.1)
                    perturbation_strength = getattr(self.hparams, 'perturbation_strength', 0.5)
                    
                    _, perturbed_adj = dataset.get_perturbed_adj_and_position(
                        adj_threshold=adj_threshold,
                        perturbation_strength=perturbation_strength
                    )
                    return perturbed_adj
                except Exception as e:
                    print(f"Warning: Failed to generate perturbed adjacency matrix: {e}")
                    return None
        return None

    def training_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # To make the model inductive
        # => remove unobserved entries from input data and adjacency matrix
        # Get observed entries (nonzero masks across time)
        mask = batch_data["mask"]
        mask = rearrange(mask, "b s n 1 -> (b s) n")
        mask_sum = mask.sum(0)  # n
        known_set = torch.where(mask_sum > 0)[0].detach().cpu().numpy().tolist()
        # ratio = float(len(known_set) / mask_sum.shape[0])

        batch_data["known_set"] = known_set

        # Extract mask and target
        _ = batch_data.pop('val_mask', None)
        test_mask = batch_data.pop('test_mask', None)
        y = batch_data.pop('y')

        if self.domain_adaptation:
            test_mask = rearrange(test_mask, "b s n 1 -> (b s) n")
            test_mask_sum = test_mask.sum(0)  # n
            test_set = torch.where(test_mask_sum > 0)[0].detach().cpu().numpy().tolist()
            if len(test_set) != 0 and min(test_set) < max(known_set):
                raise ValueError("Test set should be larger than known set")

            keep_set = list(set(range(y.shape[2])) - set(test_set))
            batch_data["keep_set"] = keep_set

        if self.use_awp and self.current_epoch >= 1:
            self.awp.perturb()

        if self.mask_decay:
            max_decay_epoch = 100
            current_epoch = self.current_epoch
            mask_ratio = self.known_mask_ratio[1] - (self.known_mask_ratio[1] - self.known_mask_ratio[0]) * (current_epoch / max_decay_epoch)
            batch_data["mask_ratio"] = mask_ratio

        # Add node perturbation support
        if hasattr(self.hparams, 'use_node_perturbation') and self.hparams.use_node_perturbation:
            perturbed_adj = self._get_perturbed_adjacency(batch_data)
            if perturbed_adj is not None:
                batch_data['perturbed_adj'] = perturbed_adj

        if self.domain_adaptation:
            imputation, mask, known_set, imputation_i, mask_i = self.predict_batch(batch, preprocess=False, postprocess=False)
        else:
            # Compute predictions and compute loss
            imputation, mask, known_set = self.predict_batch(batch, preprocess=False, postprocess=False)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            imputation = self._postprocess(imputation, batch_preprocessing)
            if self.domain_adaptation:
                imputation_i = self._postprocess(imputation_i, batch_preprocessing)

        if not self.preemptive:
            target = target[:, :, known_set, :]
        loss = self.loss_fn(imputation, target, mask)
        if self.domain_adaptation:
            loss_i = self.loss_fn(imputation_i, target, mask_i)
            loss = loss + loss_i

        if self.use_awp:
            loss.backward()
            self.awp.restore()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            optimizer = self.optimizers()
            optimizer.step()
            optimizer.zero_grad()

        # Logging
        if self.scaled_target:
            imputation = self._postprocess(imputation, batch_preprocessing)
            target = self._postprocess(target, batch_preprocessing)
        self.train_metrics.update(imputation.detach(), target, mask)  # test_mask = mask, during training
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('train_loss', loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # To make the model inductive
        # => remove unobserved entries from input data and adjacency matrix

        # Get observed entries (nonzero masks across time)
        mask = batch_data["mask"]
        mask = rearrange(mask, "b s n 1 -> (b s) n")
        mask_sum = mask.sum(0)  # n
        known_set = torch.where(mask_sum > 0)[0].detach().cpu().numpy().tolist()
        # ratio = float(len(known_set) / mask_sum.shape[0])

        batch_data["known_set"] = known_set

        # Extract mask and target
        val_mask = batch_data.pop('val_mask', None)
        test_mask = batch_data.pop('test_mask', None)
        y = batch_data.pop('y')

        test_mask = rearrange(test_mask, "b s n 1 -> (b s) n")
        test_mask_sum = test_mask.sum(0)  # n
        test_set = torch.where(test_mask_sum > 0)[0].detach().cpu().numpy().tolist()
        if len(test_set) != 0 and min(test_set) < max(known_set):
            raise ValueError("Test set should be larger than known set")

        keep_set = list(set(range(y.shape[2])) - set(test_set))
        batch_data["keep_set"] = keep_set
        y = y[:, :, list(keep_set), :]
        val_mask = val_mask[:, :, list(keep_set), :]

        # Compute predictions and compute loss
        imputation = self.predict_batch(batch, preprocess=False, postprocess=False)

        if self.scaled_target:
            target = self._preprocess(y, batch_preprocessing)
        else:
            target = y
            imputation = self._postprocess(imputation, batch_preprocessing)

        val_loss = self.loss_fn(imputation, target, val_mask)

        # Logging
        if self.scaled_target:
            imputation = self._postprocess(imputation, batch_preprocessing)
        self.val_metrics.update(imputation.detach(), y, val_mask)
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        self.log('val_loss', val_loss.detach(), on_step=False, on_epoch=True, logger=True, prog_bar=False)
        return val_loss

    def test_step(self, batch, batch_idx):
        # Unpack batch
        batch_data, batch_preprocessing = self._unpack_batch(batch)

        # To make the model inductive
        # => remove unobserved entries from input data and adjacency matrix
        # Get observed entries (nonzero masks across time)
        mask = batch_data["mask"]
        mask = rearrange(mask, "b s n 1 -> (b s) n")
        mask_sum = mask.sum(0)  # n
        known_set = torch.where(mask_sum > 0)[0].detach().cpu().numpy().tolist()
        # ratio = float(len(known_set) / mask_sum.shape[0])

        batch_data["known_set"] = known_set

        # Extract mask and target
        val_mask = batch_data.pop('val_mask', None)
        test_mask = batch_data.pop('test_mask', None)
        y = batch_data.pop('y')

        val_mask = rearrange(val_mask, "b s n 1 -> (b s) n")
        val_mask_sum = val_mask.sum(0)  # n
        val_set = torch.where(val_mask_sum > 0)[0].detach().cpu().numpy().tolist()
        if len(val_set) != 0 and len(known_set) > 0 and min(val_set) < max(known_set):
            raise ValueError("Validation set should be larger than known set")
        keep_set = list(set(range(y.shape[2])) - set(val_set))
        batch_data["keep_set"] = keep_set
        y = y[:, :, list(keep_set), :]
        test_mask = test_mask[:, :, list(keep_set), :]

        # Compute outputs and rescale
        imputation = self.predict_batch(batch, preprocess=False, postprocess=True)
        test_loss = self.loss_fn(imputation, y, test_mask)

        # Logging
        self.test_metrics.update(imputation.detach(), y, test_mask)
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True, logger=True, prog_bar=True)
        return test_loss

    def configure_optimizers(self):
        cfg = dict()
        optimizer = self.optim_class(self.parameters(), **self.optim_kwargs)
        if self.use_awp:
            self.awp.optimizer = optimizer
        cfg['optimizer'] = optimizer
        if self.scheduler_class is not None:
            metric = self.scheduler_kwargs.pop('monitor', None)
            scheduler = self.scheduler_class(optimizer, **self.scheduler_kwargs)
            cfg['lr_scheduler'] = scheduler
            if metric is not None:
                cfg['monitor'] = metric
        return cfg

    def setup(self, stage=None):
        if self.use_awp:
            if not hasattr(self, 'awp') or self.awp is None:
                self.awp = AWP(
                    model=self.model,
                    optimizer=None,
                    adv_param=['fc_1.weight'],
                    adv_lr=0.001,
                    adv_eps=0.001
                )

    def training_epoch_end(self, outputs):
        if self.use_awp:
            scheduler = self.lr_schedulers()
            if isinstance(scheduler, dict):
                scheduler = scheduler['scheduler']
            scheduler.step()