import torch
import torch.nn as nn
import time
import datetime
import subprocess 
import pytorch_lightning as pl
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
import os
import numpy as np
import csv

def compute_correlations(labels, preds, return_detail = False):
    device = labels.device
    
    labels = labels.detach().cpu().numpy()
    preds = preds.detach().cpu().numpy()
    corr = np.nan_to_num([np.corrcoef(labels[:,i], preds[:,i])[0,1] for i in range(labels.shape[1])], nan = -1).tolist()
    if return_detail:
        return corr
    corr = np.mean(corr)
    return torch.FloatTensor([corr]).to(device)

def to_gene(gene_corr):
    output = ""
    for i,j in gene_corr:
        output += "[%s : %.4f] " % (i,j)
    return output

def pearsonr(x, y):

    mean_x = torch.mean(x)
    mean_y = torch.mean(y)
    xm = x.sub(mean_x)
    ym = y.sub(mean_y)
    r_num = xm.dot(ym)
    r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
    r_val = r_num / (r_den + 1e-8)
    r_val = torch.nan_to_num(r_val,nan=-1)
    return r_val


class TrainerModel(pl.LightningModule):
    
    def __init__(self, config,  model):
        super().__init__()
        self.model = model
        self.config = config
        self.criterion = nn.MSELoss()
        self.automatic_optimization = False
        self.min_loss  = float("inf")
        self.max_corr  = float("-inf")
        self.max_eval_corr = float("-inf")
        self.min_eval_loss = float("inf")
        self.start_time  = None
        self.last_saved = None
        self._train_preds= [] # add
        self._train_targets = [] # add
        self.lambda_pcc = getattr(config, 'lambda_pcc', 0.5)        # ECA NEW
        self.lambda_gamma = getattr(config, 'lambda_gamma', 0.01)        # ECA NEW
        self._sigma0 = None         # ECA NEW
        self.es_patience = int(getattr(self.config, "early_stop_patience", 3))
        self.es_eps_pcc  = float(getattr(self.config, "early_stop_eps_pcc", 0.0))
        self.es_eps_mse  = float(getattr(self.config, "early_stop_eps_mse", 0.0))
        self._es_best_pcc = float("-inf") 
        self._es_best_mse = float("inf")   
        self._es_bad_streak = 0            
        
        # CSV logging (one file for both train and val)
        self.csv_path = os.path.join(config.store_dir, "metrics.csv")
        if self.trainer is None or getattr(self.trainer, "is_global_zero", True):
            os.makedirs(config.store_dir, exist_ok=True)

            if not os.path.exists(self.csv_path):
                with open(self.csv_path, mode="w", newline="") as f:
                    writer = csv.writer(f)
                    writer.writerow([
                        "epoch", "split", "when",
                        "PCC_F", "PCC_S", "PCC_M", "MSE"
                    ])

            # (2) 同目录保存 args.txt  # NEW
            args_txt_path = os.path.join(config.store_dir, "args.txt")  # NEW

            def _to_str(v): 
                if isinstance(v, (str, int, float, bool)) or v is None:
                    return str(v)
                if isinstance(v, (list, tuple)):
                    return "[" + ", ".join(_to_str(x) for x in v) + "]"
                if isinstance(v, dict):
                    return "{" + ", ".join(f"{k}: {_to_str(val)}" for k, val in v.items()) + "}"
                return repr(v)

            try:
                raw = dict(vars(config))
            except Exception:
                raw = {k: getattr(config, k) for k in dir(config)
                    if not k.startswith("_") and not callable(getattr(config, k))}

            lines = [f"{k} = {_to_str(v)}" for k, v in sorted(raw.items())]

            cli_parts = []
            for k, v in sorted(raw.items()):
                if isinstance(v, bool):
                    if v:
                        cli_parts.append(f"--{k}")
                elif isinstance(v, (int, float, str)):
                    cli_parts.append(f"--{k} {v}")
            cli_line = " ".join(cli_parts)

            if not os.path.exists(args_txt_path):
                with open(args_txt_path, "w", encoding="utf-8") as f:
                    f.write("# Args key=value\n")
                    f.write("\n".join(lines))
                    f.write("\n\n# Approximate CLI\n")
                    f.write(cli_line + "\n")
    
    def _append_csv(self, row):
        if self.trainer.is_global_zero:
            with open(self.csv_path, mode="a", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(row)

    @property
    def num_training_steps(self) -> int:
        """Total training steps inferred from datamodule and devices."""
        dataset =  self.trainer._data_connector._train_dataloader_source.dataloader() 
        num_devices = max(1, self.trainer.num_gpus, self.trainer.num_processes) * self.trainer.num_nodes
        return len(dataset) // num_devices
    
    def correlationMetric(self,x, y):
      corr = 0
      for idx in range(x.size(1)):
          corr += pearsonr(x[:,idx], y[:,idx])
      corr /= (idx + 1)
      return (1 - corr).mean()
    
    def on_train_epoch_start(self): # add
        self._train_preds.clear() #add
        self._train_targets.clear() #add
        
    def training_step(self,data,idx):
        
        if self.current_epoch == 0 and idx == 0:
            self.start_time  = time.time()
        
        optimizer = self.optimizers()
        
        # pred_count = self.model(data)
        # loss   = self.criterion(pred_count,data["count"])
        # corrloss = self.correlationMetric(pred_count,data["count"])
        # self._train_preds.append(pred_count.detach().cpu())     # NEW
        # self._train_targets.append(data["count"].detach().cpu())# NEW
        # optimizer.zero_grad()
        # self.manual_backward(loss + corrloss * 0.5)
        out = self.model(data)        # ECA NEW
        if isinstance(out, tuple):        # ECA NEW
            pred_count=out[0]        # ECA NEW
        else:
            pred_count = out         # ECA NEW
        if self._sigma0 is None:        # ECA NEW
            with torch.no_grad():        # ECA NEW
                self._sigma0 = data["count"].std()        # ECA NEW
        mse = self.criterion(pred_count, data["count"])        # ECA NEW
        dn_pcc = self._dn_pcc(data["count"], pred_count)        # ECA NEW
        gamma = getattr(self.model, "_eca_gamma", None)
        gamma_reg = 0.0        # ECA NEW
        if gamma is not None:        # ECA NEW
            gamma_reg = ((gamma - 1.0) ** 2).mean() * self.lambda_gamma        # ECA NEW
        total_loss = mse + self.lambda_pcc * dn_pcc + gamma_reg        # ECA NEW
        
        corrloss = self.correlationMetric(pred_count, data["count"])        # ECA NEW
        self._train_preds.append(pred_count.detach().cpu())        # ECA NEW
        self._train_targets.append(data["count"].detach().cpu())        # ECA NEW
        optimizer = self.optimizers()        # ECA NEW
        optimizer.zero_grad()
        self.manual_backward(total_loss)        # ECA NEW
        optimizer.step()
        
        lr_scheduler = self.lr_schedulers()
        lr_scheduler.step()
        
        self.produce_log(mse.detach(),corrloss.detach(),data,idx)        # ECA NEW
        
        
    def produce_log(self,loss,corr,data,idx):
        
        train_loss = self.all_gather(loss).mean().item()
        train_corr = self.all_gather(corr).mean().item()
        
        self.min_loss   = min(self.min_loss, train_loss)
        
        if self.trainer.is_global_zero and loss.device.index == 0 and idx % self.config.verbose_step == 0:
            
            current_lr = self.optimizers().param_groups[0]['lr']
            
            len_loader = self.num_training_steps
            
            batches_done = self.current_epoch  * len_loader + idx + 1
            batches_left = self.trainer.max_epochs * len_loader - batches_done
            time_left    = datetime.timedelta(seconds = batches_left * (time.time() - self.start_time) / batches_done)
                    
            self.config.logfun(
                        "[Epoch %d/%d] [Batch %d/%d] [Loss: %f, 1 - Corr: %f, PCC: %f, lr: %f] [Min Loss: %f] ETA: %s" % 
                        (self.current_epoch,
                         self.trainer.max_epochs,
                         idx,
                         len_loader,
                         train_loss,
                         train_corr,
                         1-train_corr,
                         current_lr,
                         self.min_loss,
                         time_left
                            )
                        
                        )
            

        
    def validation_step(self,data,idx):
        pred_count = self.model(data)
        return pred_count,data["count"]
    
    def training_epoch_end(self, outputs): # mew
        logfun = self.config.logfun
        if len(self._train_preds) == 0:
            return
        pred = torch.cat(self._train_preds, dim=0).to(self.device)
        targ = torch.cat(self._train_targets, dim=0).to(self.device)

        pred = self.all_gather(pred).view(-1, pred.shape[1])
        targ = self.all_gather(targ).view(-1, targ.shape[1])

        train_mse = self.criterion(pred, targ).item()

        gene_corr = compute_correlations(targ, pred, True)  # list[float]
        arr = np.asarray(gene_corr, dtype=np.float32)
        pcc_f = float(np.quantile(arr, 0.25))  # PCC@F
        pcc_s = float(np.median(arr))          # PCC@S
        pcc_m = float(np.mean(arr))            # PCC@M

        if self.trainer.is_global_zero:
            logfun("\n [TRAINEND][epoch %d] PCC@F: %.4f | PCC@S: %.4f | PCC@M: %.4f | MSE: %.6f \n"
                % (self.current_epoch, pcc_f, pcc_s, pcc_m, train_mse))
            self._append_csv([
                int(self.current_epoch),
                "train",
                "end",
                pcc_f, pcc_s, pcc_m, train_mse
            ])
        self._train_preds.clear()
        self._train_targets.clear()
            
            
    def validation_epoch_end(self,outputs):
        
        logfun = self.config.logfun
        
        pred_count = torch.stack([i[0] for i in outputs])
        count = torch.stack([i[1] for i in outputs])
        pred_count = self.all_gather(pred_count).view(-1,250)
        count = self.all_gather(count).view(-1,250)
        
        total_loss = self.criterion(pred_count,count).item()
        val_pcc_m = pcc_m        
        val_mse   = total_loss

        pcc_improved = (val_pcc_m > self._es_best_pcc + self.es_eps_pcc)
        mse_improved = (val_mse   < self._es_best_mse - self.es_eps_mse)

        if pcc_improved or mse_improved:
            if pcc_improved:
                self._es_best_pcc = float(val_pcc_m)
            if mse_improved:
                self._es_best_mse = float(val_mse)
            self._es_bad_streak = 0
        else:
            pcc_worse = (val_pcc_m < self._es_best_pcc - self.es_eps_pcc)
            mse_worse = (val_mse   > self._es_best_mse + self.es_eps_mse)
            if pcc_worse or mse_worse:
                self._es_bad_streak += 1

        if self._es_bad_streak >= self.es_patience:
            if self.trainer.is_global_zero:
                logfun(
                    f"[EarlyStop] epoch {self.current_epoch}: "
                    f"no improvement vs best for {self._es_bad_streak} consecutive validations. "
                    f"(best PCC_M={self._es_best_pcc:.4f}, best MSE={self._es_best_mse:.6f}; "
                    f"now PCC_M={val_pcc_m:.4f}, MSE={val_mse:.6f})"
                )
            self.trainer.should_stop = True
            return

        self._es_prev_pcc = float(val_pcc_m)
        self._es_prev_mse = float(val_mse)
        
        
        gene_corr_ori = compute_correlations(count, pred_count, True)
        corr = np.mean(gene_corr_ori)
        gene_corr = sorted(list(zip(self.config.filter_name, gene_corr_ori)), key = lambda x : x[1])
        
        arr   = np.asarray(gene_corr_ori, dtype=np.float32)
        pcc_f = float(np.quantile(arr, 0.25))   # PCC@F: quantile
        pcc_s = float(np.median(arr))           # PCC@S: median
        pcc_m = float(np.mean(arr))         # PCC@M: mean

        
        if self.trainer.is_global_zero and self.trainer.num_gpus != 0:
            for line in subprocess.check_output(["nvidia-smi"]).decode("utf-8").split("\n"):
                self.config.logfun(line)
            
            if corr > self.max_eval_corr:
                 self.save(self.current_epoch, total_loss,corr)                    
            self.max_eval_corr = max(self.max_eval_corr,corr)
            self.min_eval_loss = min(self.min_eval_loss, total_loss)
                            
            logfun("==" * 25)
            logfun(
                "[VALIDEND][Corr :%f, Loss: %f] [Min Loss :%f, Max Corr: %f PCC@F: %f, PCC@S: %f, PCC@M: %f]" %
                (corr,
                 total_loss,
                 self.min_eval_loss,
                 self.max_eval_corr,
                 pcc_f,
                 pcc_s,
                 pcc_m
                 )
                )        
            self._append_csv([
                int(self.current_epoch),
                "val",
                "end",
                pcc_f, pcc_s, pcc_m, total_loss
            ])    
            logfun("Top 10 gene corr")
            logfun(to_gene(gene_corr[-10:][::-1]))
            logfun("==" * 25)
            logfun("End Evaluation")
        
    def save(self, epoch,loss, acc):
        
        self.config.logfun(self.last_saved)
        if self.last_saved != None:
            os.remove(self.last_saved)
        output_path = os.path.join(self.config.store_dir, "%d_%f_%f.pt" % (epoch,loss,acc))
        self.last_saved = output_path
        torch.save(self.model.state_dict(), output_path)
        self.config.logfun("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
                      
    def configure_optimizers(self):
        
        optimizer = torch.optim.AdamW(
                            self.parameters(),
                            lr = self.config.lr,
                            betas = (0.9, 0.999),
                            weight_decay = self.config.weight_decay,
            )
        lr_scheduler = CosineAnnealingWarmupRestarts(
            optimizer,
            first_cycle_steps=self.num_training_steps//2,
            cycle_mult=1.0,
            max_lr=self.config.lr,
            min_lr=self.config.lr/5,
            warmup_steps= 100
        )
        
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}        
                   
    def _pcc_loss(self, y_true, y_pred):        # ECA NEW
        # (1 - PCC) average column        # ECA NEW
        y_true_c = y_true - y_true.mean(dim=0, keepdim=True)        # ECA NEW
        y_pred_c = y_pred - y_pred.mean(dim=0, keepdim=True)        # ECA NEW
        num = (y_true_c * y_pred_c).sum(dim=0)        # ECA NEW
        den = (y_true_c.pow(2).sum(dim=0).sqrt() * y_pred_c.pow(2).sum(dim=0).sqrt() + 1e-8)        # ECA NEW
        pcc = (num / den).mean()        # ECA NEW
        return (1 - pcc)        # ECA NEW

    def _dn_pcc(self, y_true, y_pred):        # ECA NEW
        base = self._pcc_loss(y_true, y_pred)        # ECA NEW
        # stopgrad scaling        # ECA NEW
        with torch.no_grad():        # ECA NEW
            sigma_hat = y_pred.std()        # ECA NEW
            if self._sigma0 is None or self._sigma0 <= 0:        # ECA NEW
                self._sigma0 = y_true.std()        # ECA NEW
        scale = (sigma_hat / (self._sigma0 + 1e-12)).detach()        # ECA NEW
        return base * scale        # ECA NEW
                        
                        
                        
        
     