import pdb
from torch import nn
import torch

import hydra
from hydra.utils import instantiate

from pytorch_lightning import LightningModule

from llid.utils.metrics import MetricsList

from omegaconf import open_dict

from email.policy import default
from inspect import indentsize
from time import time
from llid.utils.intrinsic_dimension import estimate_id
from llid.utils.utils import Timer

import pickle
from torch import nn
import torch.nn.functional as F
import torch
from scipy.spatial.distance import pdist,squareform
import numpy as np
import time
from torch.utils.data import Subset,  DataLoader
from itertools import chain

from collections import defaultdict
import pdb

class CommonTL(LightningModule):
    def __init__(
        self,
        config
    ):
        super().__init__()
        self.model = instantiate(config.model)
        self.metric = instantiate(config.metrics)
        self.epoch_metric = instantiate(config.epoch_metrics)
        self.regs = instantiate(config.regs, model=self.model, _recursive_=False)

        self.step_cnt = 0

        self.config = config
        self.save_hyperparameters(ignore="model")


    def configure_optimizers(self):
        
        if self.config.alternating_reg_loss:

            opt_reg = instantiate(self.config.optimizer,
                          params=self.regs.parameters())
            opt = instantiate(self.config.optimizer,
                          params=self.model.parameters())
            
            sch_reg = instantiate(self.config.lr_scheduler, optimizer=opt_reg)
            sch = instantiate(self.config.lr_scheduler, optimizer=opt)

            return [opt_reg, opt], [sch_reg,sch]

        opt = instantiate(self.config.optimizer,
                          params=chain(self.model.parameters(), self.regs.parameters()))
        sch = instantiate(self.config.lr_scheduler, optimizer=opt)

        return [opt], [sch]
    
    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure,
        on_tpu=False,
        using_native_amp=False,
        using_lbfgs=False,
    ):
        lr = self.config.lr

        # Warmup for first iterations
        if self.trainer.global_step < self.config.warmup_steps:
            lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.config.warmup_steps)
            for pg in optimizer.param_groups:
                pg["lr"] = lr_scale * lr

            # update params
            optimizer.step(closure=optimizer_closure)
        else:
            optimizer.step(closure=optimizer_closure)

