import os
import torch
from torchmetrics import Metric


class LossMetric(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.add_state("train_loss", default=torch.tensor(0.0, dtype=torch.float64, device=self.local_rank), dist_reduce_fx="sum")
        self.add_state("num_samples", default=torch.tensor(0.0, dtype=torch.float64, device=self.local_rank), dist_reduce_fx="sum")
        

    def update(self, train_loss: torch.Tensor, num_samples: torch.Tensor) -> None:
        
        if not isinstance(train_loss, torch.Tensor):
            train_loss = torch.tensor(train_loss, dtype=torch.float64, device=self.local_rank)
        else:
            train_loss = train_loss.to(self.local_rank)
        
        if not isinstance(num_samples, torch.Tensor):
            num_samples = torch.tensor(num_samples, dtype=torch.float64, device=self.local_rank)
        else:
            num_samples = num_samples.to(self.local_rank)
        
        self.train_loss += train_loss
        self.num_samples += num_samples

    def compute(self) -> torch.Tensor:
        return self.train_loss / self.num_samples
    


class SuccessMetric(Metric):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.add_state("num_success", default=torch.tensor(0.0, dtype=torch.float64, device=self.local_rank), dist_reduce_fx="sum")
        self.add_state("num_samples", default=torch.tensor(0.0, dtype=torch.float64, device=self.local_rank), dist_reduce_fx="sum")
        

    def update(self, num_success: torch.Tensor, num_samples: torch.Tensor) -> None:
        
        if not isinstance(num_success, torch.Tensor):
            num_success = torch.tensor(num_success, dtype=torch.float64, device=self.local_rank)
        else:
            num_success = num_success.to(self.local_rank)
        
        if not isinstance(num_samples, torch.Tensor):
            num_samples = torch.tensor(num_samples, dtype=torch.float64, device=self.local_rank)
        else:
            num_samples = num_samples.to(self.local_rank)
            

        self.num_success += num_success
        self.num_samples += num_samples

    def compute(self) -> torch.Tensor:
        return self.num_success / self.num_samples