from _common import *

log = logging.getLogger(__name__)

from collections import defaultdict
from peft import PeftModel
import lightning as L
import lightning.pytorch as pl
import peft
from flan_t5_checkpoint_path import finetuned_model_path
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, default_data_collator

from datasets import DatasetDict, load_dataset, load_from_disk
from src.tasks.arithmetic import state_dict_sub
from src.utils import num_devices

# disable tokenizers parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def remove_special_tokens(tokenizer, token_list: list):
    """
    This function removes special tokens from a list of tokens. It also stops processing
    when it encounters a token with a value of -100.

    Parameters:
        tokenizer (Tokenizer): The tokenizer object used for tokenizing text.
        token_list (list): The list of tokens to be processed.

    Returns:
        list: The list of tokens after removing special tokens.
    """
    ret = []
    for token in token_list:
        if token not in tokenizer.all_special_ids and token > 0:
            ret.append(token)
        if token == -100:
            break
    return ret


def evaluate_accuracy(model, val_loader: DataLoader, tokenizer):
    """
    This function evaluates the accuracy of a language model on a validation set.

    Parameters:
        model (nn.Module): The language model to be evaluated.
        val_loader (DataLoader): The DataLoader object containing the validation data.
        tokenizer (Tokenizer): The tokenizer object used for tokenizing text.

    Returns:
        float: The accuracy of the model on the validation set.
    """
    from tqdm import tqdm

    correct = 0
    total = 0

    model = model.eval()
    for batch_idx, batch in enumerate(tqdm(val_loader, desc="round", leave=False)):
        with torch.no_grad():
            outputs = model.generate(batch["input_ids"])
            output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)

            labels = [remove_special_tokens(tokenizer, l) for l in batch["labels"]]
            labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            # compare output_text and labels
            for i, j in zip(output_text, labels):
                if i == j:
                    correct += 1
                total += 1

    # return accuracy
    return correct / total


def evaluate_spearman_rho(model, val_loader: DataLoader, tokenizer):
    """
    This function evaluates the Spearman's rank correlation coefficient (rho) between the model's predictions and the actual labels on a validation set.

    Parameters:
        model (nn.Module): The language model to be evaluated.
        val_loader (DataLoader): The DataLoader object containing the validation data.
        tokenizer (Tokenizer): The tokenizer object used for tokenizing text.

    Returns:
        float: The Spearman's rho between the model's predictions and the actual labels.
    """
    from tqdm import tqdm

    model = model.eval()
    all_preds: List[str] = []
    all_labels: List[str] = []
    for batch_idx, batch in enumerate(tqdm(val_loader, desc="round", leave=False)):
        with torch.no_grad():
            outputs = model.generate(batch["input_ids"])
            output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)

            labels = [remove_special_tokens(tokenizer, l) for l in batch["labels"]]
            labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            all_preds.extend(output_text)
            all_labels.extend(labels)

    # save `all_preds` and `all_labels`
    # with open("temp/all_preds.txt", "w") as f:
    #     for preds in all_preds:
    #         for pred in preds:
    #             f.write(pred + "\n")
    # with open("temp/all_labels.txt", "w") as f:
    #     for labels in all_labels:
    #         for label in labels:
    #             f.write(label + "\n")

    # calculate spearman's rho
    # 1. convert string list `all_preds` and `all_labels` to numpy array
    # 2. compute spearman's rho
    from scipy.stats import spearmanr

    def parse_flost(s: str):
        import math

        try:
            return float(s)
        except:
            return 0.0

    all_preds = np.array([parse_flost(pred) for pred in all_preds])
    all_labels = np.array([parse_flost(label) for label in all_labels])
    rho = spearmanr(all_preds, all_labels)[0]
    return rho


# metric_func is a dictionary that maps task names to their respective evaluation functions.
# By default, it uses the evaluate_accuracy function for all tasks.
# However, for the "glue-stsb" task, it uses the evaluate_spearman_rho function.
metric_func: Dict[str, Any] = defaultdict(lambda: evaluate_accuracy)
metric_func["glue-stsb"] = evaluate_spearman_rho


