import os
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import json
from datetime import datetime
import pandas as pd

import src.dataset as dataset_utils
from src.model._base_model import reverse_complement, tackle_sequences


def split_by_names(names, sequences):
    result = defaultdict(list)
    for name, seq in zip(names, sequences):
        result[name].append(seq)
    split = [([name] * len(seqs), seqs) for name, seqs in result.items()]
    return split


def reverse_seqs(seqs: list[str], seq_type: str):
    if seq_type == "protein":
        return [seq[::-1] for seq in seqs]
    elif seq_type in ["rna", "dna"]:
        return [reverse_complement(seq, seq_type) for seq in seqs]
    else:
        raise ValueError(f"Invalid sequence type: {seq_type}")


class BaseTask:
    def __init__(self, cfg, model):
        self.cfg = cfg
        self.model = model
        self.metrics = self.cfg.metrics
        self.run()

    def set_cfg(self):
        _cfg = self.cfg.generate_cfg.copy()
        self.generate_cfg = _cfg.copy()
        self.output_dir = self.cfg.output_dir.replace(
            "date", datetime.now().strftime("%Y%m%d_%H%M%S")
        )
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
        print(f"Output Directory: {self.output_dir}")

        hyperparams = {
            attr: getattr(self.cfg, attr)
            for attr in dir(self.cfg)
            if not attr.startswith("__")
        }
        json.dump(
            hyperparams,
            open(f"{self.output_dir}/cfg.json", "w"),
            indent=4,
            ensure_ascii=False,
        )

    def load_existed_seqs(self):
        # Load existed seqs
        existed_seqs = defaultdict(dict)
        if not os.path.exists(self.output_dir):
            return existed_seqs

        for file in os.listdir(self.output_dir):
            if not file.endswith(".csv"):
                continue
            df = pd.read_csv(os.path.join(self.output_dir, file), header=None)
            if len(df) == 0:
                continue
            existed_seqs[file.split(".")[0]] = {
                seq: list(df.iloc[i])[:-1]
                for i, seq in enumerate(df.iloc[:, -1].tolist())
            }  # [time, label, ...predictions based on score_modes...]
        return existed_seqs

    def run(self):
        self.set_cfg()
        existed_seqs = self.load_existed_seqs()

        # Load the dataset
        self.dataset = getattr(dataset_utils, self.cfg.dataset)(**self.cfg.dataset_cfg)
        self.dataloader = dataset_utils.get_dataloader(self.cfg, self.dataset)

        # Run model
        self.run_dataloader(self.dataloader, existed_seqs)

        # Calc metrics
        metrics = self.calc_metrics(is_print=True)
        self.calc_average_metrics(metrics, is_print=True)

    def generate(self, mutated_seqs, _generate_cfg):
        mutated_seqs = getattr(self.model, self.cfg.generate_mode)(
            sequences=mutated_seqs,
            **_generate_cfg,
        )
        return mutated_seqs

    def run_dataloader(self, dataloader, existed_seqs):
        pbar = tqdm(total=len(dataloader.dataset), desc="Pred", dynamic_ncols=True)
        _generate_cfg = self.generate_cfg.copy()
        expanded_seqs = dict()
        cnt, cnt_eval = 0, 0
        for names, data, labels in dataloader:
            # Skip if all seqs are existed
            if all([seq in existed_seqs[name] for name, seq in zip(names, data)]):
                pbar.update(len(names))
                continue

            mutated_batches = split_by_names(names, data)
            preds_for, preds_rev, lengths = [], [], []
            lengths_logits_for, lengths_logits_rev = [], []
            for names_batch, mutated_seqs_batch in mutated_batches:
                assert len(set(names_batch)) == 1, (
                    f"More than one name in the same batch: {names_batch}"
                )
                name = names_batch[0]
                _generate_cfg["with_original"] = True

                # Repeat the sequences if needed
                mutated_seqs_batch = [
                    seq * _generate_cfg.get("repeat", 1) for seq in mutated_seqs_batch
                ]

                # tackle the sequences to the kmer length
                mutated_seqs_batch = tackle_sequences(
                    mutated_seqs_batch, self.model.tokenizer_tackle, self.model.kmer
                )

                # Always use consistent generation: all variants share same generated context
                if name not in expanded_seqs:
                    origin_seq = dataloader.dataset.origins[name]
                    origin_seqs = tackle_sequences(
                        [origin_seq], self.model.tokenizer_tackle, self.model.kmer
                    )
                    if hasattr(self.cfg, "expansion_dict"):
                        assert name in self.cfg.expansion_dict, (
                            f"Name {name} not in expansion_dict"
                        )
                        _generate_cfg["new_length"] = self.cfg.expansion_dict[name][0]
                        _generate_cfg["extra_length"] = self.cfg.expansion_dict[name][1]

                    len_l = _generate_cfg.get("extra_length", 0)
                    len_r = _generate_cfg.get("new_length", 0)

                    # Support fractional lengths as percentage of original sequence length
                    origin_seq_len = len(origin_seqs[0])
                    if isinstance(len_l, float) and 0 < len_l < 1:
                        len_l = int(origin_seq_len * len_l)

                    if isinstance(len_r, float) and 0 < len_r < 1:
                        len_r = int(origin_seq_len * len_r)

                    __generate_cfg = _generate_cfg.copy()
                    __generate_cfg["extra_length"] = len_l
                    __generate_cfg["new_length"] = len_r
                    mutated_seq = self.generate(origin_seqs, __generate_cfg)[0][0]
                    expanded_seqs[name] = [
                        mutated_seq[:len_l],
                        mutated_seq[::-1][:len_r][::-1],
                    ]

                seq_l, seq_r = expanded_seqs[name]
                mutated_seqs_batch = [seq_l + seq + seq_r for seq in mutated_seqs_batch]

                # Get score modes from config to determine what to compute
                score_modes = getattr(self.cfg, "score_modes", ["all", "for", "rev"])

                # Only compute forward prediction if needed
                if any(mode in score_modes for mode in ["all", "for"]):
                    pred_for, len_logits_for = self.model.score_sequences(
                        mutated_seqs_batch,
                    )
                else:
                    pred_for, len_logits_for = [], []

                # Only compute reverse prediction if needed
                if any(mode in score_modes for mode in ["all", "rev"]):
                    pred_rev, len_logits_rev = self.model.score_sequences(
                        reverse_seqs(mutated_seqs_batch, dataloader.dataset.seq_type),
                    )
                else:
                    pred_rev, len_logits_rev = [], []
                preds_for += pred_for
                preds_rev += pred_rev
                lengths_logits_for += len_logits_for
                lengths_logits_rev += len_logits_rev

                lengths += [len(seq) for seq in mutated_seqs_batch]

            # Calculate "all" predictions (average of forward and reverse)
            if len(preds_for) > 0 and len(preds_rev) > 0:
                preds = [
                    (preds_for[i] + preds_rev[i]) / 2 for i in range(len(preds_for))
                ]
            else:
                preds = []

            for i, name in enumerate(names):
                path = f"{self.output_dir}/{name}.csv"
                if not os.path.exists(path):
                    with open(path, "w") as f:
                        f.write("time,label,pred_all,pred_for,pred_rev,sequence\n")
                with open(path, "a") as f:
                    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    # Build CSV line based on actually computed values
                    csv_parts = [now, str(labels[i].item())]
                    csv_parts.append(str(preds[i].item()) if len(preds) > 0 else "")
                    csv_parts.append(str(preds_for[i].item()) if len(preds_for) > 0 else "")
                    csv_parts.append(str(preds_rev[i].item()) if len(preds_rev) > 0 else "")
                    csv_parts.append(data[i])
                    f.write(",".join(csv_parts) + "\n")
            pbar.update(len(names))
            cnt += len(names)

            if cnt - cnt_eval >= getattr(self.cfg, "eval_interval", 2000):
                metrics = self.calc_metrics(is_print=False)
                self.calc_average_metrics(metrics, is_print=False)
                cnt_eval = cnt
        pbar.close()

    def calc_metrics(self, is_print=True):
        results = defaultdict(dict)
        score_map = {"all": "pred_all", "for": "pred_for", "rev": "pred_rev"}
        # Determine score modes to calculate; use defaults if not provided in cfg
        score_modes_to_run = getattr(self.cfg, "score_modes", ["all", "for", "rev"])
        if "all" in score_modes_to_run:
            score_modes_to_run = ["all", "for", "rev"]
        if isinstance(score_modes_to_run, str):
            score_modes_to_run = [score_modes_to_run]

        for file in os.listdir(self.output_dir):
            if not file.endswith('csv'):
                continue

            if file.startswith('rnagym_'):  # skip summary CSVs
                continue

            name = os.path.basename(file)  # extract name from "name.csv"
            csv_path = os.path.join(self.output_dir, file)
            df = pd.read_csv(csv_path)
            _labels_all = np.array(df["label"].tolist(), dtype=np.float32)

            # Iterate requested score modes
            for mode in score_modes_to_run:
                # Select corresponding prediction column by mode
                pred_column = score_map[mode]
                _preds_all = df[pred_column].tolist()
                
                # Calculate all metrics
                for metric in self.metrics:
                    results[name][f"{mode}_{metric}"] = dataset_utils.calc_metrics(_labels_all, _preds_all, metric)

            dataset_utils.report_metrics(
                name,
                results[name],
                self.cfg.use_wandb,
                is_print,
            )
        return results

    def calc_average_metrics(self, metrics, is_print=True):
        all_metrics = defaultdict(list)
        for name in metrics.keys():
            for metric in metrics[name].keys():
                all_metrics[metric].append(metrics[name][metric])

        average_metrics = {}
        for metric in all_metrics.keys():
            average_metrics[metric] = np.mean(all_metrics[metric])

        dataset_utils.report_metrics(
            "Average",
            average_metrics,
            self.cfg.use_wandb,
            is_print,
        )
