#!/usr/bin/env python3
"""
major actions here: fine-tune the features and evaluate different settings
"""
import os
from socket import TIPC_DEST_DROPPABLE
import torch
import warnings

import numpy as np
import random

from time import sleep
from random import randint

import src.utils.logging as logging
from src.configs.config import get_cfg
from src.data import loader as data_loader
from src.engine.evaluator import Evaluator
from src.engine.trainer import Trainer
from src.models.build_model import build_model
from src.utils.file_io import PathManager
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from tensorboardX import SummaryWriter
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
import tensorflow_datasets as tfds
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader


from launch import default_argument_parser, logging_train_setup
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def get_food101_data_loaders(data_dir, batch_size=80, num_workers=4):
    images_dir = os.path.join(data_dir, 'images')
    meta_dir = os.path.join(data_dir, 'meta')
    
    # 读取类别信息
    with open(os.path.join(meta_dir, 'classes.txt'), 'r') as f:
        classes = [line.strip() for line in f.readlines()]
    
    # 创建类名到索引的映射
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    
    # 读取训练集和测试集文件列表
    def read_split_file(split_filename):
        with open(os.path.join(meta_dir, split_filename), 'r') as f:
            return [os.path.join(images_dir, line.strip() + '.jpg') for line in f.readlines()]
    
    train_files = read_split_file('train.txt')
    test_files = read_split_file('test.txt')
    
    _mean = IMAGENET_INCEPTION_MEAN
    _std = IMAGENET_INCEPTION_STD
    
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=_mean, std=_std)
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=_mean, std=_std)
    ])
    
    class Food101Dataset(ImageFolder):
        def __init__(self, file_paths, class_to_idx, transform=None):
            self.file_paths = file_paths
            self.class_to_idx = class_to_idx
            self.transform = transform
            
        def __len__(self):
            return len(self.file_paths)
        
        def __getitem__(self, idx):
            img_path = self.file_paths[idx]
            image = Image.open(img_path).convert('RGB')
            class_name = os.path.basename(os.path.dirname(img_path))
            label = self.class_to_idx[class_name]
            
            if self.transform is not None:
                image = self.transform(image)
            
            return image, label

    train_dataset = Food101Dataset(train_files, class_to_idx, transform=train_transform)
    test_dataset = Food101Dataset(test_files, class_to_idx, transform=test_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, test_loader


def setup(args, lr, wd, l1, l2, flag):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)

    # setup dist
    #cfg.DIST_INIT_PATH = "tcp://{}:12399".format(os.environ["SLURMD_NODENAME"])

    # setup output dir
    # output_dir / data_name / feature_name / lr_wd / run1
    output_dir = cfg.OUTPUT_DIR
    cfg.SOLVER.BASE_LR = lr
    cfg.SOLVER.WEIGHT_DECAY = wd
    output_folder = os.path.join(
        cfg.DATA.NAME, cfg.DATA.FEATURE, f"lr{lr}_wd{wd}_l1{l1}_l2{l2}")
    lr = lr / 256 * cfg.DATA.BATCH_SIZE
    cfg.SOLVER.BASE_LR = lr

    # train cfg.RUN_N_TIMES times
    count = 1
    while count <= cfg.RUN_N_TIMES:
        if flag == False:
            output_path = os.path.join(output_dir, output_folder, f"Cifarours")
        else:
            output_path = os.path.join(output_dir, output_folder, f"Cifarours")
        # pause for a random time, so concurrent process with same setting won't interfere with each other. # noqa
        sleep(randint(3, 30))
        if not PathManager.exists(output_path):
            PathManager.mkdirs(output_path)
            cfg.OUTPUT_DIR = output_path
            break
        else:
            count += 1
    if count > cfg.RUN_N_TIMES:
        raise ValueError(
            f"Already run {cfg.RUN_N_TIMES} times for {output_folder}, no need to run more")

    cfg.freeze()
    return cfg


def get_loaders(cfg, logger, final):
    _mean = IMAGENET_INCEPTION_MEAN
    _std = IMAGENET_INCEPTION_STD
    transform = transforms.Compose([ transforms.Resize((224, 224), interpolation=3), transforms.ToTensor(), transforms.Normalize(mean=_mean, std=_std)])

    # load CIFAR-100
    trainset = torchvision.datasets.CIFAR100(root='data/cifar100',train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=80, shuffle=True, num_workers=4)

    testset = torchvision.datasets.CIFAR100(root='data/cifar100', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=80, shuffle=False, num_workers=4)
    data_dir = 'data/food-101'
    train_dir = os.path.join(data_dir, 'train')
    test_dir = os.path.join(data_dir, 'test')
    # if load Food-101, you need the follow code
    #trainloader, testloader = get_food101_data_loaders(data_dir)
    return trainloader, testloader 


def train(cfg, args, lr, wd, l1, l2):
    # clear up residual cache from previous runs
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # main training / eval actions here

    # fix the seed for reproducibility
    if cfg.SEED is not None:
        torch.manual_seed(cfg.SEED)
        np.random.seed(cfg.SEED)
        random.seed(0)

    # setup training env including loggers
    logging_train_setup(args, cfg)
    logger = logging.get_logger("visual_prompt")

    train_loader, val_loader = get_loaders(cfg, logger, True)
    test_loader = val_loader
    logger.info("Constructing models...")
    model, cur_device = build_model(cfg)

    logger.info("Setting up Evalutator...")
    evaluator = Evaluator()
    logger.info("Setting up Trainer...")
    flag = True
    trainer = Trainer(cfg, model, evaluator, cur_device, flag, l1, l2)
    s = 'Cifar_ours_ViT_' + str(lr).replace('.', '_') + '_' + str(wd).replace('.', '_') + '_' + str(l1).replace('.', '_') + '_' + str(l2).replace('.', '_')
    writer = SummaryWriter(s)
    s = 'save/cifar100/checkpoint_ours_vit'+str(lr).replace('.', '_') + '_' + str(wd).replace('.', '_') + '_' + str(l1).replace('.', '_') + '_' + str(l2).replace('.', '_')+'.pth'
    output_vit = s
    total = 100
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params (M): %.2f' % (n_parameters / 1.e6))

    
    #for k, p in model.enc.transformer.encoder.netA.named_parameters():
    #    p.requires_grad = True
    #for k, p in model.enc.transformer.encoder.netB.named_parameters():
    #    p.requires_grad = True

    trainer.train_final(train_loader, val_loader, test_loader, writer, output_vit, total)

    if cfg.SOLVER.TOTAL_EPOCH == 0:
        trainer.eval_classifier(test_loader, "test", 0)
    writer.close()

def main(args):
    """main function to call from workflow"""
    lr = 0.001
    wd = 0.01
    l1 = 0.0
    l2 = 0.0001
    cfg = setup(args, lr, wd, l1, l2, False)

    # Perform training.
    acc = train(cfg, args, lr, wd, l1, l2)
    print(acc)


if __name__ == '__main__':
    args = default_argument_parser().parse_args()
    main(args)