class IDEstimationTL(CommonTL):
    def __init__(
        self,
        config
    ):
        super().__init__(config)

        if "id_estimation" in config:
            self.timer = Timer(verbose=self.config.id_estimation.timing_verbose)
            self.id_dict = {}
            self.id_dict["train"] = defaultdict(dict)
            self.id_dict["val"] = defaultdict(dict)
        
        self.train_step = 0
        self.last_train_step_done_for_val = -1
    
    def on_after_backward(self):
        self.train_step += 1

    def restrict_dataloader(self, dataloader, num_per_class=-1, cls=-1):
        used_indices = []
        labels = dataloader.dataset.targets

        cls_to_count = defaultdict(int)

        for idx,label in enumerate(labels):
            if (cls_to_count[label] < num_per_class or num_per_class == -1) and \
                (label == cls or cls == -1):
                    cls_to_count[label] += 1
                    used_indices.append(idx)
        
        dataset = Subset(dataloader.dataset, used_indices)

        return DataLoader(dataset, **self.config.id_estimation.dataloader)
    
    def compute_id(self, intermediates):
        ID = []
        n = int(np.round(intermediates.shape[0]*self.config.id_estimation.fraction))  
        self.timer.checkpoint(f"Pre pdists {intermediates.shape}")
        dist = F.pdist(intermediates.cuda()).cpu()
        self.timer.checkpoint(f"Post pdists {intermediates.shape}")
        dist = squareform(dist)
        self.timer.checkpoint("Post squareform")

        for i in range(self.config.id_estimation.nres):
            dist_s = dist
            perm = np.random.permutation(dist.shape[0])[0:n]
            dist_s = dist_s[perm,:]
            dist_s = dist_s[:,perm]
            ID.append(estimate_id(dist_s,verbose=False)[2])
            self.timer.checkpoint(f"ID pass {i}")

        mean = np.mean(ID) 
        error = np.std(ID) 
        return mean,error

    def log_id(self, dataloader, group):
        self.timer.start()

        current_step = self.train_step if self.config.id_estimation.do_steps else self.current_epoch

        if len(self.config.id_estimation.layers) == 0 or (current_step != (self.config.id_estimation.max_epochs-1) and \
            current_step % self.config.id_estimation.estimate_id_every != 0): return
        
        num_classes = max(dataloader.dataset.targets)+1

        if self.config.id_estimation.estimate_data_id and current_step == 0:

            total_dataloader = {"total": self.restrict_dataloader(dataloader, num_per_class=self.config.id_estimation.combined_total_per_class)}
            class_dataloaders = {f"Class_{cls}": \
                self.restrict_dataloader(dataloader, num_per_class=self.config.id_estimation.num_per_class, cls=cls) for cls in range(num_classes)}
    
            for data_name,current_dataloader in {**total_dataloader, **class_dataloaders}.items():

                dataset = []

                for idx,(batch) in enumerate(current_dataloader):
                    imgs,_ = batch
                    dataset.append(imgs)
                
                dataset = torch.cat(dataset, 0)
                num_datapoints = dataset.shape[0]

                id,id_err = self.compute_id(dataset.reshape(num_datapoints, -1))
                self.log_dict({f"{group}/{data_name}/id_data": id, f"{group}/id_data_error": id_err, \
                    f"{group}/{data_name}/num_datapoints": float(len(current_dataloader.dataset.indices))}, sync_dist=True)
                
                self.id_dict[f"{group}/{data_name}/id_data"] = id
                self.id_dict[f"{group}/{data_name}/id_data_error"] = id_err
        
        if current_step == 0 and not self.config.id_estimation.estimate_initial_id: return

        total_dataloader = {"total": self.restrict_dataloader(dataloader, num_per_class=self.config.id_estimation.combined_total_per_class)}
        
        class_dataloaders = {f"Class_{cls}": \
            self.restrict_dataloader(dataloader, num_per_class=self.config.id_estimation.num_per_class, cls=cls) for cls in range(num_classes)}

        for data_name,dataloader in {**total_dataloader, **class_dataloaders}.items():
            with torch.no_grad():
                self.intermediates = defaultdict(list)
                hooks = []

                def get_intermediate(name):
                    def hook(model, input, output):
                        print(name, output.shape)
                        to_save = output.detach().cpu()
                        self.intermediates[name].append(to_save)

                    return hook
                
                for name,module in self.model.named_children():
                    if name in self.config.id_estimation.layers:
                        hooks.append(module.register_forward_hook(get_intermediate(name)))
                

                for idx,(batch) in enumerate(dataloader):
                    imgs, labels = batch
                    outs = self.model(imgs.to(self.device))
             
                self.timer.checkpoint("Post inference")
                
                self.intermediates = {k: torch.cat([intermediate.reshape(intermediate.shape[0],-1) for intermediate in v], 0) for k,v in self.intermediates.items()}

                for name, intermediate in self.intermediates.items():
                    id,id_err = self.compute_id(intermediate)
                    self.log_dict({f"{group}/{data_name}/id_{name}": id, f"{group}/id_{name}_error": id_err, \
                        f"{group}/{data_name}/num_datapoints": float(len(dataloader.dataset.indices))}, sync_dist=True)
                    
                    self.id_dict[f"{group}/{data_name}/id_{name}"] = id
                    self.id_dict[f"{group}/{data_name}/id_{name}_error"] = id_err
                
                for hook in hooks:
                    hook.remove()
                
                self.timer.clear()
                with open(self.config.id_estimation.save_path, 'wb') as f:
                    pickle.dump(self.id_dict, f)


    
    def on_train_epoch_end(self, *args, **kwargs):
        self.eval()
        metrics = self.epoch_metric(self.model, self.trainer.datamodule.train_dataloader(), epoch=self.current_epoch, group="train")

        self.log_dict(metrics, sync_dist=True)

        if "id_estimation" in self.config and not self.config.id_estimation.do_steps and self.config.id_estimation.estimate_train_id: self.log_id(self.trainer.datamodule.train_dataloader(), "train")
        self.train()
    
    def on_validation_epoch_end(self, *args, **kwargs):
        metrics = self.epoch_metric(self.model, self.trainer.datamodule.val_dataloader(), epoch=self.current_epoch, group="val")

        self.log_dict(metrics, sync_dist=True)
        
        if "id_estimation" in self.config and not self.config.id_estimation.do_steps: self.log_id(self.trainer.datamodule.val_dataloader(), "val")