from copy import deepcopy
import pdb
from llid.training_loops.training_loop import IDEstimationTL
from torch import nn
import torch
from collections import defaultdict
import pickle

class SequenceClassificationTL(IDEstimationTL):

    def __init__(self, config):
        super().__init__(config)
        self.ce_loss = nn.CrossEntropyLoss()

    def training_step(self, batch):
        outs = self.model(**batch)
        loss = outs.loss

        with torch.no_grad():
            metrics = self.metric(self.model, batch, outs, step=self.step_cnt, group="train")

        self.log_dict({"train/loss": loss, **metrics}, sync_dist=True)

        if "id_estimation" in self.config and self.config.id_estimation.do_steps and self.config.id_estimation.estimate_train_id: self.log_id(self.trainer.datamodule.train_dataloader(), "train")

        return loss
    
    def validation_step(self, batch, batch_idx):
        outs = self.model(**batch)
        loss = outs.loss

        with torch.no_grad():
            metrics = self.metric(self.model, batch, outs, step=self.step_cnt, group="val")

        self.log_dict({"val/loss": loss, **metrics}, sync_dist=True)

        if "id_estimation" in self.config and self.config.id_estimation.do_steps: self.log_id(self.trainer.datamodule.val_dataloader(), "val")

        return loss
    
    def log_id(self, dataloader, group):
        self.timer.start()

        current_step = self.train_step if self.config.id_estimation.do_steps else self.current_epoch

        if len(self.config.id_estimation.layers) == 0 or (current_step != (self.config.id_estimation.max_epochs-1) and \
            current_step % self.config.id_estimation.estimate_id_every != 0): return
        
        if self.config.id_estimation.do_steps and self.last_train_step_done_for_val == current_step and group == "val": return 
        
        if current_step == 0 and not self.config.id_estimation.estimate_initial_id: return

        # Track weight norm
        def wnorm(w): return (w**2).sum().sqrt()
        weight_norms = {p: wnorm(w) for p,w in self.model.named_parameters()}
        self.log_dict(weight_norms, sync_dist=True)

        with torch.no_grad():
            self.intermediates = defaultdict(list)
            hooks = []

            def get_intermediate(name):
                def hook(model, input, output):

                    output_copy = deepcopy(output)

                    while type(output_copy) == tuple:
                        output_copy = output_copy[0]

                    to_save = output_copy.detach().cpu()
                    self.intermediates[name].append(to_save)

                return hook
            names = []

            def register_hooks(model, base_name=""):
                for name,module in model.named_children():
                    if base_name + name in self.config.id_estimation.layers:

                        hooks.append(module.register_forward_hook(get_intermediate(base_name + name)))
                        register_hooks(module, base_name + name + ".")
                    
                    else:

                        register_hooks(module, base_name + name + ".")
          
            register_hooks(self.model)
           
            for idx,(batch) in enumerate(dataloader):
                outs = self.model(**{k:v.to(self.device) for k,v in batch.items()})
            
            self.timer.checkpoint("Post inference")
            
            self.intermediates = {k: torch.cat([intermediate.reshape(intermediate.shape[0],-1) for intermediate in v], 0) for k,v in self.intermediates.items()}

            for name, intermediate in self.intermediates.items():
                id,id_err = self.compute_id(intermediate)
                self.log_dict({f"{group}/total/id_{name}": id, f"{group}/id_{name}_error": id_err}, sync_dist=True)
                
                self.id_dict[group][current_step][f"{group}/total/id_{name}"] = id
                self.id_dict[group][current_step][f"{group}/total/id_{name}_error"] = id_err
            
            for hook in hooks:
                hook.remove()
            
            self.timer.clear()
            with open(self.config.id_estimation.save_path, 'wb') as f:
                pickle.dump(self.id_dict, f)
            
        self.last_train_step_done_for_val = current_step
