from copy import deepcopy
import logging
import hydra
import math
import torch
from agent.pretrain.train_diffusion_agent import batch_to_device

log = logging.getLogger(__name__)


class RIMixin(object):
    def init_il(self):
        self.il_batch_size = self.cfg.il_train.batch_size
        self.dataset_train = hydra.utils.instantiate(self.cfg.il_dataset)
        self.dataloader_train = torch.utils.data.DataLoader(
            self.dataset_train,
            batch_size=self.il_batch_size,
            num_workers=4 if self.dataset_train.device == "cpu" else 0,
            shuffle=True,
            pin_memory=True if self.dataset_train.device == "cpu" else False,
        )
        self.dataloader_val = None
        if "train_split" in self.cfg.il_train and self.cfg.il_train.train_split < 1:
            val_indices = self.dataset_train.set_train_val_split(self.cfg.il_train.train_split)
            self.dataset_val = deepcopy(self.dataset_train)
            self.dataset_val.set_indices(val_indices)
            self.dataloader_val = torch.utils.data.DataLoader(
                self.dataset_val,
                batch_size=self.il_batch_size,
                num_workers=4 if self.dataset_val.device == "cpu" else 0,
                shuffle=True,
                pin_memory=True if self.dataset_val.device == "cpu" else False,
            )

    def get_il_loss(self, batch):
        """Compute IL loss for a batch of data."""
        if self.dataset_train.device == "cpu":
            batch = batch_to_device(batch)
        # actor_ft exists
        if hasattr(self.model, "actor_ft"):
            self.model.actor_ft.train()
        else:
            self.model.actor.train()
        loss_il = self.model.il_loss(*batch)
        return loss_il

    def get_il_value(self, batch):
        """Compute IL value for a batch of data."""
        if self.dataset_train.device == "cpu":
            batch = batch_to_device(batch)
        if isinstance(batch, tuple) and len(batch) == 2:
            conditions = batch[1]
        elif hasattr(batch, 'conditions'):
            conditions = batch.conditions
        else:
            raise ValueError("Unknown batch structure")
        state = conditions['state']
        obs = {"state": state}
        if hasattr(self.model, "critic"):
            self.model.critic.eval()
            with torch.no_grad():
                value = self.model.critic(obs).mean()
        elif hasattr(self.model, "critic_v"):
            self.model.critic_v.eval()
            with torch.no_grad():
                value = self.model.critic_v(obs).mean()
        return value


class StdScheduler:
    def __init__(self, cfg):
        self.expl_schedule = cfg.train.get("expl_schedule", None)
        self.il_val_ema = None
        self.il_val_min = float('inf')
        self.il_val_max = 0.0
        self.std_min = cfg.train.get("std_min", 0.1)
        self.std_max = cfg.train.get("std_max", 0.2)
        self.std_init = cfg.train.get("std_init", 0.1)
        self.std_warmup_itr = cfg.train.get("std_warmup", 5)
        self.std_max_itr = cfg.train.get("std_max_itr", 1000)
        log.info(f"STD schedule: {self.expl_schedule} | min: {self.std_min} | max: {self.std_max} | warmup: {self.std_warmup_itr}")

    def decay_denoising_std(self, itr):
        progress = float(itr - self.std_warmup_itr) / (self.std_max_itr - self.std_warmup_itr)
        new_std = max(self.std_min, self.std_max - progress * (self.std_max - self.std_min))
        return float(new_std)

    def decay_denoising_std_cosine(self, itr):
        progress = (itr - self.std_warmup_itr) / (self.std_max_itr - self.std_warmup_itr)
        cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
        new_std = self.std_min + (self.std_max - self.std_min) * cosine_decay
        return float(new_std)

    def decay_denoising_std_il(self, itr, current_il_val):
        if current_il_val is None:
            return self.std_init
        alpha = 0.1
        if self.il_val_ema is None:
            self.il_val_ema = current_il_val
        else:
            self.il_val_ema = alpha * current_il_val + (1 - alpha) * self.il_val_ema

        self.il_val_max = max(self.il_val_max, self.il_val_ema)
        self.il_val_min = min(self.il_val_min, self.il_val_ema)

        if self.il_val_max > self.il_val_min + 1e-6:
            frac = (self.il_val_ema - self.il_val_min) / (self.il_val_max - self.il_val_min)
        else:
            frac = 0.0

        new_std = self.std_max - min(frac * (self.std_max - self.std_min), 0.02)
        log.info(
            f"[STD] IL EMA: {self.il_val_ema:.4f} | min: {self.il_val_min:.4f} | max: {self.il_val_max:.4f} | frac: {frac:.4f}")
        log.info(f"[STD] Denoising std: {new_std:.4f} | frac: {frac:.4f}")
        return float(new_std)

    def step(self, itr, current_il_val=None):
        if itr < self.std_warmup_itr:
            progress = float(itr) / float(self.std_warmup_itr)
            return self.std_init + progress * (self.std_max - self.std_init)
        if itr >= self.std_max_itr:
            return self.std_min

        if self.expl_schedule == "linear":
            return self.decay_denoising_std(itr)
        elif self.expl_schedule == "cosine":
            return self.decay_denoising_std_cosine(itr)
        elif self.expl_schedule == "il":
            return self.decay_denoising_std_il(itr, current_il_val)
        raise ValueError(f"Unknown exploration schedule: {self.expl_schedule}")
