import os
import timm
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import pandas as pd
from torch.utils.tensorboard import SummaryWriter

from open_clip import create_model_and_transforms

from data import CustomDataset
from engine import train_epoch, validate_epoch
from models import (
    ClassificationWrapper, EnsembleClassifier,
    freeze_swin_layers, freeze_generic_backbone,
    get_model_dimensions
)

def run_experiment(dataset_dir, args, log_dir_suffix):
    """
    Build datasets/dataloaders, construct the requested model/backbone/head,
    run training/validation loop with scheduler and TensorBoard logging,
    and return (best_val_acc, best_epoch). 
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform_train = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    transform_val = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = CustomDataset(
        root_dir=dataset_dir, split='Train', transform=transform_train,
        num_classes=args.num_classes, support_size=args.support_size,
        random_support=(args.support_selection == "random")
    )
    val_dataset = CustomDataset(
        root_dir=dataset_dir, split='Test', transform=transform_val,
        num_classes=args.num_classes, support_size=args.support_size,
        random_support=(args.support_selection == "random")
    )
    train_bs = min(args.batch_size, len(train_dataset))
    val_bs = min(args.batch_size, len(val_dataset))
    train_loader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=val_bs, shuffle=False, num_workers=4)

    class_names = [str(k) for k in sorted(train_dataset.label_map.values())]

    # Model selection
    if args.model_choice == "swin":
        if args.head_type == 'linear':
            swin_model = timm.create_model(args.swin_arch, pretrained=True, num_classes=args.num_classes)
        else:
            swin_model = timm.create_model(args.swin_arch, pretrained=True, num_classes=0)
        if args.freeze_backbone:
            freeze_swin_layers(swin_model, trainable_layers=0)
        swin_model.to(device)
        swin_model.eval()
        swin_fe_dim, swin_t_dim = get_model_dimensions(args.swin_arch)
        if args.head_type == "linear":
            model = swin_model
            print("Using Swin model with pretrained linear head.")
        else:
            model = ClassificationWrapper(
                support_size=args.support_size,
                model=swin_model,
                num_classes=args.num_classes,
                output_dim=(224,224),
                img_dim=(224,224),
                is_clip=False,
                head_type=args.head_type,
                fe_dim=swin_fe_dim,
                t_dim=swin_t_dim,
                elmes_dim=256,
                no_query_fusion=(args.prototype_construction == "aggregate")
            )
            print("Using Swin model with transformer head.")
    elif args.model_choice == "clip":
        clip_model, clip_preprocess, _ = create_model_and_transforms(
            "ViT-B-32", pretrained="datacomp_xl_s13b_b90k", device=device
        )
        clip_model.to(device)
        clip_model.eval()
        clip_model.trainable_layers = 0
        clip_fe_dim, clip_t_dim = get_model_dimensions('clip')
        model = ClassificationWrapper(
            support_size=args.support_size,
            model=clip_model,
            num_classes=args.num_classes,
            output_dim=(224,224),
            img_dim=(224,224),
            is_clip=True,
            head_type=args.head_type,
            fe_dim=clip_fe_dim,
            t_dim=clip_t_dim,
            elmes_dim=256,
            no_query_fusion=(args.prototype_construction == "aggregate")
        )
        print("Using only CLIP model.")
    elif args.model_choice == "ensemble":
        swin_model = timm.create_model(args.swin_arch, pretrained=True, num_classes=0)
        if args.freeze_backbone:
            freeze_swin_layers(swin_model, trainable_layers=0)
        swin_model.to(device)
        swin_model.eval()
        swin_fe_dim, swin_t_dim = get_model_dimensions(args.swin_arch)
        swin_wrapper = ClassificationWrapper(
            support_size=args.support_size,
            model=swin_model,
            num_classes=args.num_classes,
            output_dim=(224,224),
            img_dim=(224,224),
            is_clip=False,
            head_type=args.head_type,
            fe_dim=swin_fe_dim,
            t_dim=swin_t_dim,
            elmes_dim=256,
            no_query_fusion=(args.prototype_construction == "aggregate")
        )

        clip_model, clip_preprocess, _ = create_model_and_transforms(
            "ViT-B-32", pretrained="datacomp_xl_s13b_b90k", device=device
        )
        clip_model.to(device)
        clip_model.eval()
        clip_model.trainable_layers = 0
        clip_fe_dim, clip_t_dim = get_model_dimensions('clip')
        clip_wrapper = ClassificationWrapper(
            support_size=args.support_size,
            model=clip_model,
            num_classes=args.num_classes,
            output_dim=(224,224),
            img_dim=(224,224),
            is_clip=True,
            head_type=args.head_type,
            fe_dim=clip_fe_dim,
            t_dim=clip_t_dim,
            elmes_dim=256,
            no_query_fusion=(args.prototype_construction == "aggregate")
        )
        model = EnsembleClassifier(swin_wrapper, clip_wrapper)
        print("Using ensemble of Swin and CLIP models.")
    elif args.model_choice in [
        "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
        "mobilevit", "vit", "eva02_base", "coatnet3_rw",
        "convnextv2_large", "convnextv2_tiny", "maxvit_small"
    ]:
        if args.model_choice == "resnet18":
            model_str = "resnet18"
        elif args.model_choice == "resnet34":
            model_str = "resnet34"
        elif args.model_choice == "resnet50":
            model_str = "resnet50"
        elif args.model_choice == "resnet101":
            model_str = "resnet101"
        elif args.model_choice == "resnet152":
            model_str = "resnet152"
        elif args.model_choice == "mobilevit":
            model_str = "mobilevit_s"
        elif args.model_choice == "vit":
            model_str = "vit_base_patch16_224"
        elif args.model_choice == "eva02_base":
            model_str = "eva02_base_patch14_224.mim_in22k"
        elif args.model_choice == "coatnet3_rw":
            model_str = "coatnet_3_rw_224.sw_in12k"
        elif args.model_choice == "convnextv2_large":
            model_str = "convnextv2_large.fcmae_ft_in22k_in1k_384"
        elif args.model_choice == "convnextv2_tiny":
            model_str = "convnextv2_tiny.fcmae_ft_in22k_in1k_384"
        elif args.model_choice == "maxvit_small":
            model_str = "maxvit_small_tf_512.in1k"

        fe_dim, t_dim = get_model_dimensions(model_str)
        if args.head_type == "linear":
            backbone_model = timm.create_model(model_str, pretrained=True, num_classes=args.num_classes)
            if args.freeze_backbone:
                freeze_generic_backbone(backbone_model)
            model = backbone_model
            print(f"Using {model_str} with pretrained linear head.")
        else:
            backbone_model = timm.create_model(model_str, pretrained=True, num_classes=0)
            if args.freeze_backbone:
                for param in backbone_model.parameters():
                    param.requires_grad = False
            model = ClassificationWrapper(
                support_size=args.support_size,
                model=backbone_model,
                num_classes=args.num_classes,
                output_dim=(224,224),
                img_dim=(224,224),
                is_clip=False,
                head_type=args.head_type,
                fe_dim=fe_dim,
                t_dim=t_dim,
                elmes_dim=256,
                no_query_fusion=(args.prototype_construction == "aggregate")
            )
            print(f"Using {model_str} with transformer head and support set of size {args.support_size}.")
    else:
        raise ValueError(f"Unknown model choice: {args.model_choice}")

    model.to(device)

    # Assertions block 
    if args.model_choice in [
        "swin", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
        "mobilevit", "vit", "convnextv2_large", "convnextv2_tiny", "maxvit_small"
    ]:
        if args.head_type == "linear":
            for name, param in model.named_parameters():
                if any(keyword in name for keyword in ["head.", ".classifier", ".fc.weight", ".fc.bias"]):
                    assert param.requires_grad, f"Parameter {name} in classification head should be trainable."
                else:
                    if args.freeze_backbone:
                        assert not param.requires_grad, f"Parameter {name} in backbone should be frozen."
                    else:
                        assert param.requires_grad, f"Parameter {name} should be trainable when --freeze_backbone is not set."
        else:
            backbone = model.model if isinstance(model, ClassificationWrapper) else model
            for name, param in backbone.named_parameters():
                if args.freeze_backbone:
                    assert not param.requires_grad, f"Backbone parameter {name} should be frozen for transformer head when --freeze_backbone is set."
                else:
                    assert param.requires_grad, f"Backbone parameter {name} should be trainable when --freeze_backbone is not set."
            for name, param in model.head.named_parameters():
                assert param.requires_grad, f"Head parameter {name} should be trainable for transformer head."
    elif args.model_choice == "clip":
        for name, param in model.model.named_parameters():
            assert not param.requires_grad, f"CLIP backbone parameter {name} should be frozen for transformer head."
        for name, param in model.head.named_parameters():
            assert param.requires_grad, f"CLIP head parameter {name} should be trainable for transformer head."
    elif args.model_choice == "ensemble":
        swin_wrapper = model.wrapper1
        if args.head_type == "linear":
            for name, param in swin_wrapper.named_parameters():
                if any(keyword in name for keyword in ["head.", ".classifier", ".fc.weight", ".fc.bias"]):
                    assert param.requires_grad, f"Parameter {name} in swin wrapper's classification head should be trainable."
                else:
                    assert not param.requires_grad, f"Parameter {name} in swin wrapper's backbone should be frozen."
        else:
            for name, param in swin_wrapper.model.named_parameters():
                assert not param.requires_grad, f"Parameter {name} in swin backbone of ensemble should be frozen."
            for name, param in swin_wrapper.head.named_parameters():
                assert param.requires_grad, f"Parameter {name} in swin transformer head of ensemble should be trainable."

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    criterion = nn.CrossEntropyLoss()

    writer = SummaryWriter(log_dir=os.path.join("runs", log_dir_suffix))

    best_val_acc = 0.0
    best_epoch = -1
    for epoch in range(1, args.epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, criterion, device, writer, epoch)
        val_loss, val_acc = validate_epoch(model, val_loader, criterion, device, writer, epoch, class_names)
        print(f"Epoch {epoch}/{args.epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}, Val Acc = {val_acc*100:.2f}%")
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
        scheduler.step()

    writer.close()
    return best_val_acc, best_epoch
