import copy
import os
import random
import time
from functools import partial, wraps
from typing import Callable, List, Optional

import hydra
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import wandb
from hydra.utils import get_original_cwd
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from tqdm.auto import tqdm
from sklearn.metrics import f1_score, confusion_matrix
import src.models.nn.utils as U
import src.utils as utils
import src.utils.train
from src.dataloaders import SequenceDataset  # TODO make registry
from src.tasks import decoders, encoders, tasks
from src.utils import registry
from src.utils.optim.ema import build_ema_optimizer
from src.utils.optim_groups import add_optimizer_hooks
from clearml import Task, Logger
log = src.utils.train.get_logger(__name__)

# Turn on TensorFloat32 (speeds up large model training substantially)
import torch.backends
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from src.models.nn.normalization import RBN
from src.models.sequence.attention.mha import PointwizeMultiheadAttention,PointwizeMultiheadAttentionGelu,PointwizeMultiheadAttentionGelu2, MultiheadAttention

def ser_save(model,path):
    import dill
    with open(path, 'wb') as f:
        dill.dump(model, f)
def get_layers(model,lst=[],type_of_layer=RBN):
    cnt = 0
    for n,module in model.named_children():
        if len(list(module.children())) > 0:
            c_cnt,_ = get_layers(module, lst, type_of_layer)
            cnt = cnt + c_cnt
            if isinstance(module, type_of_layer):
                lst.append(module)
                cnt = cnt+1
    return cnt,lst

def get_attn_layers(model):
    _, attn_layers1 = get_layers(model, lst=[], type_of_layer=PointwizeMultiheadAttention)
    _, attn_layers2 = get_layers(model, lst=[], type_of_layer=PointwizeMultiheadAttentionGelu2)
    _, attn_layers3 = get_layers(model, lst=[], type_of_layer=PointwizeMultiheadAttentionGelu)
    _, attn_layers4 = get_layers(model, lst=[], type_of_layer=MultiheadAttention)
    return attn_layers1 + attn_layers2 + attn_layers3 + attn_layers4

class FullModel(nn.Module):
    def __init__(self, model, encoder=None ,num_classes=0, width=0, pool=False):
        super(FullModel, self).__init__()
        self.debug = False
        self.FHE=False
        self.encoder_only = False
        self.model = model
        self.encoder = encoder
        self.save_input = False
        self.pool = pool
        if num_classes >0:
            self.use_classification_head = True
            self.classification_head = torch.nn.Linear(width, num_classes)
            self.classification_head.eval()
        else:
            self.use_classification_head = False
        self.model.eval()
        if self.encoder is not None:
            self.encoder.eval()

    def forward(self, x):
        if self.encoder_only:
            x = self.encoder[0](x)
            x = self.encoder[1](x)
            x = self.encoder[2](x)
            return x

        #if self.FHE:
            #if self.encoder is not None:
                #x = self.encoder[1](x)
                #x = self.encoder[2](x)
        #else:
        if not self.FHE:
            if self.encoder is not None:
                #x, _ = self.encoder(x)
                x = self.encoder[0](x)
                x = self.encoder[1](x)
                x = self.encoder[2](x)
        
        if self.save_input: self.inputs_onnx_encoded = x.detach()
        x, state = self.model(x)
        if not self.use_classification_head:
            if self.pool:
                return x.mean(-2)
            else:
                return x
        else:
            return self.classification_head(x.mean(-2))
  
            
    
class RBNLoss(nn.Module):
    def __init__(self, original_loss, rbn_layers, var_coef, mean_coef):
        super().__init__()
        self.original_loss = original_loss
        self.rbn_layers = rbn_layers
        self.var_coef = var_coef
        self.mean_coef = mean_coef

    def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs):
        og_loss = self.original_loss(hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs)
        rbn_loss = sum([self.var_coef*rbn_layer.var_loss + self.mean_coef*rbn_layer.mean_loss for rbn_layer in self.rbn_layers])
        if rbn_loss*10 > og_loss:
            scaler= ((og_loss/rbn_loss)).detach()*0.1
            total_loss = og_loss + (rbn_loss)*scaler
        else: 
            total_loss = og_loss + rbn_loss
        return total_loss


class RangeLoss(nn.Module):
    def __init__(self, original_loss, activation_layers=[], loss_coef=0.0, ln_layers=[], pwr_layers=[], ln_loss=False, ln_loss_coef=0.0, wrapper=False):
        super().__init__()
        self.original_loss = original_loss
        self.activation_layers = activation_layers
        self.loss_coef = loss_coef
        self.ln_loss = ln_loss
        self.ln_layers = ln_layers
        self.pwr_layers = pwr_layers
        self.ln_loss_coef = ln_loss_coef
        self.wrapper = wrapper
        print("Create Ranges loss with : wrapper_mode=", self.wrapper, " | use_ln_loss:", ln_loss, " | loss coef " , self.loss_coef, "mun layers=()" , len(self.activation_layers), len(self.ln_layers), len(self.pwr_layers),")"   )     
    def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs):
        og_loss = self.original_loss(hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs)
        if self.wrapper:
            return og_loss
        range_loss = sum([act_layer.get_loss() for act_layer in self.activation_layers])
        # Automatic Scaling
        # if range_loss*10 > og_loss:
        #     scaler= ((og_loss/range_loss)).detach()*0.1
        #     total_loss = og_loss + (range_loss)*scaler
        # else: 
        total_loss = og_loss + (self.loss_coef*range_loss)
        if self.ln_loss:
            layer_norm_loss = sum([ln_layer.loss for ln_layer in self.ln_layers])
            pwr_norm_loss = sum([pwr_layer.loss for pwr_layer in self.pwr_layers])
            total_loss = total_loss + (self.ln_loss_coef*layer_norm_loss) + (self.ln_loss_coef*pwr_norm_loss)
        return total_loss



def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Lots of annoying hacks to get WandbLogger to continuously retry on failure
class DummyExperiment:
    """Dummy experiment."""

    def nop(self, *args, **kw):
        pass

    def __getattr__(self, _):
        return self.nop

    def __getitem__(self, idx) -> "DummyExperiment":
        # enables self.logger.experiment[0].add_image(...)
        return self

    def __setitem__(self, *args, **kwargs) -> None:
        pass


def rank_zero_experiment(fn: Callable) -> Callable:
    """Returns the real experiment on rank 0 and otherwise the DummyExperiment."""

    @wraps(fn)
    def experiment(self):
        @rank_zero_only
        def get_experiment():
            return fn(self)

        return get_experiment() or DummyExperiment()

    return experiment


class CustomWandbLogger(WandbLogger):

    def __init__(self, *args, **kwargs):
        """Modified logger that insists on a wandb.init() call and catches wandb's error if thrown."""

        super().__init__(*args, **kwargs)

    @property
    @rank_zero_experiment
    def experiment(self):
        r"""
        Actual wandb object. To use wandb features in your
        :class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
        Example::
        .. code-block:: python
            self.logger.experiment.some_wandb_function()
        """
        if self._experiment is None:
            if self._offline:
                os.environ["WANDB_MODE"] = "dryrun"

            attach_id = getattr(self, "_attach_id", None)
            if wandb.run is not None:
                # wandb process already created in this instance
                rank_zero_warn(
                    "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
                    " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
                )
                self._experiment = wandb.run
            elif attach_id is not None and hasattr(wandb, "_attach"):
                # attach to wandb process referenced
                self._experiment = wandb._attach(attach_id)
            else:
                # create new wandb process
                while True:
                    try:
                        self._experiment = wandb.init(**self._wandb_init)
                        break
                    except Exception as e:
                        print("wandb Exception:\n", e)
                        t = random.randint(30, 60)
                        print(f"Sleeping for {t} seconds")
                        time.sleep(t)

                # define default x-axis
                if getattr(self._experiment, "define_metric", None):
                    self._experiment.define_metric("trainer/global_step")
                    self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)

        return self._experiment


