import pytorch_lightning as PL
import torch
from metrics import psnr, ssim, ciede2000
import os.path as osp
from PIL import Image
import os
from typing import Dict, Any, List
import csv
import warnings
try:
    import pyiqa
except:
    None


class Base_Trainer(PL.LightningModule):
    def __init__(self,cfg):
        super().__init__()
        self.save_hyperparameters('cfg')
        self.val_metrics = MetricTracker()
        self.test_metrics = MetricTracker()
        self.pred_metrics = MetricTracker()
    
    def inference(self, x):
        pred = self.forward(x)
        return pred[-1] if isinstance(pred, list) else pred

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        smoky, real_clear = batch
        pred_clear = self.inference(smoky)

        ssim_vals = self._cal_ssim(pred_clear, real_clear)
        psnr_vals = self._cal_psnr(pred_clear, real_clear)
        ciede_vals = self._cal_ciede(pred_clear, real_clear)
        self.val_metrics.update_metrics(ssim=ssim_vals, psnr=psnr_vals,ciede2000=ciede_vals)

    def on_validation_epoch_end(self):
        metrics = self.val_metrics.get_metrics()

        avg_ssim = torch.tensor(metrics.get("ssim", [0.0])).mean().item()
        avg_psnr = torch.tensor(metrics.get("psnr", [0.0])).mean().item()
        avg_ciede = torch.tensor(metrics.get("ciede2000", [0.0])).mean().item()
        log_dict = {
            'val_metrics/ssim': avg_ssim,
            'val_metrics/psnr': avg_psnr,
            'val_metrics/ciede': avg_ciede,
        }
  
        self.log_dict(log_dict, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.val_metrics = MetricTracker()  

    def test_step(self, batch, batch_idx):
        smoky, real_clear = batch
        pred_clear = self.inference(smoky)
        self.test_metrics.update_metrics(
            ssim=self._cal_ssim(pred_clear, real_clear),
            psnr=self._cal_psnr(pred_clear, real_clear),
            ciede=self._cal_ciede(pred_clear, real_clear),
        )

    def on_test_epoch_end(self):
        metrics = self.test_metrics.get_metrics()
        def get_avg_std(key):
            vals = torch.tensor(metrics.get(key, [0.0]))
            return vals.mean().item(), vals.std().item()

        avg_ssim, std_ssim = get_avg_std("ssim")
        avg_psnr, std_psnr = get_avg_std("psnr")
        avg_ciede, std_ciede = get_avg_std("ciede")

        log_dict = {
            'test_metrics/ssim': avg_ssim,
            'test_metrics/psnr': avg_psnr,
            'test_metrics/ciede': avg_ciede,
            'test_metrics/ssim_std': std_ssim,
            'test_metrics/psnr_std': std_psnr,
            'test_metrics/ciede_std': std_ciede,
        }
        self.log_dict(log_dict, on_epoch=True, prog_bar=True, logger=True)
        self.test_metrics = MetricTracker()  # 清空
            
    def predict_step(self, batch, batch_idx):
        smoky, smoky_path = batch
        pred_clear = self.inference(smoky)
        # metric_results = self.cal_no_ref_metrics(pred_clear)
        # self.pred_metrics.update_metrics(**metric_results)
        for idx in range(smoky.shape[0]):
            img_name = osp.basename(smoky_path[idx])
            self._save_image(pred_clear[idx],img_name)
        return 
    
    # def on_predict_epoch_end(self, results: List[Any] = None) -> None:
    #     self._save_metrics_csv(self.pred_metrics.get_metrics())
    #     self.pred_metrics = MetricTracker()      
              
    def _cal_ssim(self, img_tensor_1, img_tensor_2):
        ssim_vals = []
        for i in range(img_tensor_1.shape[0]):
            ssim_val = ssim(img_tensor_1[i].unsqueeze(
                0), img_tensor_2[i].unsqueeze(0)).item()
            ssim_vals.append(ssim_val)
        return ssim_vals

    def _cal_psnr(self, img_tensor_1, img_tensor_2):
        psnr_vals = []
        for i in range(img_tensor_1.shape[0]):
            psnr_val = psnr(img_tensor_1[i].unsqueeze(
                0), img_tensor_2[i].unsqueeze(0))
            psnr_vals.append(psnr_val)
        return psnr_vals

    def _cal_ciede(self, img_tensor_1, img_tensor_2):
        ciede_vals = []
        for i in range(img_tensor_1.shape[0]):
            ciede_val = ciede2000(img_tensor_1[i], img_tensor_2[i])
            ciede_vals.append(ciede_val)
        return ciede_vals
    
    def cal_no_ref_metrics(
            self,
        img_tensor: torch.Tensor,  # 输入张量 [B,C,H,W]
        metric_names: List[str] = ['brisque_matlab','niqe_matlab','piqe','pi','nrqm'],
        clamp_input: bool = True,   # 是否强制限制输入到[0,1]
        safe_mode: bool = True      # 是否跳过错误指标
    ) -> Dict[str, List[float]]:
        """
        Args:
            img_tensor: 模型输出张量（自动处理值范围）
            metric_names: 需计算的指标（如 'brisque', 'clipiqa', 'niqe'）
            clamp_input: 为True时强制截断输入到[0,1]范围
            safe_mode: 为True时跳过初始化失败的指标
        
        Returns:
            字典 {指标名: 分数列表}，分数为普通Python数值
        """
        # 1. 输入验证和值范围处理
        assert isinstance(img_tensor, torch.Tensor), "输入必须是torch.Tensor"
        assert img_tensor.dim() == 4, "输入形状应为[B,C,H,W]"
        
        if clamp_input:
            img_tensor = img_tensor.clamp(0, 1)  # 强制截断到[0,1]
        else:
            if img_tensor.min() < 0 or img_tensor.max() > 1:
                warnings.warn("输入值超出[0,1]范围，可能导致IQA计算错误！")

        # 2. 统一转为float32（避免模型输出可能是float16）
        img_tensor = img_tensor.to(torch.float32)
        
        # 3. 计算指标
        device = img_tensor.device
        results = {}
        
        for name in metric_names:
            try:
                metric = pyiqa.create_metric(name, device=device)
                scores = metric(img_tensor)
                results[name] = scores.cpu().numpy().tolist()
                    
            except Exception as e:
                if safe_mode:
                    warnings.warn(f"指标 {name} 被跳过（错误：{str(e)}）")
                    continue
                raise
        
        return results
    
    def _save_image(self, img_tensor, name):
        narr = img_tensor.detach().mul_(255).clamp_(0,255).permute(1,2,0).to('cpu', torch.uint8).numpy()
        pil_img = Image.fromarray(narr)
        path = self.hparams.cfg.common.path_to_save_image
        os.makedirs(path,exist_ok=True)
        pil_img.save(osp.join(path,name))
        
    def _save_metrics_csv(self, metrics: Dict[str, list]):
        save_dir = self.hparams.cfg.common.path_to_save_image
        os.makedirs(save_dir, exist_ok=True)
        csv_path = os.path.join(save_dir, f"metrics.csv")

        # 计算均值和标准差
        stats = {}
        for key, vals in metrics.items():
            vals_tensor = torch.tensor(vals)
            stats[key] = (vals_tensor.mean().item(), vals_tensor.std().item())

        # 写入 CSV 文件
        with open(csv_path, mode='w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["Metric", "Mean", "Std"])
            for metric_name, (mean, std) in stats.items():
                writer.writerow([metric_name, f"{mean:.6f}", f"{std:.6f}"])


class MetricTracker:
    def __init__(self,metrics=None):
        # 初始化一个空字典来存储不同类型的指标
        if metrics is None:
            metrics = {}
        self.metrics = metrics

    def update_metrics(self, **new_metrics: Dict[str, Any]):
        """
        更新指标，支持传入多个指标（如 SSIM, PSNR 等）。
        :param new_metrics: 新的指标，键是指标名称，值是指标值（可以是单个值或列表）。
        """
        for key, value in new_metrics.items():
            if key not in self.metrics:
                self.metrics[key] = []
            # 如果是列表，直接扩展，否则添加单个值
            if isinstance(value, list):
                self.metrics[key].extend(value)
            else:
                self.metrics[key].append(value)

    def get_metrics(self) -> Dict[str, list]:
        """
        返回当前存储的所有指标。
        """
        return self.metrics