from pathlib import Path


import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import os
import json
import time


import hydra

# Distributed computing
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from src.nlp.distributed_utils import ddp_setup, aggregate_results, summarise_final_results


# Custom Library
from src.nlp.train_utils import get_available_device
from precompute_llama_embeddings import precompute_llama_embeddings, load_precomputed_llama_embeddings_dataset

# Models
from src.nlp.models.MistralModel import MistralModel
from src.nlp.models.QwenModel import QwenModel
from src.nlp.models.NanoGPTWrapper import NanoGPTWrapper
from src.nlp.models.BertModel import BertModel
from src.nlp.models.LlamaModel import LlamaModel
from src.nlp.models.LlamaModelLMHead import LlamaModelLMHead

# Experiment loader
from src.nlp.experiments.next_token_tinystories import NextTokenPrediction
from src.nlp.experiments.hellaswag import HellaswagMatching
from src.nlp.experiments.stsb_similarity import STSBSimilarity
from src.nlp.experiments.classification import ClassificationTask

import random
import numpy as np
import torch


def set_random_seed(seed):
    """
    Set the random seed for reproducibility across Python, NumPy, and PyTorch.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior
    torch.backends.cudnn.benchmark = False     # Disable auto-tuning for reproducibility


@hydra.main(version_base=None, config_path="configs/", config_name="general")
def run(cfg):

    run_path = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

    # Make the path relative to the (sub) project root.
    run_path = Path(run_path)
    run_path = run_path.relative_to(Path.cwd())
    run_path = str(run_path)  # Back to str, because rest of code expects it.

    device = get_available_device(cfg)
    print(f"Device: {device}")

    # Get Huggingface token (for LLama3, Mistral)
    hf_token = "<INSERT HF TOKEN HERE>"
    os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token
    os.environ["HF_TOKEN"] = hf_token

    if cfg.multigpu:
        ddp_setup()

    # Force padding side to be left for weighted average pooling
    if cfg.pooling == "weighted_avg": 
        cfg.padding_side = "left"

    main_assets_path = os.path.join(run_path, "assets")

    llama_special_case = False
    # With Llama3 + next token + attention pool we run out of memory on the L4 gpus so
    # we handle this case separately by precomputing the embeddings from the backbone
    # which is frozen anyway.
    if cfg.backbone == "llama3" and cfg.pooling == "attention_pool" and cfg.datasets.name == "next_token_tinystories":
        precomputed_files_exist = any(fname.startswith("llama-next-token-precomputed-") for fname in os.listdir("."))
        if not precomputed_files_exist:
            model = LlamaModel(cfg, device)
            experiment = NextTokenPrediction(cfg, model=model, device=device)
            experiment.set_model(model)
            experiment.distributed_setup()
            print("Precomputing Llama3 backbone embeddings...")
            precompute_llama_embeddings(model, experiment.train_dataloader, experiment.val_dataloader)
        llama_special_case = True

    for run in range(cfg.runs):
        assets_path = os.path.join(main_assets_path, f"run-{run}")
        # Once distributed computing is set up,assets path is created 

        set_random_seed(cfg.seed + run)

        # Load the model
        if cfg.backbone == "Mistral7B":
            model = MistralModel(cfg, device)
        elif cfg.backbone == "llama3":
            if llama_special_case:
                model = LlamaModelLMHead(cfg, device)
            else:
                model = LlamaModel(cfg, device)
        elif cfg.backbone == "Qwen2.5":
            model = QwenModel(cfg, device)
        elif cfg.backbone in ["NanoGPT", "GPT2-Custom"]:
            model = NanoGPTWrapper(cfg, device)
        elif cfg.backbone == "bert-base":
            model = BertModel(cfg, device)
        else:
            raise ValueError(f"Model {cfg.backbone} not supported.")

        # Load experiment (contains dataloaders and evaluation loops)
        if cfg.datasets.name == "next_token_tinystories":
            experiment = NextTokenPrediction(cfg, model=model, device=device)
        elif cfg.datasets.name == "hellaswag":
            experiment = HellaswagMatching(cfg, model=model, device=device)
        elif cfg.datasets.name == "stsb":
            experiment = STSBSimilarity(cfg, model=model, device=device)
        
        elif cfg.datasets.name in ["tweet_sentiment", "banking77"]:
            experiment = ClassificationTask(cfg, model=model, device=device)
        else:
            raise ValueError(f"Dataset {cfg.datasets.name} not supported.")

        # Set model
        experiment.set_model(model)

        # Set up distributed computing
        if cfg.multigpu:
            experiment.distributed_setup()
            local_rank = experiment.local_rank
        else:
            # This is needed so no errors are thrown when model.module is called
            # and no distributed computing is set up.
            experiment.dummy_distributed()
            local_rank = 0

        if not os.path.exists(main_assets_path):
            if local_rank == 0:
                os.makedirs(main_assets_path)

        if not os.path.exists(assets_path):
            if local_rank == 0:
                os.makedirs(assets_path)

        results_path = os.path.join(assets_path, f"results-{local_rank}.json")

        # Run initial eval (optional)
        if cfg.zero_shot_eval:
            results = experiment.evaluate("test")
            print(results)

            # Save this to assets folder in json
            if local_rank == 0:
                results = {k: float(v) for k, v in results.items()}
                with open(results_path, "w") as f:
                    json.dump(results, f)

        # Set up train/finetune here
        experiment.model.module.set_linear_finetune(True)

        if cfg.datasets.name == "next_token_tinystories":
            experiment.enable_gradient()
        else:
            model.lm_head.weight.requires_grad = False

        if cfg.pooling == "weighted_avg" and local_rank == 0:
            # Save model.pooling.weighted_average_pooling.weight to assets folder
            # as a numpy array
            weights = model.pooling.weighted_average_pooling.w.detach().float().cpu().numpy()

            with open(os.path.join(assets_path, "weighted_avg_weights.npy"), "wb") as f:
                np.save(f, weights)

        if llama_special_case:
            # The precomputation of these embeddings will write one file per process/rank
            # which will be the one loaded here, meaning we don't need the distributed
            # sampler since they will be distinct already.
            dataset_train = load_precomputed_llama_embeddings_dataset("train")
            dataset_val = load_precomputed_llama_embeddings_dataset("val")
            experiment.train_dataloader = DataLoader(dataset_train, batch_size=cfg.learning.batch_size, shuffle=True)
            experiment.val_dataloader = DataLoader(dataset_val, batch_size=cfg.learning.batch_size, shuffle=False)
            experiment.test_dataloader = DataLoader(dataset_val, batch_size=cfg.learning.batch_size, shuffle=False)

            
        epochs = cfg.learning.epochs
        lr = cfg.learning.lr
        params_with_grad = [(name, p) for name, p in model.named_parameters() if p.requires_grad]
        optimizer = torch.optim.Adam(params_with_grad, lr=lr)

        for epoch in range(epochs):
            loss_coll = []
            for batch in tqdm(experiment.train_dataloader):
                optimizer.zero_grad()
                loss = experiment.finetune_pass(batch, oom_case=llama_special_case)

                loss.backward()
                optimizer.step()
                loss_coll.append(loss.item())

            print(f"Epoch {epoch + 1} train loss: {sum(loss_coll)/len(loss_coll)}")

            if cfg.interim_eval:
                interim_results = experiment.evaluate("val", oom_case=llama_special_case)
                print(f"Epoch {epoch + 1} validation metrics:", interim_results)

        # Run final eval
        results = experiment.evaluate("test", oom_case=llama_special_case)
        final_results_path = os.path.join(assets_path, f"results-{local_rank}-finetune.json")
        print(results)

        # Save this to assets folder in a json
        if local_rank == 0:
            results = {k: float(v) for k, v in results.items()}
            with open(final_results_path, "w") as f:
                json.dump(results, f)
                


        torch.distributed.barrier() # Ensure all processes are synchronized before proceeding

        if local_rank == 0:
        
            aggregate_results(assets_path, cfg)

            if cfg.pooling == "weighted_avg":
                # Save model.pooling.weighted_average_pooling.weight to assets folder
                # as a numpy array
                weights = model.pooling.weighted_average_pooling.w.detach().float().cpu().numpy()
                np.save(os.path.join(assets_path, "weighted_avg_weights.npy"), weights)


        print(f"Results saved to {run_path}")
        # Clean up memory
        del model
        del experiment
        torch.cuda.empty_cache()
        ############### RUN LOOP ENDS HERE ###############

    # Once all runs are done, aggregate results and calculate mean/std of metrics 
        # End distributed computing
    if cfg.multigpu:
        destroy_process_group()
    if local_rank == 0:
        summary = summarise_final_results(main_assets_path)
        print("Final summary of results:")
        print(summary)
        with open(os.path.join(main_assets_path, "summary.json"), "w") as f:
            json.dump(summary, f)

        print(f"Summary saved to {main_assets_path}/summary.json")
        
    

if __name__ == "__main__":
    run()