import os
import os.path as osp
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lrs
from omegaconf import OmegaConf
import pytorch_lightning as pl
from src.tools.geo_utils import batch_rmsd
from src.chroma.data import Protein
from src.models.refiner_pr import RefinerPR
from src.datasets.bio_tokenizer import BioTokenizer
from src.tools.perturbation import perturb_structure


class RefinerPR_ITF(pl.LightningModule):
    def __init__(self, **kargs):
        super().__init__()
        self.save_hyperparameters()
        self.load_model()
        self.val_metrics = {"rec": [], "refined_rec": [], "rmsd": [], "rmsd_bb": [], "rmsd_sc": []}
        self.test_metrics = {"rec": [], "refined_rec": [], "rmsd": []}
        self.val_per_sample, self.test_per_sample = [], []
        if self.hparams.loss_dump:
            self.loss_dump_dict={}
        if self.hparams.features_dump:
            self.features_dump_dict={"residues":[],"gt_features":[],"af_features":[]}
        self.predict_dict={'name':[],"seq":[],'coords':[],"afdb_coords":[],"valid_mask":[]}
        self.ce_loss = nn.CrossEntropyLoss()
        self.tokenizer = BioTokenizer()
        
    def configure_optimizers(self):
        weight_decay = getattr(self.hparams, 'weight_decay', 0)
        optimizer = torch.optim.AdamW(self.refiner.parameters(), lr=self.hparams.lr, weight_decay=weight_decay, \
                            betas=(0.9, 0.98), eps=1e-8)
        scheduler = {"scheduler": lrs.CosineAnnealingLR(optimizer, T_max=self.hparams.lr_decay_steps), "interval": "step"}
        return [optimizer], [scheduler]

    def forward(self, batch, batch_idx):
        pass # will release the complete code upon the acceptance

    def training_step(self, batch, batch_idx, **kwargs):
        ret = self(batch, batch_idx)
        loss = ret['loss']
        self.log("train_loss_str", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss 

    def validation_step(self, batch, batch_idx):
        ret = self(batch, batch_idx)
        pr_X = ret['preds']
        gt_X, batch_id = batch['X'][:, :4], batch['batch_id']

        if self.trainer.current_epoch % self.hparams.log_epoch == 0:
            if self.trainer.current_epoch == 0:
                gt_save_dir = osp.join(self.hparams.savedir, 'pdb_gt-' + str(self.trainer.current_epoch))
                os.makedirs(gt_save_dir, exist_ok=True)

            rmsd_bb, rmsd_sc, rmsd = batch_rmsd(pr_X, gt_X, batch_id, bb_sc=(gt_X.shape[-2]==14))
            self.val_metrics['rmsd_bb'].extend(rmsd_bb)
            self.val_metrics['rmsd_sc'].extend(rmsd_sc)
            self.val_metrics['rmsd'].extend(rmsd)
            pdb_save_dir = osp.join(self.hparams.savedir, 'pdb_val-' + str(self.trainer.current_epoch))
            os.makedirs(pdb_save_dir, exist_ok=True)
            
            for i, pdb_id in enumerate(batch['title']):
                bmask = batch['batch_id'] == i
                x, c, s = (
                    gt_X[bmask].unsqueeze(0),
                    batch["C"][bmask].unsqueeze(0),
                    batch["S"][bmask].unsqueeze(0),
                )
                try:
                    if self.trainer.current_epoch == 0:
                        gt_prot = Protein.from_XCS(x, c, s)
                        gt_prot.to_PDB(osp.join(gt_save_dir, pdb_id + '_gt.pdb'))
                    pr_x = pr_X[bmask].unsqueeze(0)
                    pr_prot = Protein.from_XCS(pr_x, c, s)
                    pr_prot.to_PDB(osp.join(pdb_save_dir, pdb_id + '_pr.pdb'))

                    per_sample_info = {"pdb id": pdb_id, "rmsd_bb": rmsd_bb[i]}
                    self.val_per_sample.append(per_sample_info)
                except Exception as e:
                    print(f"{pdb_id}, an error occurred: {e}")
                    continue

        log_dict = {'val_' + k: v.item() for k, v in ret.items() if k in ['loss']}
        self.log_dict(log_dict)
        return self.log_dict

    def test_step(self, batch, batch_idx):
        ret = self(batch, batch_idx)
        # ref_X, pre_logit, aft_logit = ret['preds'][-1], ret['pre_logit'], ret['aft_logit']
        batch_id = batch['batch_id']
        if self.hparams.loss_type=="rmsd":
            pr_X = ret['preds']
            gt_X, batch_id = batch['X'][:, :4], batch['batch_id']
            rmsd_bb, rmsd_sc, rmsd = batch_rmsd(pr_X, gt_X, batch_id, bb_sc=(gt_X.shape[-2]==14))
        else:
            loss = ret['loss'] 
        if self.hparams.loss_dump:
            rmsd_all_exists= (gt_X.shape[-2]==14)
            for index,i in enumerate(batch["title"]):
                if rmsd_all_exists:
                    self.loss_dump_dict[i]=rmsd[index]
                else:
                    self.loss_dump_dict[i]=rmsd_bb[index]

    def on_validation_epoch_end(self):
        if self.trainer.current_epoch % self.hparams.log_epoch == 0:
            # median_rec = np.median(self.val_metrics['rec'])
            # median_ref_rec = np.median(self.val_metrics['refined_rec'])
            avg_rmsd = np.median(self.val_metrics['rmsd'])
            avg_rmsd_bb = np.median(self.val_metrics['rmsd_bb'])
            avg_rmsd_sc = np.median(self.val_metrics['rmsd_sc'])
            # self.log("median_rec", median_rec, on_step=False, on_epoch=True, prog_bar=True)
            # self.log("median_ref_rec", median_ref_rec, on_step=False, on_epoch=True, prog_bar=True)
            self.log("avg_rmsd", avg_rmsd, on_step=False, on_epoch=True, prog_bar=True)
            self.log("avg_rmsd_bb", avg_rmsd_bb, on_step=False, on_epoch=True, prog_bar=True)
            self.log("avg_rmsd_sc", avg_rmsd_sc, on_step=False, on_epoch=True, prog_bar=True)
            self.val_metrics = {key: [] for key in self.val_metrics}

            pdb_save_dir = osp.join(self.hparams.savedir, 'pdb_val-' + str(self.trainer.current_epoch))
            with open(osp.join(pdb_save_dir, 'log.json'), 'w', encoding='utf-8') as f:
                json.dump(self.val_per_sample, f, ensure_ascii=False, indent=4)
            self.val_per_sample = []

    def predict_step(self, batch, batch_idx, dataloader_idx = 0):
        def restore_batch_length(data, split_indices):
            restored = []
            start_idx=0
            for length in split_indices:
                restored.append(data[start_idx:start_idx + length].tolist())
                start_idx += length
            return restored
        
        ret = self(batch, batch_idx)
        batch_predict = batch
        batch_predict["afX"]=ret["preds"]
        split_indices = batch_predict["num_nodes"].tolist()
        batch_predict["S"] = self.tokenizer.decode(restore_batch_length(batch_predict["S"], split_indices))
        batch_predict["X"] = restore_batch_length(batch_predict["X"], split_indices)
        batch_predict["afX"] = restore_batch_length(batch_predict["afX"], split_indices)
        batch_predict["valid_mask"] = restore_batch_length(batch_predict["mask"], split_indices)

        self.predict_dict["name"].extend(batch_predict["title"])
        self.predict_dict["seq"].extend(batch_predict["S"])
        self.predict_dict["coords"].extend(batch_predict["X"])
        self.predict_dict["afdb_coords"].extend(batch_predict["afX"])
        self.predict_dict["valid_mask"].extend(batch_predict["valid_mask"])
        if "labels" in batch_predict:
            batch_predict["labels"] = restore_batch_length(batch_predict["labels"], split_indices)
            self.predict_dict["labels"].extend(batch_predict["labels"])

    
    def cal_metric(self):
        if self.trainer.global_rank == 0:
            median_rec = np.median(self.test_metrics['rec'])
            median_ref_rec = np.median(self.test_metrics['refined_rec'])
            avg_rmsd = np.median(self.test_metrics['rmsd'])
            print(f"median_rec: {median_rec:.4f}, median_ref_rec: {median_ref_rec:.4f}, avg_rmsd: {avg_rmsd:.4f}")
            os.makedirs(self.hparams.savedir, exist_ok=True)
            with open(osp.join(self.hparams.savedir, 'pdb_test.json'), 'w', encoding='utf-8') as f:
                json.dump(self.test_per_sample, f, ensure_ascii=False, indent=4)
            if self.hparams.loss_dump:
                dump_loss_filepath = osp.join(self.hparams.savedir,self.hparams.loss_dump)
                with open(dump_loss_filepath,"w") as f:
                    json.dump(self.loss_dump_dict,f)
            if self.hparams.features_dump:
                dump_feature_filepath = osp.join(self.hparams.savedir,self.hparams.features_dump)
                with open(dump_feature_filepath,"w") as f:
                    json.dump(self.features_dump_dict,f)
            if self.hparams.predict:
                predict_filepath = osp.join(self.hparams.savedir,self.hparams.predict)
                data_length=len(self.predict_dict["name"])
                print(data_length)
                with open(predict_filepath,"w") as f:                
                    for i in range(data_length):
                        record = {key: value[i] for key, value in self.predict_dict.items()}
                        f.write(json.dumps(record) + '\n')

    def load_model(self):
        params = OmegaConf.load(f'src/models/configs/default.yaml')
        params.update(self.hparams)
        self.refiner = RefinerPR(params)