import argparse
import os
import numpy as np
import yaml
from copy import deepcopy
from PIL import Image

import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

import torchvision.models as models

from clip.custom_clip import get_coop
from data.imagnet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset
from utils.tools import Summary, AverageMeter, ProgressMeter, accuracy, load_config, set_random_seed
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask
from test_time_tuning import test_time_adapt_eval
# from croc import GraphSmoothingTTA

model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))


def main():
    """
    Main function to parse arguments and start the training/evaluation process.
    """
    args = parser.parse_args()

    # Set random seed
    if not args.seed:
        args.seed = torch.randint(0, 10000, (1,)).item()
    set_random_seed(args.seed)
    print(f"Using Seed: {args.seed} for training")

    assert args.gpu is not None, "A GPU ID must be specified."
    main_worker(args.gpu, args)


def main_worker(gpu, args):
    """
    Worker function for multi-GPU training/evaluation.
    Args:
        gpu (int): ID of the current GPU.
        args (argparse.Namespace): Command-line arguments.
    """
    args.gpu = gpu

    # Create cache directory if it does not exist
    if not os.path.exists(args.cache_path):
        os.makedirs(args.cache_path)

    set_random_seed(args.seed)  # Ensure seed is set for this worker as well

    print(f"Using GPU: {args.gpu} for training")

    # Determine if CoOp pre-trained model is used
    args.coop = True if args.load else False

    # Initialize CLIP model with CoOp or custom prompt learning
    model = get_coop(args, args.arch, args.test_sets, args.gpu)

    print("===================== Settings ======================")
    print(f"Template: {args.with_templates}, Pooling Type: {args.pooling_type}")
    print(f"CoOp Enabled: {args.coop}")
    print(f"Shifter Enabled: {args.shifter}, TTA Steps: {args.tta_steps}")
    print(f"Seed: {args.seed}, Batch Size: {args.batch_size}, Architecture: {args.arch}")
    print("=====================================================")

    # Freeze all model parameters by default
    for name, param in model.named_parameters():
        param.requires_grad_(False)
    print(f"=> Model created: visual backbone {args.arch}")

    # Move model to GPU
    if not torch.cuda.is_available():
        print('Using CPU, this will be slow.')
    else:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)

    # Setup automatic mixed-precision (AMP) loss scaling
    scaler = torch.cuda.amp.GradScaler(init_scale=1000)
    print('=> Using native Torch AMP. Training in mixed precision.')

    cudnn.benchmark = True  # Enable cuDNN autotuner for faster convolutions

    # Normalization statistics from CLIP
    normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                     std=[0.26862954, 0.26130258, 0.27577711])

    # Iterate through evaluation datasets
    datasets = args.test_sets.split("/")
    results = {}

    for set_id in datasets:
        print(f"======================> Evaluating: {set_id} <======================")

        config_file_path = os.path.join(args.config_dir, f"{set_id}.yaml")

        # 2. Check whether the YAML config file exists
        if os.path.exists(config_file_path):
            print(f"  > Found config file: {config_file_path}")
            try:
                with open(config_file_path, 'r') as f:
                    dataset_params = yaml.safe_load(f)
                for key, value in dataset_params.items():
                    setattr(args, key, value)  # Inject YAML key/values into args

                print(f"  > Loaded and updated args from {set_id}.yaml.")
                print("  Current dataset-specific args:")
                print(vars(args))  # Inspect updated args
            except Exception as e:
                print(f"  > Error: failed to load or parse {config_file_path} - {e}")
        else:
            print(f"  > No config file found for {set_id}: {config_file_path}. Using default or previously loaded args.")

        # Define data transformations based on whether Shifter is enabled
        if args.shifter:
            base_transform = transforms.Compose([
                transforms.Resize(args.resolution, interpolation=BICUBIC),
                transforms.CenterCrop(args.resolution)])
            preprocess = transforms.Compose([
                transforms.ToTensor(),
                normalize])
            data_transform = AugMixAugmenter(base_transform, preprocess, n_views=args.batch_size - 1,
                                             augmix=(len(set_id) > 1 and args.shifter))
            batchsize = 1  # For Shifter, process one image at a time
        else:
            data_transform = transforms.Compose([
                transforms.Resize(args.resolution, interpolation=BICUBIC),
                transforms.CenterCrop(args.resolution),
                transforms.ToTensor(),
                normalize,
            ])
            batchsize = args.batch_size

        # Determine class names for the current dataset
        classnames = []
        if len(set_id) > 1:  # Fine-grained classification datasets (e.g., 'Caltech101')
            classnames = eval(f"{set_id.lower()}_classes")
        else:  # ImageNet variants (A, R, K, V, I)
            classnames_all = imagenet_classes
            if set_id in ['A', 'R', 'V']:
                label_mask = eval(f"imagenet_{set_id.lower()}_mask")
                if set_id == 'R':  # R variant logic
                    classnames = [classnames_all[i] for i, m in enumerate(label_mask) if m]
                else:  # A, V variants logic
                    classnames = [classnames_all[i] for i in label_mask]
            else:  # K, I (full ImageNet) variants
                classnames = classnames_all

        # Reset model's class names and initialize text features
        model.reset_classnames(args, classnames, set_id)
        model.init_text_features()
        model = model.cuda(args.gpu)  # Ensure model is on the correct device after text feature init

        # Initialize TextFeatureShifter if enabled
        shifter = None
        if args.shifter:
            # shifter = GraphSmoothingTTA(model.emb_dim, class_num=model.class_number, type=model.dtype)
            # shifter = shifter.cuda(args.gpu)
            # shifter.reset()
            # trainable_param = list(shifter.parameters())  # Collect trainable parameters for optimizer
            # optimizer = torch.optim.AdamW(trainable_param, args.lr)
            optimizer = None
        else:
            optimizer = None  # No optimizer needed if shifter is not used

        # Build validation dataset and loader
        val_dataset = build_dataset(set_id, data_transform, args.data, mode=args.dataset_mode)
        print(f"Number of test samples: {len(val_dataset)}")
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=batchsize, shuffle=True,
            num_workers=args.workers, pin_memory=True)

        # Perform test-time adaptation and evaluation
        results[set_id] = test_time_adapt_eval(val_loader, model, shifter, optimizer, scaler, args, len(val_dataset))

        # Clean up
        del val_dataset, val_loader
        if shifter:
            del shifter
            torch.cuda.empty_cache()  # Clear GPU memory after each dataset if shifter was used

        try:
            print(f"=> Acc. on testset [{set_id}]: @1 {results[set_id][0]:.2f}/ @5 {results[set_id][1]:.2f}")
        except TypeError:
            # Handle cases where results might not be a tuple (e.g., if only one metric is returned)
            print(f"=> Acc. on testset [{set_id}]: {results[set_id]}")

    # Print and save summary of results
    print("\n======== Result Summary ========")
    print("Parameters: nstep\tlr\tbs")
    print(f"Parameters: {args.tta_steps}\t{args.lr}\t{args.batch_size}")

    print("{:<15}{:<15}{:<15}".format("[set_id]", "Top-1 acc.", "Top-5 acc."))

    for id in results.keys():
        raw_top1_tensor = results[id][0]
        raw_top5_tensor = results[id][1]

        top1_accuracy = raw_top1_tensor.item()
        top5_accuracy = raw_top5_tensor.item()

        print(f"{id:<15}{top1_accuracy:<15.2f}{top5_accuracy:<15.2f}")

    print("\n")

    # Save results to a log file
    log_file_path = f'./logs/{args.name}.txt'
    with open(log_file_path, 'a') as f:
        f.write("\n======== Result Summary ========\n")
        f.write(f"Parameters: nstep\tlr\tbs\n")
        f.write(f"Parameters: {args.tta_steps}\t{args.lr}\t{args.batch_size}\n")

        f.write("{:<15}{:<15}{:<15}\n".format("[set_id]", "Top-1 acc.", "Top-5 acc."))
        for id in results.keys():
            # results[id] is [tensor(top1_acc), tensor(top5_acc)]
            raw_top1_tensor = results[id][0]
            raw_top5_tensor = results[id][1]

            # Convert PyTorch tensors to Python floats
            top1_accuracy = raw_top1_tensor.item()
            top5_accuracy = raw_top5_tensor.item()

            f.write(f"{id:<15}{top1_accuracy:<15.2f}{top5_accuracy:<15.2f}\n")

        f.write("\n")
        f.write(f"Template: {args.with_templates}, Pooling Type: {args.pooling_type}\n")
        f.write(f"CoOp Enabled: {args.coop}, inject_rate: {args.inject_rate}\n")
        f.write(f"Shifter Enabled: {args.shifter}, Per Label: {args.per_label}\n")
        f.write(f"Seed: {args.seed}, Batch Size: {args.batch_size}, Architecture: {args.arch}\n")
        f.write("======== End ========\n")
    print(f"Results saved to {log_file_path}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Test-time Prompt Tuning')
    # Core arguments
    parser.add_argument('--data', type=str, help='Path to dataset root')
    parser.add_argument('--test_sets', type=str, default='A/R/V/K/I',
                        help='Test datasets (multiple datasets split by slash)')
    parser.add_argument('--dataset_mode', type=str, default='test', help='Which split to use: train/val/test')
    parser.add_argument('-a', '--arch', metavar='ARCH', default='ViT-B/16', help='CLIP architecture')
    parser.add_argument('--resolution', default=224, type=int, help='CLIP image resolution')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='Number of data loading workers (default: 4)')
    parser.add_argument('-b', '--batch-size', default=64, type=int, metavar='N', help='Batch size')
    parser.add_argument('--lr', '--learning-rate', default=5e-3, type=float,
                        metavar='LR', help='Initial learning rate', dest='lr')
    parser.add_argument('-p', '--print-freq', default=200, type=int,
                        metavar='N', help='Print frequency (default: 10)')
    parser.add_argument('--gpu', default=0, type=int,
                        help='GPU ID to use.')
    parser.add_argument('--load', default=None, type=str, help='Path to a pre-trained CoOp model')
    parser.add_argument('--seed', type=int, default=1, help='Random seed')
    parser.add_argument('--name', default='shifter', type=str, help='Name for results log file')
    parser.add_argument('--cache_path', default='./cache/', type=str, help='Path for extracted text embeddings cache')

    # Shifter and TTA arguments
    parser.add_argument('--shifter', action='store_true', default=False, help='Enable text shifter')
    parser.add_argument('--tta_steps', default=1, type=int, help='Test-time adaptation steps')
    parser.add_argument('--per_label', action='store_true', default=False, help='Perform TTA per label')
    parser.add_argument('--selection_p', default=0.1, type=float, help='Confidence selection percentile for TTA')

    # Misc
    parser.add_argument('--inject_rate', type=float, default=0.1)
    parser.add_argument('--config_dir', type=str, default='./configs')

    # Prompt learning arguments
    parser.add_argument('--ctx_init', default='a_photo_of_a', type=str, help='Initial tunable prompts')
    parser.add_argument('--n_ctx', default=4, type=int, help='Number of tunable tokens')
    parser.add_argument('--with_templates', action="store_true", default=False,
                        help='Use predefined templates for prompts')
    parser.add_argument('--with_concepts', action="store_true", default=False, help='Use concepts for prompts')
    parser.add_argument('--cache_init', action="store_true", default=False, help='Initialize cache')
    parser.add_argument('--concept_type', default='wo_temp', type=str,
                        help='Concept type: w_temp|wo_class|wo_temp')  # Typo kept as in original
    parser.add_argument('--pooling_type', default='mean', type=str,
                        help='Pooling type for text features: mean|macro|class|concept')

    main()

