from email.policy import default
from inspect import indentsize
from time import time
from llid.training_loops.training_loop import IDEstimationTL
from llid.utils.intrinsic_dimension import estimate_id
from llid.utils.utils import Timer

import pickle
from torch import nn
import torch.nn.functional as F
import torch
from scipy.spatial.distance import pdist,squareform
import numpy as np
import time
from torch.utils.data import Subset,  DataLoader

from collections import defaultdict
import pdb

class ImageClassificationTL(IDEstimationTL):
    def __init__(self, config):
        super().__init__(config)
       
        class_weights = self.get_class_weights(config)

        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights)

        self.config = config
    
    def get_class_weights(self, config):
        if config.class_weights == 'None': 
            return torch.tensor([1/(config.num_classes) for _ in range(config.num_classes)])
        adjusted_class_weights_sum = sum([w**config.loss_weight_alpha for w in config.class_weights])
        adjusted_class_weights = [(w**config.loss_weight_alpha) / adjusted_class_weights_sum for w in config.class_weights]
        return torch.tensor(adjusted_class_weights)

    def training_step(self, batch, optimizer_idx=0, *args, **kwargs):
        imgs, labels = batch
        outs = self.model(imgs)
        
        base_loss = 0
        if not self.config.alternating_reg_loss or optimizer_idx == 1:
            base_loss = self.ce_loss(outs, labels)

        reg_loss = self.regs(self.model, batch, outs, step=self.step_cnt, epoch=self.current_epoch)

        with torch.no_grad():
            metrics = self.metric(self.model, batch, outs, step=self.step_cnt, group="train")
        
        l1_loss = 0
        
        loss = base_loss+reg_loss+l1_loss

        self.log_dict({"train/loss": loss, "train/base_loss": base_loss, "train/reg_loss": reg_loss, "train/l1_loss": l1_loss, **metrics}, sync_dist=True)

        if "id_estimation" in self.config and self.config.id_estimation.do_steps and self.config.id_estimation.estimate_train_id: 
            self.eval()
            self.log_id(self.trainer.datamodule.train_dataloader(), "train")
            self.train()
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        outs = self.model(imgs)
        base_loss = self.ce_loss(outs, labels)
      
        reg_loss = self.regs(self.model, batch, outs, step=self.step_cnt, epoch=self.current_epoch)

        metrics = self.metric(self.model, batch, outs,step=self.step_cnt,group="val")

        self.log_dict({"val/loss": base_loss+reg_loss, "val/base_loss": base_loss, "val/reg_loss": reg_loss, **metrics}, sync_dist=True)

        if "id_estimation" in self.config and self.config.id_estimation.do_steps: self.log_id(self.trainer.datamodule.val_dataloader(), "val")

        return

    