class Program:
    def __init__(self, cfg: DictConfig):
        self.cfg = cfg
        if hasattr(cfg, "seed") and cfg.seed is not None:
            log.info(f"set seed to {cfg.seed}")
            L.seed_everything(cfg.seed)
        if cfg.peft.peft_config is None:
            self.results_dir = RESULTS_DIR / cfg.model.name
        else:
            self.results_dir = RESULTS_DIR / (cfg.model.name + "_" + cfg.peft.name)
        self.results_dir.mkdir(parents=True, exist_ok=True)
        self.results_path = self.results_dir / "individuals.csv"

        self.fabric = L.Fabric(accelerator="cuda", devices=1)
        self.fabric.launch()

    def run(self):
        self.load_models()
        self.load_datasets()

        self.eval_individuals()

    def eval_individuals(self):
        results = defaultdict(list)

        log.info("Evaluating pretrained model")
        model = deepcopy(self.pretrained_model)
        model = self.fabric.setup_module(model)
        _results = self.eval_model_on_datasets(model)
        results["model"].append("pretrained")
        for dataset_name, score in _results.items():
            results[dataset_name].append(score)

        for dataset_name, task_vector in track(
            zip(self.cfg.test_datasets, self.task_vectors),
            "Evaluating fine-tuned models",
        ):
            if self.cfg.peft.peft_config is None:
                model = deepcopy(self.pretrained_model)
                for n, p in task_vector.items():
                    model.state_dict()[n] += p
            else:
                model = deepcopy(self.pretrained_model)
                peft_config = instantiate(self.cfg.peft.peft_config)
                peft_config.target_modules = list(peft_config.target_modules)
                model = peft.get_peft_model(model, peft_config)

                lora_state_dict = {
                    f"base_model.model.{k}": v 
                    for k, v in task_vector.items()
                }
                model.load_state_dict(lora_state_dict, strict=False)
                
            model = self.fabric.setup_module(model)
            _results = self.eval_model_on_datasets(model)
            results["model"].append(dataset_name)
            for k, v in _results.items():
                results[k].append(v)
                
            print(pd.DataFrame(results))

        pd.DataFrame(results).to_csv(self.results_path, index=False)

    @torch.no_grad()
    def eval_model_on_datasets(self, model):
        """
        Evaluates the given model on the test datasets specified in the configuration.
        It iterates over each test dataset and its corresponding data loader, computes the metric for the dataset,
        and logs the score. If the `fast_dev_run` configuration option is set, it only evaluates on one batch per dataset.
        The results are returned as a dictionary mapping dataset names to scores.

        Args:
            model (torch.nn.Module): The model to evaluate.

        Returns:
            dict: A dictionary mapping dataset names to evaluation scores.
        """
        results = {}
        model.eval()
        for dataset_name, test_loader in tqdm(
            zip(self.cfg.test_datasets, self.test_loaders),
            desc="evaluating",
            total=len(self.cfg.test_datasets),
            leave=False,
        ):
            if self.cfg.fast_dev_run:
                log.info("fast_dev_run is True, only evaluate on one batch")
                test_loader = [next(iter(test_loader))]
            score = metric_func[dataset_name](model, test_loader, self.tokenizer)
            log.info(f"Eval: dataset: {dataset_name} score: {score}")
            results[dataset_name] = score
        return results
    
    def load_models(self, task_vector_device=torch.device("cpu")):
        cfg = self.cfg

        log.info("Loading pretrained model and tokenizer")
        self.tokenizer = instantiate(self.cfg.model.tokenizer)
        self.pretrained_model = instantiate(self.cfg.model.model).eval()
        for p in self.pretrained_model.parameters():
            p.requires_grad_(False)

        if cfg.peft.peft_config is not None:
            peft_config = instantiate(cfg.peft.peft_config)
            peft_config.target_modules = list(peft_config.target_modules)
            if hasattr(cfg.peft, "seed") and cfg.peft.seed is not None:
                L.seed_everything(cfg.peft.seed)
            self.peft_model = peft.get_peft_model(deepcopy(self.pretrained_model), peft_config)
            self.peft_model.print_trainable_parameters()

        task_vectors = []
        for dataset_name in track(cfg.test_datasets, description="Loading finetuned weights"):
            weight_file = finetuned_model_path(
                cfg.model.name, template_name="glue_v1", dataset_name=dataset_name
            )
            
            if cfg.peft.peft_config is not None:
                weight_file = weight_file.parent / cfg.peft.name / weight_file.name

            if not weight_file.exists():
                raise FileNotFoundError(f"Model weight file not found: {weight_file}")

            log.info(f"Loading weights from: {weight_file}")
            
            try:
                if cfg.peft.peft_config is None:
                    weights = torch.load(weight_file, map_location="cpu")
                    task_vectors.append(state_dict_sub(weights, self.pretrained_model.state_dict(), strict=False))
                else:
                    if weight_file.suffix == '.safetensors':
                        from safetensors import safe_open
                        lora_weights = {}
                        with safe_open(weight_file, framework="pt", device="cpu") as f:
                            for k in f.keys():
                                lora_weights[k] = f.get_tensor(k)
                    else:
                        lora_weights = torch.load(weight_file, map_location="cpu")

                    if not lora_weights:
                        raise ValueError(f"Loaded empty LoRA weights from {weight_file}")
                    
                    log.info(f"Loaded {len(lora_weights)} LoRA parameters for {dataset_name}")

                    _task_vector = {
                        k.replace("base_model.model.", ""): v 
                        for k, v in lora_weights.items()
                    }
                    
                    task_vectors.append(_task_vector)
                    
            except Exception as e:
                log.error(f"Failed to load weights from {weight_file}: {str(e)}")
                raise

        self.task_vectors = [
            {k: v.to(task_vector_device) for k, v in tv.items()}
            for tv in task_vectors
        ]
        
        log.info("Successfully loaded all models and task vectors")

    def load_datasets(self):
        """
        Loads the datasets specified in the configuration. If a dataset is found in the cache, it is loaded from there.
        Otherwise, it is instantiated from the configuration, optionally preprocessed, and saved to the cache.
        The method then checks for a validation split in the dataset and raises an error if none is found.
        The validation split is added to the list of test datasets.
        Finally, it sets up the data loaders for the test datasets, both with and without shuffling, and creates an iterator
        for each shuffled test loader.

        Raises:
            ValueError: If a dataset has no validation split.

        Side Effects:
            Sets the instance variables `test_datasets`, `test_loaders`, `shuffled_test_loaders`, and
            `shuffled_test_loader_iters`.
        """
        test_datasets = []
        for dataset_name in track(
            self.cfg.test_datasets, description="loading datasets"
        ):
            cache_dir = CACHE_DIR / "datasets" / dataset_name
            cache_dir_str = str(cache_dir) 
            
            if cache_dir.exists():
                log.info(f"loading dataset {dataset_name} from cache")
                try:
                    dataset = load_from_disk(cache_dir_str)
                except Exception as e:
                    log.error(f"Failed to load dataset from cache: {str(e)}")
                    raise
            else:
                log.info(f"dataset {dataset_name} not found in cache, creating new")
                config = OmegaConf.load(
                    CONFIG_DIR / "datasets" / f"{dataset_name}.yaml"
                )
                dataset = instantiate(config.datasets)

                if hasattr(config, "preprocessor"):
                    log.info("preprocessing the dataset")
                    preprocessor = instantiate(
                        config.preprocessor,
                        template_file=str(TEMPLATE_DIR / "glue_v1" / os.path.basename(config.preprocessor.template_file)),
                        tokenizer=self.tokenizer,
                        tokenizer_kwargs=self.cfg.model.tokenizer_kwargs
                        if hasattr(self.cfg.model, "tokenizer_kwargs")
                        else None,
                    )
                    dataset = dataset.map(
                        preprocessor,
                        **config.map_kwargs if hasattr(config, "map_kwargs") else {},
                    )
                    dataset.save_to_disk(cache_dir_str)

            if "validation" in dataset:
                test_dataset = dataset["validation"]
            elif "validation_matched" in dataset:
                test_dataset = dataset["validation_matched"]
            else:
                raise ValueError(f"dataset {dataset_name} has no validation split.")
            test_datasets.append(test_dataset)

        self.test_datasets = test_datasets
        self.test_loaders = self.fabric.setup_dataloaders(
            *[
                DataLoader(
                    test_dataset,
                    batch_size=self.cfg.batch_size * 3,
                    num_workers=self.cfg.num_workers,
                    shuffle=False,
                    collate_fn=default_data_collator,
                )
                for test_dataset in test_datasets
            ]
        )
        self.shuffled_test_loaders = self.fabric.setup_dataloaders(
            *[
                DataLoader(
                    test_dataset,
                    batch_size=self.cfg.batch_size,
                    num_workers=self.cfg.num_workers,
                    shuffle=True,
                    collate_fn=default_data_collator,
                )
                for test_dataset in test_datasets
            ]
        )

        self.shuffled_test_loader_iters = {
            dataset_name: iter(itertools.cycle(test_loader))
            for dataset_name, test_loader in zip(
                self.cfg.test_datasets, self.shuffled_test_loaders
            )
        }


@hydra.main(str(CONFIG_DIR), "flan_t5_default", None)
def main(cfg: DictConfig):
    (program := Program(cfg)).run()


if __name__ == "__main__":
    main()
