"""Train source-domain LoRA adapters and evaluate.
This module contains the training loop for either a single universal LoRA
adapter (source_lora) or per-domain LoRA adapters (source_lora_per_domain).
Evaluated on held-out domains or optional in-domain validation data.
LoRA is implemented using the CLIPORA library:
Github - https://github.com/awilliamson10/clipora
"""

import os   
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

import open_clip
import wandb
import hydra
from omegaconf import DictConfig, OmegaConf

from datasets.get_dataset import get_dataloader
from experiments.helpers import encode_text
from experiments.eval_source_models import eval_model
from experiments.helpers import load_lora_model 

def source_train(
        cfg: DictConfig,
        trial_num: int,
        train_domains: list[str], 
        test_domains: list[str]
    ):
    """Train a LoRA module on CLIP's vision encoder and evaluate it.

    Parameters
    ----------
    cfg : DictConfig
        Hydra configuration with experiment, dataset, constants, and backbone.
    trial_num : int
        Trial index for seeding and checkpoint naming.
    train_domains : list[str]
        Domains used for training (single or multiple, depending on experiment).
    test_domains : list[str]
        Domains used for evaluation; may be held out or created from train.

    Returns
    -------
    dict
        Evaluation metrics on the test/val split.
    """

    dataset_name = cfg.dataset.name
    backbone = cfg.backbone
    hp = cfg.experiment[dataset_name]

    wandb.init(
        entity=cfg.constants.wandb_entity,
        project=cfg.experiment.name,
        name=f"{dataset_name}_{backbone.replace('/', '-')}_trial{trial_num}",
        config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=False), 
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("--------------------------------")
    print("Experiment name:", cfg.experiment.name)
    print("Dataset name:", dataset_name)
    print("Backbone:", backbone)
    print("Device:", device)
    print("Train domains:", train_domains)
    print("Test domains:", test_domains)
    print("--------------------------------")

    model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(backbone, pretrained="openai")
    tokenizer = open_clip.get_tokenizer(backbone)
    model.to(device)

    create_val_from_train = False # Set True if you want to test on in-domain validation data for wilds datasets
    train_dataset, train_loader = get_dataloader(
        dataset_name=dataset_name,
        trial_num=trial_num,
        path_to_annotations_dir=cfg.dataset.annotations_dir if dataset_name == "domainnet" else None,
        path_to_dataset_dir=cfg.dataset.data_dir,
        domains=train_domains,
        classes=cfg.dataset.classes,  
        transform=preprocess_train,
        split_type="train",
        batch_size=cfg.experiment.batch_size,
        num_workers=cfg.experiment.num_workers,
        pin_memory=cfg.experiment.pin_memory,   
        is_training=True,
        create_val_from_train=create_val_from_train 
    )

    image_loss_func = nn.CrossEntropyLoss()

    text_features = encode_text(
        cfg.dataset.classes,
        cfg.dataset.prompt_templates,
        model,
        tokenizer,
        device
    )

    # Load the LoRA model (image encoder only)
    model = load_lora_model(cfg.experiment, backbone)
    model.to(device)

    # Freeze the base model
    for name, param in model.named_parameters():
        if "lora" not in name:
            param.requires_grad = False

    print("--------------------------------")
    print("Model parameters: ", sum(p.numel() for p in model.parameters()))
    print("Model parameters to train: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
    
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.shape)
    
    print("Length of dataset: ", len(train_dataset))
    print("--------------------------------")

    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=hp.lr_clip,
        betas=cfg.experiment.optimizer.betas,
        eps=cfg.experiment.optimizer.eps,
        weight_decay=cfg.experiment.optimizer.weight_decay
    )

    for epoch in range(hp.num_epochs):

        print("----------------------------------------------------")
        print("Epoch: ", epoch)
        print("----------------------------------------------------")

        model.train()
        total_loss, total_correct, total_samples = 0, 0, 0
        all_preds = []
        all_labels = []
        all_metadata = []

        for batch in tqdm(train_loader):
            optimizer.zero_grad()
            images, labels, metadata = batch
            images, labels = images.to(device), labels.to(device)

            image_features = model.encode_image(images).float()
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            logits = model.logit_scale.exp() * image_features @ text_features
            loss = image_loss_func(logits, labels)
            loss.backward()
            optimizer.step()

            preds = logits.argmax(dim=1)

            all_preds.append(preds)
            all_labels.append(labels)
            all_metadata.append(metadata)

            total_loss += loss.item()
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0) 

        train_loss = total_loss / len(train_loader)
        train_accuracy = 100 * total_correct / total_samples

        all_preds = torch.cat(all_preds, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        all_metadata = torch.cat(all_metadata, dim=0)

        if dataset_name in cfg.constants.wilds_datasets:
            all_preds_cpu = all_preds.cpu()
            all_labels_cpu = all_labels.cpu()
            all_metadata_cpu = all_metadata.cpu()
            # Use WILDS built-in evaluation
            train_results, _ = train_dataset.eval(all_preds_cpu, all_labels_cpu, all_metadata_cpu)
            train_results["train_loss"] = train_loss
        else:
            train_results = {"train_accuracy": train_accuracy, "train_loss": train_loss}

        if (create_val_from_train == True): 
            split_type = "test" if dataset_name == "domainnet" else "train"
            test_results = eval_model(cfg, trial_num, model, train_domains, split_type, create_val_from_train)
        else:
            test_results = eval_model(cfg, trial_num, model, test_domains, "test", create_val_from_train)

        domain_prefix = test_domains[0] if test_domains is not None else "global"
        log_dict = {
            f"{dataset_name}/{domain_prefix}/epoch": epoch,
        }
        for metric_name, metric_value in train_results.items():
            clean_name = metric_name.lower().replace(" ", "_")
            log_dict[f"{dataset_name}/{domain_prefix}/train_{clean_name}"] = metric_value
        for metric_name, metric_value in test_results.items():
            clean_name = metric_name.lower().replace(" ", "_")
            if create_val_from_train == True:
                log_dict[f"{dataset_name}/{domain_prefix}/val_{clean_name}"] = metric_value
            else:
                log_dict[f"{dataset_name}/{domain_prefix}/test_{clean_name}"] = metric_value
        wandb.log(log_dict)

        print("--------------------------------")
        print(f"Train results on {train_domains}: \n", train_results)
        if (create_val_from_train == True):
            print(f"Val results on {train_domains}: \n", test_results)
        else:
            print(f"Test results on {test_domains}: \n", test_results)
        print("--------------------------------")

    if create_val_from_train == False: 
        # Save the model
        save_dir = os.path.join(cfg.constants.model_save_path, f"{cfg.experiment.name}_{dataset_name}_{backbone.replace('/', '-')}")
        os.makedirs(save_dir, exist_ok=True)
        if cfg.experiment.name in ["source_lora"]:
            if dataset_name == "domainnet": 
                model_name = f"trial{trial_num}_{cfg.experiment.name}_{dataset_name}_{test_domains[0]}.pt"
            else:
                model_name = f"trial{trial_num}_{cfg.experiment.name}_{dataset_name}.pt"
        else: # source_lora_per_domain
            model_name = f"trial{trial_num}_{cfg.experiment.name}_{dataset_name}_{train_domains[0]}.pt"
        torch.save({'model_state_dict': model.state_dict()}, os.path.join(save_dir, model_name))

    wandb.finish()

    return test_results