class SequenceLightningModule(pl.LightningModule):
    def __init__(self, config):
        # Disable profiling executor. This reduces memory and increases speed.
        try:
            torch._C._jit_set_profiling_executor(False)
            torch._C._jit_set_profiling_mode(False)
        except AttributeError:
            pass

        super().__init__()


        ##### Regulrization Batch Norm Initilization:
        if "RBN" in config:
            if config.RBN:
                self.use_RBN_loss = True
                self.var_coef = config.get("RBN_V", 0.0)
                self.mean_coef = config.get("RBN_M", 0.0)
            else:
                self.use_RBN_loss = False
        else:
            self.use_RBN_loss = False

        if ("LNMin" in config) and ("LNMax" in config):
            self.ln_min = config.get("LNMin", 0.5)
            self.ln_max = config.get("LNMax", 12.0)
            self.replace_ln_to_poly = True
        else:
            self.replace_ln_to_poly = False
        self.act_range = config.get("PactRange", None)
        if "RL" in config:
            if config.RL:
                self.use_range_loss = True
                self.should_config_optimizer = True
                self.range_loss_coef = config.get("RLC", 0.1)
                if "LNL" in config:
                    self.use_ln_loss = True
                    if "LNLC" in config:
                        self.ln_loss_coef = config.get("LNLC", 0.1) # LayerNprm Loss Coef
                    else:
                       self.ln_loss_coef = 0.1
                else:
                    self.use_ln_loss = False
                    self.ln_loss_coef = 0.0
            else:
                self.use_range_loss = False
                self.use_ln_loss = False
        else:
            self.use_range_loss = False

        if "BERTtokenizer" in config:
            if config.BERTtokenizer:
                self.bert_tokenizer = True
            else:
                self.bert_tokenizer = False
        else:
                self.bert_tokenizer = False

        if "PolyAct" in config:
            if config.PolyAct:
                self.replace_poly_act = True
            else:
                self.replace_poly_act = False
        else:
            self.replace_poly_act = False

        if "WM" in config:
            if config.WM:
                self.without_monitoring_ranges = True
            else:
                self.without_monitoring_ranges = False
        else:
            self.without_monitoring_ranges = False
        if "save_per_epoch" in config:
            if config.save_per_epoch:
                self.save_per_epoch = True
            else:
                self.save_per_epoch = False
        else:
            self.save_per_epoch = False

        if "prompt" in config:
            self.prompt = config.prompt
            self.prompt_topk = config.get("topk", 20)
            self.prompt_options = config.get("prompt_options", None)
        else:
            self.prompt = None

        if "BDFT" in config: #Band Data FineTunning
            self.BDFT = config.BDFT
        else:
            self.BDFT = False

        if "save_attn_mat" in config: 
            self.save_attn_mat = config.save_attn_mat
        else:
            self.save_attn_mat = False

        if "LN_aprrox_iters" in config:
            self.LN_aprrox_iters = config.LN_aprrox_iters
        else:
            self.LN_aprrox_iters = 20

        
        # Passing in config expands it one level, so can access by self.hparams.train instead of self.hparams.config.train
        self.save_hyperparameters(config, logger=False)

        # Dataset arguments
        self.dataset = SequenceDataset.registry[self.hparams.dataset._name_](
            **self.hparams.dataset
        )
        self.dataset.bert_tokenizer = self.bert_tokenizer

        # Check hparams
        self._check_config()

        # PL has some bugs, so add hooks and make sure they're only called once
        self._has_setup = False
        self.config = config
        self.setup()  ## Added by KS
        self.epoch = config.get("epoch", -1)
        
        if self.BDFT:
            #if "clrml" in config:
            if "clrml" in config:
                if config.clrml:
                    folder_name = config.get("clrml_folder", "HE-friendly-wikitext103") 
                    task = Task.init(project_name="HE-friendly-Attention/" + folder_name, task_name=config.clrml_name)
                    print("Initialize clearml on: ", "HE-friendly-Attentioni/" + folder_name + "/" + config.clrml_name)
                    self.clr_logger = task.get_logger()
                    self.clr_parameters = task.connect(OmegaConf.to_container(config))
                else:
                    self.clr_logger = None
            else:
                self.clr_logger = None
            device = "cuda"
            from src.dataloaders.datasets.bank_hf_data import get_bank_dataloaders
            dataloader_train, dataloader_val = get_bank_dataloaders(batch_size=8)
            model = FullModel(self.model, self.encoder, num_classes=3, width=config.model.d_model).to(device)
            if self.config["train"]["ckpt"] is not None:
                if self.config["train"]["ckpt"] != "":
                    print("load cp from" ,self.config["train"]["ckpt"])
                    cp = torch.load(self.config["train"]["ckpt"])
                    model.load_state_dict(cp['state_dict'],strict=False)
            model.to(device)
            for attn_layer in get_attn_layers(self.model):
                attn_layer.causal = False
            print(model)
            

            epochs = self.config.trainer["max_epochs"]#18#
            criteria = torch.nn.CrossEntropyLoss()
            weight_decay = 0.01
            optimizer = torch.optim.SGD(model.parameters(), lr=config.optimizer.lr, weight_decay=weight_decay)  # Just as an example
            
            for epoch in range(epochs):
                for act_layer in self.poly_activations:
                    act_layer.reset_ranges()
                for norm_layer in self.ln_lyers_for_monitoring :
                    norm_layer.reset_stat()
                for pwr_layer in self.powersoftmax_lyers_for_monitoring :
                    pwr_layer.reset_stat()
    
                    
                # Training
                model.train()
                train_loss = 0
                correct_predictions_train = 0
                total_train = 0
                optimizer.zero_grad()
                all_preds = []
                all_labels = []
                for step, batch in enumerate(tqdm(dataloader_train)):
                    # Unpack this training batch from our dataloader
                    b_input_ids = batch['input_ids'].to(device)
                    b_labels = batch['labels'].to(device)
                    model.zero_grad()
                    outputs = model(b_input_ids)
                    loss =  criteria(outputs,b_labels)
                    train_loss += loss.item()
                    loss.backward()
                    #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  
                    optimizer.step()
                    optimizer.zero_grad()
                    preds = torch.argmax(outputs, dim=1)
                    correct_predictions_train += (preds == b_labels).sum().item()
                    total_train += b_labels.size(0)
                    preds = preds.cpu().numpy()
                    labels = b_labels.cpu().numpy()
                    all_preds.extend(preds)
                    all_labels.extend(labels)
                # Calculate the average loss over the training data
                avg_train_loss = train_loss / len(dataloader_train)
                train_accuracy = correct_predictions_train / total_train
                print(f"Epoch {epoch+1}/{epochs} - Training Loss: {avg_train_loss}")
                print(f"Epoch {epoch+1}/{epochs} - Training Accuracy: {train_accuracy:.2f}")
                cm = confusion_matrix(all_labels, all_preds)
                print(f"Epoch {epoch+1}/{epochs} - Training Confusion Matrix:")
                print(cm)
                prefix = "train"
                self.clr_logger.report_scalar(title="loss", series=prefix , iteration=epoch,  value=avg_train_loss)
                self.clr_logger.report_scalar(title="accuracy", series=prefix , iteration=epoch,  value=f"{train_accuracy:.2f}")   
                for cnt,layer in enumerate(self.ln_lyers_for_monitoring):
                    dict_result = layer.get_stat()
                    global_ln_min = dict_result["min"].item() if cnt == 0 else min(global_ln_min, dict_result["min"].item()) 
                    global_ln_max = dict_result["max"].item() if cnt == 0 else max(global_ln_max, dict_result["max"].item()) 
                    self.clr_logger.report_scalar(title="ln-min", series=prefix+ "-" + str(cnt), iteration=epoch,  value=dict_result["min"].item())
                    self.clr_logger.report_scalar(title="ln-max", series=prefix+ "-" + str(cnt), iteration=epoch, value=dict_result["max"].item())
                    self.clr_logger.report_scalar(title="ln-mean", series=prefix+ "-" + str(cnt), iteration=epoch, value=dict_result["mean"].item())
                if len(self.ln_lyers_for_monitoring) > 0 :
                    self.clr_logger.report_scalar(title="Global Stat LN", series=prefix+"-min", iteration=epoch, value=global_ln_min)
                    self.clr_logger.report_scalar(title="Global Stat LN", series=prefix+"-max", iteration=epoch, value=global_ln_max)

                for cnt,pwrlayer in enumerate(self.powersoftmax_lyers_for_monitoring):
                    dict_result_pwr = pwrlayer.get_stat()
                    global_pwr_min = dict_result_pwr["min-div"].item() if cnt == 0 else min(global_pwr_min, dict_result_pwr["min-div"].item()) 
                    global_pwr_max = dict_result_pwr["max-div"].item() if cnt == 0 else max(global_pwr_max, dict_result_pwr["max-div"].item()) 
                    self.clr_logger.report_scalar(title="PwrSftMax-div-min", series=prefix+ "-" + str(cnt), iteration=epoch,  value=dict_result_pwr["min-div"].item())
                    self.clr_logger.report_scalar(title="PwrSftMax-div-max", series=prefix+ "-" + str(cnt), iteration=epoch, value=dict_result_pwr["max-div"].item())
                    self.clr_logger.report_scalar(title="PwrSftMax-div-mean", series=prefix+ "-" + str(cnt), iteration=epoch, value=dict_result_pwr["mean-div"].item())
                    self.clr_logger.report_scalar(title="PwrSftMax-scores-max", series=prefix+ "-" + str(cnt), iteration=epoch, value=dict_result_pwr["max-score"].item())
                    self.clr_logger.report_scalar(title="PwrSftMax-scores-min", series=prefix+ "-" + str(cnt), iteration=epoch, value=dict_result_pwr["min-score"].item())
                    
                    global_min_pwr_score = dict_result_pwr["min-score"].item() if cnt == 0 else min(global_min_pwr_score, dict_result_pwr["min-score"].item()) 
                    global_max_pwr_score = dict_result_pwr["max-score"].item() if cnt == 0 else max(global_max_pwr_score, dict_result_pwr["max-score"].item()) 
                if len(self.powersoftmax_lyers_for_monitoring) > 0 :
                    self.clr_logger.report_scalar(title="Global Stat PwrSftMax-div", series=prefix+"-min", iteration=epoch, value=global_pwr_min)
                    self.clr_logger.report_scalar(title="Global Stat PwrSftMax-div", series=prefix+"-max", iteration=epoch, value=global_pwr_max)
                    self.clr_logger.report_scalar(title="Global Stat PwrSftMax-score", series=prefix+"-min", iteration=epoch, value=global_min_pwr_score)
                    self.clr_logger.report_scalar(title="Global Stat PwrSftMax-score", series=prefix+"-max", iteration=epoch, value=global_max_pwr_score)
                        

                
                # Monitor activations:
                for cnt,layer in enumerate(self.poly_activations):
                    min_val, max_val = layer.get_min_max()
                    global_min = min_val if cnt == 0 else min(global_min, min_val) 
                    global_max = max_val if cnt == 0 else max(global_max, max_val) 
    
                    self.clr_logger.report_scalar(title="Range per Layer", series=prefix+"-min-"+str(cnt), iteration=epoch, value=min_val)
                    self.clr_logger.report_scalar(title="Range per Layer", series=prefix+"-max-"+str(cnt), iteration=epoch, value=max_val)
                
                if len(self.poly_activations) > 0 :
                    self.clr_logger.report_scalar(title="Global Ranges", series=prefix+"-min", iteration=epoch, value=global_min)
                    self.clr_logger.report_scalar(title="Global Ranges", series=prefix+"-max", iteration=epoch, value=global_max)
                    
                # Evaluation
                model.eval()
                for act_layer in self.poly_activations:
                    act_layer.reset_ranges()
                for norm_layer in self.ln_lyers_for_monitoring :
                    norm_layer.reset_stat()
                for pwr_layer in self.powersoftmax_lyers_for_monitoring :
                    pwr_layer.reset_stat()
                  
                last_epoch = epoch == (epochs-1)
                if last_epoch:
                    print("Num of act layers", len(self.poly_activations))
                    for act_layer in self.poly_activations:
                        #act_layer.replace_to_poly(range_val=self.act_range)
                        if act_layer.act_type=="relu":
                            act_layer.replace_to_poly(range_val=10)
                            print("Num of ln layers", len(self.ln_lyers_for_monitoring))
                    for i,ln_layer in enumerate(self.ln_lyers_for_monitoring):
                        print("Replace LN to Poly")
                        ln_layer.debug = True
                        ln_layer.layer_idx = i
                        ln_layer.replace_to_poly(degree=self.LN_aprrox_iters)
                    print("Num of PowerSoftmax layers", len(self.powersoftmax_lyers_for_monitoring))
                    for i,pwr_layer in enumerate(self.powersoftmax_lyers_for_monitoring):
                        pwr_layer.poly_div = True
                        print("Replace to Poly PowerSoftmax", "pwr_layer.poly_div=",pwr_layer.poly_div)
                        pwr_layer.scalePowerSoftmaxbyCons = 1
                        pwr_layer.stable = True
                        pwr_layer.debug = True
                        pwr_layer.layer_idx = i
            
                if last_epoch:
                    model.debug = True
                with torch.no_grad():
                    
                    test_loss = 0.0
                    correct_predictions_val = 0
                    total_val = 0
                    all_preds = []
                    all_labels = []
                    
                    import h5py
                    model.save_input = True
                    path_h5 = 'samples.h5'
                    with h5py.File(path_h5, 'w') as hf:
                        H5all_input_ids = []
                        H5all_labels = []
                        H5all_logits = []
                        for step, batch in enumerate(tqdm(dataloader_val)):
                            if last_epoch:
                                if step == 1:
                                    for i,pwr_layer in enumerate(self.powersoftmax_lyers_for_monitoring):
                                        pwr_layer.debug = False
                                    for i,ln_layer in enumerate(self.ln_lyers_for_monitoring):
                                        ln_layer.debug = False
                                    model.debug = False
                            # Unpack this training batch from our dataloader
                            b_input_ids = batch['input_ids'].to(device)
                            b_labels = batch['labels'].to(device)
                            if last_epoch:
                                if step == 0:
                                    b_input_ids = b_input_ids[:1,...]
                                    b_labels = b_labels[:1,...]
                            
                            outputs = model(b_input_ids)
                            loss = criteria(outputs,b_labels)
                            test_loss += loss.item()
                            preds = torch.argmax(outputs, dim=1)
                            correct_predictions_val += (preds == b_labels).sum().item()
                            total_val += b_labels.size(0)
                            preds = preds.cpu().numpy()
                            labels = b_labels.cpu().numpy()
                            all_preds.extend(preds)
                            all_labels.extend(labels)
                            # H5:
                            if last_epoch:
                                H5all_input_ids.append(model.inputs_onnx_encoded.detach().cpu().numpy())
                                H5all_labels.append(labels)
                                H5all_logits.append(outputs.detach().cpu().numpy())  
                        avg_test_loss = test_loss / len(dataloader_val)
                        print(f"Epoch {epoch+1}/{epochs} - Val Loss: {avg_test_loss}")
                        val_accuracy = correct_predictions_val / total_val
                        print(f"Epoch {epoch+1}/{epochs} - Validation Accuracy: {val_accuracy:.2f}")
                        cm = confusion_matrix(all_labels, all_preds)
                        print(f"Epoch {epoch+1}/{epochs} - Validation Confusion Matrix:")
                        print(cm)
                        prefix = "test"
                        self.clr_logger.report_scalar(title="loss", series=prefix , iteration=epoch,  value=avg_test_loss)
                        self.clr_logger.report_scalar(title="accuracy", series=prefix , iteration=epoch,  value=f"{val_accuracy:.2f}")
                        for cnt,layer in enumerate(self.ln_lyers_for_monitoring):
                            dict_result = layer.get_stat()
                            global_ln_min = dict_result["min"].item() if cnt == 0 else min(global_ln_min, dict_result["min"].item()) 
                            global_ln_max = dict_result["max"].item() if cnt == 0 else max(global_ln_max, dict_result["max"].item()) 
                            self.clr_logger.report_scalar(title="ln-min", series=prefix+ "-" + str(cnt), iteration=epoch,  value=dict_result["min"].item())
                            self.clr_logger.report_scalar(title="ln-max", series=prefix+ "-" + str(cnt), iteration=epoch, value=dict_result["max"].item())
                            self.clr_logger.report_scalar(title="ln-mean", series=prefix+ "-" + str(cnt), iteration=epoch, value=dict_result["mean"].item())
                        if len(self.ln_lyers_for_monitoring) > 0 :
                            self.clr_logger.report_scalar(title="Global Stat LN", series=prefix+"-min", iteration=epoch, value=global_ln_min)
                            self.clr_logger.report_scalar(title="Global Stat LN", series=prefix+"-max", iteration=epoch, value=global_ln_max)

                        for cnt,pwrlayer in enumerate(self.powersoftmax_lyers_for_monitoring):
                            pwr_dict_result = pwrlayer.get_stat()
                            global_pwr_min = pwr_dict_result["min-div"].item() if cnt == 0 else min(global_pwr_min, pwr_dict_result["min-div"].item()) 
                            global_pwr_max = pwr_dict_result["max-div"].item() if cnt == 0 else max(global_pwr_max, pwr_dict_result["max-div"].item()) 
                            self.clr_logger.report_scalar(title="PwrSftMax-div-min", series=prefix+ "-" + str(cnt), iteration=epoch,  value=pwr_dict_result["min-div"].item())
                            self.clr_logger.report_scalar(title="PwrSftMax-div-max", series=prefix+ "-" + str(cnt), iteration=epoch, value=pwr_dict_result["max-div"].item())
                            self.clr_logger.report_scalar(title="PwrSftMax-div-mean", series=prefix+ "-" + str(cnt), iteration=epoch, value=pwr_dict_result["mean-div"].item())
                            self.clr_logger.report_scalar(title="PwrSftMax-scores-max", series=prefix+ "-" + str(cnt), iteration=epoch, value=pwr_dict_result["max-score"].item())
                            self.clr_logger.report_scalar(title="PwrSftMax-scores-min", series=prefix+ "-" + str(cnt), iteration=epoch, value=pwr_dict_result["min-score"].item())
                            global_min_pwr_score = pwr_dict_result["min-score"].item() if cnt == 0 else min(global_min_pwr_score, pwr_dict_result["min-score"].item()) 
                            global_max_pwr_score = pwr_dict_result["max-score"].item() if cnt == 0 else max(global_max_pwr_score, pwr_dict_result["max-score"].item()) 
                            
                        if len(self.powersoftmax_lyers_for_monitoring) > 0 :
                            self.clr_logger.report_scalar(title="Global Stat PwrSftMax-div", series=prefix+"-min", iteration=epoch, value=global_pwr_min)
                            self.clr_logger.report_scalar(title="Global Stat PwrSftMax-div", series=prefix+"-max", iteration=epoch, value=global_pwr_max)
                            self.clr_logger.report_scalar(title="Global Stat PwrSftMax-score", series=prefix+"-min", iteration=epoch, value=global_min_pwr_score)
                            self.clr_logger.report_scalar(title="Global Stat PwrSftMax-score", series=prefix+"-max", iteration=epoch, value=global_max_pwr_score)
                        


                        # Monitor activations:
                        for cnt,layer in enumerate(self.poly_activations):
                            min_val, max_val = layer.get_min_max()
                            global_min = min_val if cnt == 0 else min(global_min, min_val) 
                            global_max = max_val if cnt == 0 else max(global_max, max_val) 
            
                            self.clr_logger.report_scalar(title="Range per Layer", series=prefix+"-min-"+str(cnt), iteration=epoch, value=min_val)
                            self.clr_logger.report_scalar(title="Range per Layer", series=prefix+"-max-"+str(cnt), iteration=epoch, value=max_val)
                        
                        if len(self.poly_activations) > 0 :
                            self.clr_logger.report_scalar(title="Global Ranges", series=prefix+"-min", iteration=epoch, value=global_min)
                            self.clr_logger.report_scalar(title="Global Ranges", series=prefix+"-max", iteration=epoch, value=global_max)

                        if last_epoch:
                            H5all_input_ids = np.concatenate(H5all_input_ids, axis=0)
                            H5all_labels = np.concatenate(H5all_labels, axis=0)
                            H5all_logits = np.concatenate(H5all_logits, axis=0)
                            hf.create_dataset('input_ids', data=H5all_input_ids)
                            print("shapes", H5all_input_ids.shape,H5all_labels.shape,H5all_logits.shape)
                            hf.create_dataset('labels', data=H5all_labels)
                            hf.create_dataset('logits', data=H5all_logits)
                            model.save_input = False
                            print("saved h5 file to",path_h5)
          
            randomTokens = torch.randint(low=0, high=100, size=(1, self.config.loader['l_max']), dtype=torch.int64).cuda()
            from src.models.nn.normalization import CLN 
            _, ln_lyers_for_monitoring = get_layers(model, lst=[], type_of_layer=CLN)
            for ln_layer in ln_lyers_for_monitoring:
                ln_layer.no_monitor = True

            model.eval()
            with torch.no_grad():
                out = model(randomTokens)
            create_onnx_model_and_encoder(model, self.onnxPath, export_params=True, tokens=randomTokens.cpu(), withoutcopy=True)
            print(operator_count(self.onnxPath))
            raise ValueError("finetunning done")


    def setup(self, stage=None):
        if not self.hparams.train.disable_dataset:
            self.dataset.setup()

        # We need to set up the model in setup() because for some reason when training with DDP, one GPU uses much more memory than the others
        # In order to not overwrite the model multiple times during different stages, we need this hack
        # TODO PL 1.5 seems to have an option to skip hooks to avoid this
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/5410#issuecomment-762257024
        if self._has_setup:
            return
        else:
            self._has_setup = True

        # Convenience feature: if model specifies encoder, combine it with main encoder
        encoder_cfg = utils.to_list(self.hparams.encoder) + utils.to_list(
            self.hparams.model.pop("encoder", None)
        )
        decoder_cfg = utils.to_list(
            self.hparams.model.pop("decoder", None)
        ) + utils.to_list(self.hparams.decoder)

        # Instantiate model
        self.model = utils.instantiate(registry.model, self.hparams.model)
        from src.models.nn.normalization import CLN 
        _, self.ln_lyers_for_monitoring = get_layers(self.model, lst=[], type_of_layer=CLN)

        from src.models.sequence.attention.mha import PointwizeMultiheadAttention 
        _, self.attentoin_lyers_for_monitoring = get_layers(self.model, lst=[], type_of_layer=PointwizeMultiheadAttention)
        self.powersoftmax_lyers_for_monitoring = [layer.mha.multihead_attention.attention_layer.act for layer in(self.attentoin_lyers_for_monitoring)]
        
        if (name := self.hparams.train.post_init_hook['_name_']) is not None:
            kwargs = self.hparams.train.post_init_hook.copy()
            del kwargs['_name_']
            for module in self.modules():
                if hasattr(module, name):
                    getattr(module, name)(**kwargs)
        
        from poly_utils import replace_activations, polyAct
        if not self.without_monitoring_ranges:
            num_activations_replaced = replace_activations(self.model)
  
        _, self.poly_activations = get_layers(self.model, lst=[], type_of_layer=polyAct)
        if not self.without_monitoring_ranges:
            assert (num_activations_replaced == len(self.poly_activations))
            print(num_activations_replaced , "activations replaced by polyAct")

        # Approximate activations by poly activation:
        if self.replace_poly_act:
            for act_layer in self.poly_activations:
                #act_layer.replace_to_poly(range_val=self.act_range)
                if act_layer.act_type=="relu":
                    act_layer.replace_to_poly(range_val=self.act_range)
        if self.replace_ln_to_poly:
            for ln_layer in self.ln_lyers_for_monitoring:
                print("Replace LN to Poly")
                ln_layer.replace_to_poly(self.ln_min, self.ln_max)
    
        ### Collect all RBN layers
        if self.use_RBN_loss: 
            from src.models.nn.normalization import RBN
            cnt_rbn_layers, rbn_layers = get_layers(self.model, lst=[], type_of_layer=RBN)
            print("cnt_rbn_layers:" , cnt_rbn_layers)
            
            
        #rbn_loss = get_loss(rbn_layers)
        if self.hparams.task["_name_"] == "lm": # ONNX FRIENDLY ENCODER:
            del self.hparams.task['loss']
            del self.hparams.task['init_scale']  
            del self.hparams.task['bias_scale']  
            del self.hparams.task['div_val']  
            del self.hparams.task['cutoffs']  
            del self.hparams.task['tie_weights']  
            del self.hparams.task['tie_projs'] 
            del self.hparams.task['dropemb']  
            del self.hparams.task['dropsoft'] 
            self.hparams.task['metrics'] = "ppl"
            
        # Instantiate the task
        self.task = utils.instantiate(
            tasks.registry, self.hparams.task, dataset=self.dataset, model=self.model
        )

        # Create encoders and decoders
        encoder = encoders.instantiate(
            encoder_cfg, dataset=self.dataset, model=self.model
        )
        decoder = decoders.instantiate(
            decoder_cfg, model=self.model, dataset=self.dataset
        )

        # Extract the modules so they show up in the top level parameter count
        self.encoder = U.PassthroughSequential(self.task.encoder, encoder)
        self.decoder = U.PassthroughSequential(decoder, self.task.decoder)
        


        if self.use_RBN_loss:
            norm_factor = 1.0
            self.loss = RBNLoss(self.task.loss, rbn_layers=rbn_layers, var_coef=self.var_coef*norm_factor, mean_coef=self.mean_coef*norm_factor)
        elif self.use_range_loss: 
            self.loss = RangeLoss(self.task.loss, activation_layers=self.poly_activations, loss_coef=self.range_loss_coef, ln_layers=self.ln_lyers_for_monitoring, pwr_layers=self.powersoftmax_lyers_for_monitoring, ln_loss=self.use_ln_loss, ln_loss_coef=self.ln_loss_coef)
        else:
            #self.loss = self.task.loss
            self.loss = RangeLoss(self.task.loss, wrapper=True)
        self.loss_val = self.task.loss
        if hasattr(self.task, 'loss_val'):
            self.loss_val = self.task.loss_val
        self.metrics = self.task.metrics

        # Handle state logic
        self._initialize_state()

        val_dataloaders = self.val_dataloader()
        self.val_n_batches = [len(dataloader) for dataloader in val_dataloaders]
        self.train_n_batches = len(self.train_dataloader())
        print("train number of batches:",self.train_n_batches)
        print("val and test number of batches:",self.val_n_batches)
    
    #def load_state_dict(self, state_dict, strict=False):
    def load_state_dict(self, state_dict, strict=True): 
        if self.hparams.train.pretrained_model_state_hook['_name_'] is not None:
            model_state_hook = utils.instantiate(
                registry.model_state_hook,
                self.hparams.train.pretrained_model_state_hook.copy(),
                partial=True,
            )
            # Modify the checkpoint['state_dict'] inside model_state_hook e.g. to inflate 2D convs to 3D convs
            state_dict = model_state_hook(self.model, state_dict)

        print("Custom load_state_dict function is running.")

        # note, it needs to return something from the normal function we overrided
        return super().load_state_dict(state_dict, strict=strict)

    def _check_config(self):
        assert self.hparams.train.state.mode in [None, "none", "null", "reset", "bptt", "tbptt"]
        assert (
            (n := self.hparams.train.state.n_context) is None
            or isinstance(n, int)
            and n >= 0
        )
        assert (
            (n := self.hparams.train.state.n_context_eval) is None
            or isinstance(n, int)
            and n >= 0
        )

    def _initialize_state(self):
        """Called at model setup and start of epoch to completely reset state"""
        self._state = None
        self._memory_chunks = []

    def _reset_state(self, batch, device=None):
        """Called to construct default_state when necessary, e.g. during BPTT"""
        device = device or batch[0].device
        self._state = self.model.default_state(*batch[0].shape[:1], device=device)

    def _detach_state(self, state):
        if isinstance(state, torch.Tensor):
            return state.detach()
        elif isinstance(state, tuple):
            return tuple(self._detach_state(s) for s in state)
        elif isinstance(state, list):
            return [self._detach_state(s) for s in state]
        elif isinstance(state, dict):
            return {k: self._detach_state(v) for k, v in state.items()}
        elif state is None:
            return None
        else:
            raise NotImplementedError

    def _process_state(self, batch, batch_idx, train=True):
        """Handle logic for state context."""
        # Number of context steps
        key = "n_context" if train else "n_context_eval"
        n_context = self.hparams.train.state.get(key)

        # Don't need to do anything if 0 context steps. Make sure there is no state
        if n_context == 0 and self.hparams.train.state.mode not in ['tbptt']:
            self._initialize_state()
            return

        # Reset state if needed
        if self.hparams.train.state.mode == "reset":
            if batch_idx % (n_context + 1) == 0:
                self._reset_state(batch)

        # Pass through memory chunks
        elif self.hparams.train.state.mode == "bptt":
            self._reset_state(batch)
            with torch.no_grad():  # should be unnecessary because individual modules should handle this
                for _batch in self._memory_chunks:
                    self.forward(_batch)
            # Prepare for next step
            self._memory_chunks.append(batch)
            self._memory_chunks = self._memory_chunks[-n_context:]

        elif self.hparams.train.state.mode == 'tbptt':
            _, _, z = batch
            reset = z["reset"]
            if reset:
                self._reset_state(batch)
            else:
                self._state = self._detach_state(self._state)

    def _on_epoch_start(self):
        self._initialize_state()
        

    def forward(self, batch):
        """Passes a batch through the encoder, backbone, and decoder"""
        # z holds arguments such as sequence length
        x, y, *z = batch # z holds extra dataloader info such as resolution
        if len(z) == 0:
            z = {}
        else:
            assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments"
            z = z[0]
        x, w = self.encoder(x, **z) # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs
        x, state = self.model(x, **w, state=self._state)
        self._state = state
        x, w = self.decoder(x, state=state, **z)
        return x, y, w
    
    def forward_tokens(self,model, batch):
        
        x = model(batch)
        # """Passes a batch through the encoder, backbone, and decoder"""
        # x = batch
        # x, _ = self.encoder(x) # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs
        # x, state = self.model(x)#, **w, state=self._state)
        # #x = self.decoder(x)[0]
        return x

    def step(self, x_t):
        x_t, *_ = self.encoder(x_t) # Potential edge case for encoders that expect (B, L, H)?
        x_t, state = self.model.step(x_t, state=self._state)
        self._state = state
        # x_t = x_t[:, None, ...] # Dummy length
        # x_t, *_ = self.decoder(x_t, state=state)
        # x_t = x_t[:, 0, ...]
        x_t, *_ = self.decoder.step(x_t, state=state)
        return x_t

    def _shared_step(self, batch, batch_idx, prefix="train" ,last_batch=False):
        if batch_idx == 0:
            for act_layer in self.poly_activations:
                act_layer.reset_ranges()
            for norm_layer in self.ln_lyers_for_monitoring :
                norm_layer.reset_stat()
            for power_layer in self.powersoftmax_lyers_for_monitoring:
                if hasattr(power_layer, 'reset_stat'):
                    power_layer.reset_stat()
            
            
        #if prefix == "val" and batch_idx == 0 :  self.epoch +=1
        self._process_state(batch, batch_idx, train=(prefix == "train"))

        x, y, w = self.forward(batch)

        # Loss
        if prefix == 'train':
            loss = self.loss(x, y, **w)
        else:
            loss = self.loss_val(x, y, **w)

        # Metrics
        metrics = self.metrics(x, y, **w)
        metrics["loss"] = loss
        metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}

        # ClearML logging:
        if last_batch: # and torch.distributed.get_rank() == 0:
            if self.clr_logger is not None:
                #print("monitoring stat")
                # for key,val in metrics.items():
                #     loader_name, metric_name = key.split("/")
                #     self.clr_logger.report_scalar(title=metric_name, series=loader_name, iteration=self.epoch, value=val)
        
                # if prefix == 'test': #lr only on testing
                #     self.clr_logger.report_scalar(title="Learning Rate", series=".", iteration=self.epoch, value=self.optimizers().param_groups[0]["lr"])
                #     try: # a temporal hack for monitoring stat of several optimizers (todo wrap this logic with lightning standard interface)
                #         self.clr_logger.report_scalar(title="Learning Rate", series="*", iteration=self.epoch, value=self.lr_schedulers().get_lr())
                #     except:
                #         pass
                # Monitor layernorm_statistics:
 
                for cnt,layer in enumerate(self.ln_lyers_for_monitoring):
                    dict_result = layer.get_stat()
                    global_ln_min = dict_result["min"].item() if cnt == 0 else min(global_ln_min, dict_result["min"].item()) 
                    global_ln_max = dict_result["max"].item() if cnt == 0 else max(global_ln_max, dict_result["max"].item()) 
                    self.clr_logger.report_scalar(title="ln-min", series=prefix+ "-" + str(cnt), iteration=self.epoch,  value=dict_result["min"].item())
                    self.clr_logger.report_scalar(title="ln-max", series=prefix+ "-" + str(cnt), iteration=self.epoch, value=dict_result["max"].item())
                    self.clr_logger.report_scalar(title="ln-mean", series=prefix+ "-" + str(cnt), iteration=self.epoch, value=dict_result["mean"].item())

                if len(self.ln_lyers_for_monitoring) > 0 :
                    self.clr_logger.report_scalar(title="Global Stat LN", series=prefix+"-min", iteration=self.epoch, value=global_ln_min)
                    self.clr_logger.report_scalar(title="Global Stat LN", series=prefix+"-max", iteration=self.epoch, value=global_ln_max)

                dict_result_pwr = None
                for cnt,pwrlayer in enumerate(self.powersoftmax_lyers_for_monitoring):
                    if hasattr(pwrlayer, 'get_stat'):
                        dict_result_pwr = pwrlayer.get_stat()
                        global_min_pwr_div = dict_result_pwr["min-div"].item() if cnt == 0 else min(global_min_pwr_div, dict_result_pwr["min-div"].item()) 
                        global_max_pwr_div = dict_result_pwr["max-div"].item() if cnt == 0 else max(global_max_pwr_div, dict_result_pwr["max-div"].item()) 
                        self.clr_logger.report_scalar(title="PwrSftMax-div-min", series=prefix+ "-" + str(cnt), iteration=self.epoch,  value=dict_result_pwr["min-div"].item())
                        self.clr_logger.report_scalar(title="PwrSftMax-div-max", series=prefix+ "-" + str(cnt), iteration=self.epoch, value=dict_result_pwr["max-div"].item())
                        self.clr_logger.report_scalar(title="PwrSftMax-div-mean", series=prefix+ "-" + str(cnt), iteration=self.epoch, value=dict_result_pwr["mean-div"].item())
                        self.clr_logger.report_scalar(title="PwrSftMax-scores-max", series=prefix+ "-" + str(cnt), iteration=self.epoch, value=dict_result_pwr["max-score"].item())
                        self.clr_logger.report_scalar(title="PwrSftMax-scores-min", series=prefix+ "-" + str(cnt), iteration=self.epoch, value=dict_result_pwr["min-score"].item())
                                                         
                        global_min_pwr_score = dict_result_pwr["min-score"].item() if cnt == 0 else min(global_min_pwr_score, dict_result_pwr["min-score"].item()) 
                        global_max_pwr_score = dict_result_pwr["max-score"].item() if cnt == 0 else max(global_max_pwr_score, dict_result_pwr["max-score"].item()) 
                if dict_result_pwr is not None:
                    if len(self.powersoftmax_lyers_for_monitoring) > 0 :
                        self.clr_logger.report_scalar(title="Global Stat PwrSftMax-div", series=prefix+"-min", iteration=self.epoch, value=global_min_pwr_div)
                        self.clr_logger.report_scalar(title="Global Stat PwrSftMax-div", series=prefix+"-max", iteration=self.epoch, value=global_max_pwr_div)
                        self.clr_logger.report_scalar(title="Global Stat PwrSftMax-score", series=prefix+"-min", iteration=self.epoch, value=global_min_pwr_score)
                        self.clr_logger.report_scalar(title="Global Stat PwrSftMax-score", series=prefix+"-max", iteration=self.epoch, value=global_max_pwr_score)
                 

                # Monitor activations:
                for cnt,layer in enumerate(self.poly_activations):
                    min_val, max_val = layer.get_min_max()
                    global_min = min_val if cnt == 0 else min(global_min, min_val) 
                    global_max = max_val if cnt == 0 else max(global_max, max_val) 
    
                    self.clr_logger.report_scalar(title="Range per Layer", series=prefix+"-min-"+str(cnt), iteration=self.epoch, value=min_val)
                    self.clr_logger.report_scalar(title="Range per Layer", series=prefix+"-max-"+str(cnt), iteration=self.epoch, value=max_val)
                
                if len(self.poly_activations) > 0 :
                    self.clr_logger.report_scalar(title="Global Ranges", series=prefix+"-min", iteration=self.epoch, value=global_min)
                    self.clr_logger.report_scalar(title="Global Ranges", series=prefix+"-max", iteration=self.epoch, value=global_max)
                    
                            
        
        # Calculate torchmetrics: these are accumulated and logged at the end of epochs
        self.task.torchmetrics(x, y, prefix)

        self.log_dict(
            metrics,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            add_dataloader_idx=False,
            sync_dist=True,
        )

        
        return loss

    def on_train_epoch_start(self):
        #print("train start")
        self._on_epoch_start()
        # Reset training torchmetrics
        self.task._reset_torchmetrics("train")
        # Temporal soluation to hack the leraning rate without change deafult settings. todo: wrap with standard lightning logic.
        if self.use_range_loss and self.should_config_optimizer:
            curr_lr = self.lr_schedulers().get_last_lr()[0]
            new_lr = self.hparams.optimizer.lr
            self.lr_schedulers().lr_lambdas[0] = lambda epoch: (new_lr / curr_lr)
            self.lr_schedulers().step()
            self.should_config_optimizer = False

    def training_epoch_end(self, outputs):
        #print("train end")
        # Log training torchmetrics
        super().training_epoch_end(outputs)
        self.log_dict(
            {f"train/{k}": v for k, v in self.task.get_torchmetrics("train").items()},
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            add_dataloader_idx=False,
            sync_dist=True,
        )

        metrics = self.trainer.callback_metrics
        if self.clr_logger is not None:
            for key,val in metrics.items():
                loader_name, metric_name = key.split("/")
                if loader_name in ['train','test','val']:
                    self.clr_logger.report_scalar(title=metric_name, series=loader_name, iteration=self.epoch, value=val)
            
                #self.clr_logger.report_scalar(title="Learning Rate", series=".", iteration=self.epoch, value=self.optimizers().param_groups[0]["lr"])
                if loader_name == 'test': #lr only on testing
                    self.clr_logger.report_scalar(title="Learning Rate", series=".", iteration=self.epoch, value=self.optimizers().param_groups[0]["lr"])
                    try: # a temporal hack for monitoring stat of several optimizers (todo wrap this logic with lightning standard interface)
                        self.clr_logger.report_scalar(title="Learning Rate", series="*", iteration=self.epoch, value=self.lr_schedulers().get_lr())
                    except:
                        pass
        self.epoch +=1
        

        # Testing the model on pre-defined prompts 
        if self.prompt is not None:
            import torch.nn.functional as F
            tokenizer = self.dataset.vocab
            pad = "once upon a time in a tranquil village, surrounded by lush forests and serene lakes, there lived an old wise man. He was known for his profound wisdom and kind heart. People from distant lands would travel miles to seek his advice on matters of utmost importance. The old man, with his long white beard and gentle eyes, would welcome everyone with open arms, listening patiently to their stories and concerns. he believed in the power of kindness and empathy, emphasizing that even the smallest act of compassion could make a world of difference. His humble abode was filled with books from floor to ceiling, each telling a story of courage, love, and resilience. The villagers admired him greatly and saw him as a beacon of hope and guidance. as the seasons changed, the wise man continued to share his knowledge, inspiring generations to come. He taught them the importance of living in harmony with nature and cherishing the bonds of friendship and family. His legacy lived on, turning the small village into a sanctuary of wisdom and love, where every heart found solace and every soul danced in the rhythm of life’s beautiful melody."
            pad = pad.lower()
            #self.prompt = pad + self.prompt
            tokens = tokenizer.tokenize(self.prompt)
            tokens = tokenizer.tokenize(self.prompt)
            tokens = tokenizer.convert_to_tensor(tokens).cuda().unsqueeze(0)
            attn_layers = get_attn_layers(self.model)
            for attn_layer in attn_layers:
                attn_layer.causal = False
            model = FullModel(self.model, self.encoder)
            model = model.cuda()
            tokens=tokens.cuda()
            out = self.forward_tokens(model,tokens)
            out_proj = self.task.loss.compute_logits(out)[-1,:]
            out_proj = F.softmax(out_proj)
            print("tokens after softmax: ",out_proj.shape, out_proj.argmax())
            if self.prompt_topk > 0 :
                values, idxs = out_proj.topk(self.prompt_topk)
                print("values: ", values.shape)
                print("idxs: ", idxs.shape)
                for y in zip(values.tolist(), idxs.tolist()):
                    print((tokenizer.get_symbols([y[1]])[0],round(y[0],5)))
                
            if self.prompt_options is not None:
                options = self.prompt_options.split()
                print("Options:" , options)
                options_with_prob = {option : out_proj[tokenizer.get_idx(option)].item() for option in options}
                sorted_words = sorted(options_with_prob.keys(), key = lambda x: options_with_prob.get(x) , reverse=True)
                for word in sorted_words:
                    print(word,options_with_prob[word])
            exit()
        

    def on_validation_epoch_start(self):
        self._on_epoch_start()
        # Reset all validation torchmetrics
        for name in self.val_loader_names:
            self.task._reset_torchmetrics(name)
        if self.save_attn_mat:
            for i,attn_layer in enumerate(get_attn_layers(self.model)):
                if hasattr(attn_layer.mha, 'multihead_attention'): #PowerSoftmax
                    attn_layer.mha.multihead_attention.attention_layer.act.layer_idx = i 
                    attn_layer.mha.multihead_attention.attention_layer.act.path = self.config.attn_dirpath
                else: #Softmax
                    attn_layer.layer_idx = i 
                    attn_layer.path = self.config.attn_dirpath


    def validation_epoch_end(self, outputs):
        # Log all validation torchmetrics
        super().validation_epoch_end(outputs)
        for name in self.val_loader_names:
            self.log_dict(
                {f"{name}/{k}": v for k, v in self.task.get_torchmetrics(name).items()},
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                add_dataloader_idx=False,
                sync_dist=True,
            )

        if self.save_per_epoch:
            path = self.config["callbacks"]["model_checkpoint"]["dirpath"]
            curr_path= path+"/epoch" +str(self.epoch)
            if not os.path.exists(curr_path):
                os.makedirs(curr_path)

            tokens = self.dataset.vocab.tokenize("the largest country in America is")
            tokens = self.dataset.vocab.convert_to_tensor(tokens).unsqueeze(0)
            tokens = torch.randn(3,512,768)
            #model = FullModel(self.model, self.encoder)
            for attn_layer in get_attn_layers(self.model):
                attn_layer.causal = False
            create_onnx(self.model, curr_path+"/onnx.onnx", export_params=True, tokens=tokens)
            print("onnx saved:", curr_path)
        if self.save_attn_mat:
            print('============')
            print('done')
            exit()
     
    def on_test_epoch_start(self):
        self._on_epoch_start()
        # Reset all test torchmetrics
        for name in self.test_loader_names:
            self.task._reset_torchmetrics(name)

    def test_epoch_end(self, outputs):
        # Log all test torchmetrics
        super().test_epoch_end(outputs)
        for name in self.test_loader_names:
            self.log_dict(
                {f"{name}/{k}": v for k, v in self.task.get_torchmetrics(name).items()},
                on_step=False,
                on_epoch=True,
                prog_bar=True,
                add_dataloader_idx=False,
                sync_dist=True,
            )
        
        
        

    def training_step(self, batch, batch_idx):
        if not (batch_idx<self.train_n_batches): raise ValueError("batch index too large")
        last_batch = (batch_idx == (self.train_n_batches - 1))
        loss = self._shared_step(batch, batch_idx, prefix="train",last_batch=last_batch)

        # Log the loss explicitly so it shows up in WandB
        # Note that this currently runs into a bug in the progress bar with ddp (as of 1.4.6)
        # https://github.com/PyTorchLightning/pytorch-lightning/pull/9142
        # We additionally log the epochs under 'trainer' to get a consistent prefix with 'global_step'
        loss_epoch = {"trainer/loss": loss, "trainer/epoch": self.current_epoch}
        self.log_dict(
            loss_epoch,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            add_dataloader_idx=False,
            sync_dist=True,
        )

        # Log any extra info that the models want to expose (e.g. output norms)
        metrics = {}
        for module in list(self.modules())[1:]:
            if hasattr(module, "metrics"):
                metrics.update(module.metrics)

        self.log_dict(
            metrics,
            on_step=True,
            on_epoch=False,
            prog_bar=False,
            add_dataloader_idx=False,
            sync_dist=True,
        )

        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        total_batches = self.val_n_batches[dataloader_idx]
        if not (batch_idx<total_batches): raise ValueError("batch index too large")
        last_batch = (batch_idx == (total_batches - 1))
        # if batch_idx == total_batches - 1:  # Check if this is the last batch
        #     print("validation loader done")
        #     print('-----', self.val_loader_names[dataloader_idx],'-----')
            #print("validation loader done", type(dataloader_idx))#, self.test_loader_names[dataloader_idx])
        '''
        if batch_idx ==0: print("dataloader_idx batch 0:",dataloader_idx)
        '''
        ema = (
            self.val_loader_names[dataloader_idx].endswith("/ema")
            and self.optimizers().optimizer.stepped
        )  # There's a bit of an annoying edge case with the first (0-th) epoch; it has to be excluded due to the initial sanity check
        if ema:
            self.optimizers().swap_ema()
        loss = self._shared_step(
            batch, batch_idx, prefix=self.val_loader_names[dataloader_idx],last_batch=last_batch
        )
        if ema:
            self.optimizers().swap_ema()

        return loss

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        return self._shared_step(
            batch, batch_idx, prefix=self.test_loader_names[dataloader_idx]
        )

    def configure_optimizers(self):
        if self.clr_logger is not None:
            self.clr_parameters["Trianable Parameters"] = count_parameters(self)
            #self.clr_logger.report_text(msg = "Trainable Parameters:" + str(count_parameters(self)), print_console=False)

        # Set zero weight decay for some params
        if 'optimizer_param_grouping' in self.hparams.train:
            add_optimizer_hooks(self.model, **self.hparams.train.optimizer_param_grouping)

        # Normal parameters
        all_params = list(self.parameters())
        params = [p for p in all_params if not hasattr(p, "_optim")]


        # Construct optimizer, add EMA if necessary
        if self.hparams.train.ema > 0.0:
            optimizer = utils.instantiate(
                registry.optimizer,
                self.hparams.optimizer,
                params,
                wrap=build_ema_optimizer,
                polyak=self.hparams.train.ema,
            )
        else:
            optimizer = utils.instantiate(registry.optimizer, self.hparams.optimizer, params)
        del self.hparams.optimizer._name_

        # Add parameters with special hyperparameters
        hps = [getattr(p, "_optim") for p in all_params if hasattr(p, "_optim")]
        hps = [
            # dict(s) for s in set(frozenset(hp.items()) for hp in hps)
            dict(s) for s in sorted(list(dict.fromkeys(frozenset(hp.items()) for hp in hps)))
            # dict(s) for s in dict.fromkeys(frozenset(hp.items()) for hp in hps)
        ]  # Unique dicts
        for hp in hps:
            params = [p for p in all_params if getattr(p, "_optim", None) == hp]
            optimizer.add_param_group(
                {"params": params, **self.hparams.optimizer, **hp}
            )

        ### Layer Decay ###

        if self.hparams.train.layer_decay['_name_'] is not None:
            get_num_layer = utils.instantiate(
                registry.layer_decay,
                self.hparams.train.layer_decay['_name_'],
                partial=True,
            )

            # Go through all parameters and get num layer
            layer_wise_groups = {}
            num_max_layers = 0
            for name, p in self.named_parameters():
                # Get layer id for each parameter in the model
                layer_id = get_num_layer(name)

                # Add to layer wise group
                if layer_id not in layer_wise_groups:
                    layer_wise_groups[layer_id] = {
                        'params': [],
                        'lr': None,
                        'weight_decay': self.hparams.optimizer.weight_decay
                    }
                layer_wise_groups[layer_id]['params'].append(p)

                if layer_id > num_max_layers: num_max_layers = layer_id

            # Update lr for each layer
            for layer_id, group in layer_wise_groups.items():
                group['lr'] = self.hparams.optimizer.lr * (self.hparams.train.layer_decay.decay ** (num_max_layers - layer_id))

            # Reset the torch optimizer's param groups
            optimizer.param_groups = []
            for layer_id, group in layer_wise_groups.items():
                optimizer.add_param_group(group)

        # Print optimizer info for debugging
        keys = set([k for hp in hps for k in hp.keys()])  # Special hparams
        utils.train.log_optimizer(log, optimizer, keys)

        # Configure scheduler
        if "scheduler" not in self.hparams:
            return optimizer
        lr_scheduler = utils.instantiate(
            registry.scheduler, self.hparams.scheduler, optimizer
        )
        scheduler = {
            "scheduler": lr_scheduler,
            "interval": self.hparams.train.interval,  # 'epoch' or 'step'
            "monitor": self.hparams.train.monitor,
            "name": "trainer/lr",  # default is e.g. 'lr-AdamW'
        }
        # See documentation for how to configure the return
        # https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.core.lightning.html#pytorch_lightning.core.lightning.LightningModule.configure_optimizers
        return [optimizer], [scheduler]

    def train_dataloader(self):
        train_loader = self.dataset.train_dataloader(**self.hparams.loader)
        # Print stats in a try block since some dataloaders might not have a length?
        try:
            log.info(
                f"Loaded 'train' dataloader:".ljust(30) +
                f"{len(train_loader.dataset):7} examples | {len(train_loader):6} steps"
            )
        except:
            pass
        return train_loader

    def _eval_dataloaders_names(self, loaders, prefix):
        """Process loaders into a list of names and loaders"""
        if utils.is_dict(loaders):
            return [
                f"{prefix}/{k}" if k is not None else prefix for k in loaders.keys()
            ], list(loaders.values())
        elif utils.is_list(loaders):
            return [f"{prefix}/{i}" for i in range(len(loaders))], loaders
        else:
            return [prefix], [loaders]

    def _eval_dataloaders(self):
        # Return all val + test loaders
        val_loaders = self.dataset.val_dataloader(**self.hparams.loader)
        test_loaders = self.dataset.test_dataloader(**self.hparams.loader)
        val_loader_names, val_loaders = self._eval_dataloaders_names(val_loaders, "val")
        test_loader_names, test_loaders = self._eval_dataloaders_names(
            test_loaders, "test"
        )

        # Duplicate datasets for ema
        if self.hparams.train.ema > 0.0:
            val_loader_names += [name + "/ema" for name in val_loader_names]
            val_loaders = val_loaders + val_loaders
            test_loader_names += [name + "/ema" for name in test_loader_names]
            test_loaders = test_loaders + test_loaders

        # adding option to only have val loader at eval (eg if test is duplicate)
        if self.hparams.train.get("remove_test_loader_in_eval", None) is not None:
            eval_loader_names = val_loader_names
            eval_loaders = val_loaders
        # default behavior is to add test loaders in eval
        else:
            eval_loader_names = val_loader_names + test_loader_names
            eval_loaders = val_loaders + test_loaders

        return eval_loader_names, eval_loaders

    def val_dataloader(self):
        val_loader_names, val_loaders = self._eval_dataloaders()
        self.val_loader_names = val_loader_names
        try:
            for name, loader in zip(val_loader_names, val_loaders):
                log.info(
                    f"Loaded '{name}' dataloader:".ljust(30) +
                    f"{len(loader.dataset):7} examples | {len(loader):6} steps"
                )
        except:
            pass

        return val_loaders

    def test_dataloader(self):
        test_loader_names, test_loaders = self._eval_dataloaders()
        self.test_loader_names = ["final/" + name for name in test_loader_names]
        return test_loaders


### pytorch-lightning utils and entrypoint ###

def create_trainer(config):
    callbacks: List[pl.Callback] = []
    logger = None

    # WandB Logging
    if config.get("wandb") is not None:
        # Pass in wandb.init(config=) argument to get the nice 'x.y.0.z' hparams logged
        # Can pass in config_exclude_keys='wandb' to remove certain groups
        import wandb

        logger = CustomWandbLogger(
            config=utils.to_dict(config, recursive=True),
            settings=wandb.Settings(start_method="fork"),
            **config.wandb,
        )

    # Lightning callbacks
    if "callbacks" in config:
        for _name_, callback in config.callbacks.items():
            if callback is None: continue
            if config.get("wandb") is None and _name_ in ["learning_rate_monitor"]:
                continue
            log.info(f"Instantiating callback <{registry.callbacks[_name_]}>")
            callback._name_ = _name_
            callbacks.append(utils.instantiate(registry.callbacks, callback))

    # Profiler
    profiler = None
    if config.trainer.get("profiler", None) is not None:
        profiler = hydra.utils.instantiate(config.trainer.profiler)
        config.trainer.pop("profiler")


    # Configure ddp automatically
    if config.trainer.accelerator == 'gpu' and config.trainer.devices > 1:
        print("ddp automatically configured, more than 1 gpu used!")
        config.trainer.strategy = "ddp"

    # Add ProgressiveResizing callback
    if config.callbacks.get("progressive_resizing", None) is not None:
        num_stages = len(config.callbacks.progressive_resizing.stage_params)
        print(f"Progressive Resizing: {num_stages} stages")
        for i, e in enumerate(config.callbacks.progressive_resizing.stage_params):
            # Stage params are resolution and epochs, pretty print
            print(f"\tStage {i}: {e['resolution']} @ {e['epochs']} epochs")

    trainer = pl.Trainer(
        logger=logger,
        callbacks=callbacks,
        profiler=profiler,
        **config.trainer,
    )
    return trainer


def train(config):
    if config.train.seed is not None:
        pl.seed_everything(config.train.seed, workers=True)
    trainer = create_trainer(config)
    model = SequenceLightningModule(config)
    model.config = config
    if "clrml" in config:
        if config.clrml:
            folder_name = config.get("clrml_folder", "HE-Friendly-Attention") 
            task = Task.init(project_name="HE-Friendly-Attention/" + folder_name, task_name=config.clrml_name)
            print("Initialize clearml on: ", "HE-Friendly-Attention/" + folder_name + "/" + config.clrml_name)
            model.clr_logger = task.get_logger()
            model.clr_parameters = task.connect(OmegaConf.to_container(config))
        else:
            model.clr_logger = None
    else:
            model.clr_logger = None
    # Run initial validation epoch (useful for debugging, finetuning)
    if config.train.validate_at_start:
        print("Running validation before training")
        trainer.validate(model)
    print('-'*20 , "\n" , model)
    if config.train.ckpt is not None:
        trainer.fit(model, ckpt_path=config.train.ckpt)
    else:
        trainer.fit(model)
    if config.train.test:
        trainer.test(model)

@hydra.main(config_path="configs", config_name="config.yaml")
def main(config: OmegaConf):

    # Process config:
    # - register evaluation resolver
    # - filter out keys used only for interpolation
    # - optional hooks, including disabling python warnings or debug friendly configuration
    config = utils.train.process_config(config)

    # Pretty print config using Rich library
    utils.train.print_config(config, resolve=True)

    train(config)
    
import onnx
from collections import defaultdict

def operator_count(onnx_file):
    # Load the ONNX model
    model = onnx.load(onnx_file)

    # Dictionary to store the count of each operator
    operator_dict = defaultdict(int)

    # Traverse all nodes in the model graph
    for node in model.graph.node:
        # Increment the count of the operator in the dictionary
        operator_dict[node.op_type] += 1

    return dict(operator_dict)

if __name__ == "__main__":
    main()
