import csv
import numpy as np
import torch
from pathlib import Path
from collections import Counter
import h5py
from models.EfficientNet import CSTA_EfficientNet
from models.GoogleNet import CSTA_GoogleNet
from models.MobileNet import CSTA_MobileNet
from models.ResNet import CSTA_ResNet
from collections import defaultdict
import re

# Count the number of parameters
def count_parameters(model,model_name):
    if model_name in ['GoogleNet','GoogleNet_Attention','ResNet','ResNet_Attention']:
        x = [param.numel() for name,param in model.named_parameters() if param.requires_grad and 'fc' not in name]
    elif model_name in ['EfficientNet','EfficientNet_Attention','MobileNet','MobileNet_Attention']:
        x = [param.numel() for name,param in model.named_parameters() if param.requires_grad and 'classifier' not in name]
    return sum(x) / (1024 * 1024)

def load_h5_data_to_memory(dataset_name):
    h5_data = defaultdict(dict)  
    file_path = f'./data/eccv16_dataset_{dataset_name.lower()}_google_pool5.h5'
    if dataset_name == 'QFVS' or dataset_name == 'QFVS2':
        with h5py.File(file_path, 'r') as hdf:
            for video_num in hdf.keys():
                h5_data[video_num]['gtscore'] = np.array(hdf[f"{video_num}/gtscore"])
                h5_data[video_num]['n_frames'] = np.array(hdf[f"{video_num}/n_frames"])
    else:
        with h5py.File(file_path, 'r') as hdf:
            for video_num in hdf.keys():
                h5_data[video_num]['user_summary'] = np.array(hdf[f"{video_num}/user_summary"])
                h5_data[video_num]['change_points'] = np.array(hdf[f"{video_num}/change_points"])
                h5_data[video_num]['n_frames'] = np.array(hdf[f"{video_num}/n_frames"])
                h5_data[video_num]['positions'] = np.array(hdf[f"{video_num}/picks"])
                # h5_data[video_num]['gtscore'] = np.array(hdf[f"{video_num}/gtscore"])
    return h5_data

def get_exp_name(config):
    # Hyperparameters
    hparams_info = f"bs={config.batch_size},e={config.epochs},d={config.Dropout_ratio if config.Dropout_on else 0},r={config.repeat},s={config.seed},lr={config.learning_rate},wd={config.weight_decay}"

    if config.pt_ckpt_path != 'CSTA':
        exp_name = f"{Path(config.pt_ckpt_path).parent.stem}/{Path(config.pt_ckpt_path).stem}"
    else:
        exp_name = f"CSTA"

    skip_exps = get_skip_exps(config.skip_files,thredhold=config.skip_thredhold)
    if exp_name in skip_exps:
        print(f"Experiment {exp_name} is in skip list, skipping...")
        exit(17)

    if Path(config.result_path).exists():
        with open(config.result_path,'r') as f:
            lines = f.readlines()
            for line in lines:
                if f"({hparams_info})[{exp_name}]" in line:
                    print(f"Experiment ({hparams_info})[{exp_name}] already exists, skipping...")
                    exit(17)
    if config.check:
        exit(0)
    return exp_name,hparams_info

# Funtion printing the number of parameters of models
def report_params(model_name,
                  Scale,
                  Softmax_axis,
                  Balance,
                  Positional_encoding,
                  Positional_encoding_shape,
                  Positional_encoding_way,
                  Dropout_on,
                  Dropout_ratio,
                  Classifier_on,
                  CLS_on,
                  CLS_mix,
                  key_value_emb,
                  Skip_connection,
                  Layernorm):
    if model_name in ['EfficientNet','EfficientNet_Attention']:
        model = CSTA_EfficientNet(
            model_name=model_name,
            Scale=Scale,
            Softmax_axis=Softmax_axis,
            Balance=Balance,
            Positional_encoding=Positional_encoding,
            Positional_encoding_shape=Positional_encoding_shape,
            Positional_encoding_way=Positional_encoding_way,
            Dropout_on=Dropout_on,
            Dropout_ratio=Dropout_ratio,
            Classifier_on=Classifier_on,
            CLS_on=CLS_on,
            CLS_mix=CLS_mix,
            key_value_emb=key_value_emb,
            Skip_connection=Skip_connection,
            Layernorm=Layernorm
        )
    elif model_name in ['GoogleNet','GoogleNet_Attention']:
        model = CSTA_GoogleNet(
            model_name=model_name,
            Scale=Scale,
            Softmax_axis=Softmax_axis,
            Balance=Balance,
            Positional_encoding=Positional_encoding,
            Positional_encoding_shape=Positional_encoding_shape,
            Positional_encoding_way=Positional_encoding_way,
            Dropout_on=Dropout_on,
            Dropout_ratio=Dropout_ratio,
            Classifier_on=Classifier_on,
            CLS_on=CLS_on,
            CLS_mix=CLS_mix,
            key_value_emb=key_value_emb,
            Skip_connection=Skip_connection,
            Layernorm=Layernorm
        )
    elif model_name in ['MobileNet','MobileNet_Attention']:
        model = CSTA_MobileNet(
            model_name=model_name,
            Scale=Scale,
            Softmax_axis=Softmax_axis,
            Balance=Balance,
            Positional_encoding=Positional_encoding,
            Positional_encoding_shape=Positional_encoding_shape,
            Positional_encoding_way=Positional_encoding_way,
            Dropout_on=Dropout_on,
            Dropout_ratio=Dropout_ratio,
            Classifier_on=Classifier_on,
            CLS_on=CLS_on,
            CLS_mix=CLS_mix,
            key_value_emb=key_value_emb,
            Skip_connection=Skip_connection,
            Layernorm=Layernorm
        )
    elif model_name in ['ResNet','ResNet_Attention']:
        model = CSTA_ResNet(
            model_name=model_name,
            Scale=Scale,
            Softmax_axis=Softmax_axis,
            Balance=Balance,
            Positional_encoding=Positional_encoding,
            Positional_encoding_shape=Positional_encoding_shape,
            Positional_encoding_way=Positional_encoding_way,
            Dropout_on=Dropout_on,
            Dropout_ratio=Dropout_ratio,
            Classifier_on=Classifier_on,
            CLS_on=CLS_on,
            CLS_mix=CLS_mix,
            key_value_emb=key_value_emb,
            Skip_connection=Skip_connection,
            Layernorm=Layernorm
        )
    print(f"PARAMS: {count_parameters(model,model_name):.2f}M")

# Print all arguments and GPU setting
def print_args(args):
    print(args.kwargs)
    print(f"CUDA: {torch.version.cuda}")
    print(f"cuDNN: {torch.backends.cudnn.version()}")
    if 'cuda' in args.device:
        print(f"GPU: {torch.cuda.is_available()}")
        print(f"GPU count: {torch.cuda.device_count()}")
        print(f"GPU name: {torch.cuda.get_device_name(0)}")

# Load ground truth for TVSum
def get_gt(dataset):
    if dataset=='TVSum':
        annot_path = f"./data/ydata-anno.tsv"
        with open(annot_path) as annot_file:
            annot = list(csv.reader(annot_file, delimiter="\t"))
        annotation_length = list(Counter(np.array(annot)[:, 0]).values())
        user_scores = []
        for idx in range(1,51):
            init = (idx - 1) * annotation_length[idx-1]
            till = idx * annotation_length[idx-1]
            user_score = []
            for row in annot[init:till]:
                curr_user_score = row[2].split(",")
                curr_user_score = np.array([float(num) for num in curr_user_score])
                curr_user_score = curr_user_score / curr_user_score.max(initial=-1)
                curr_user_score = curr_user_score[::15]

                user_score.append(curr_user_score)
            user_scores.append(user_score)
        return user_scores
    elif dataset=='SumMe' or dataset=='QFVS' or dataset=='QFVS2':
        return None
    else:
        raise

def get_skip_exps(file_paths,thredhold=0.24):
    pattern = re.compile(
        r'\[(.*?)\]: .*?Kendall_Spear:(.*?)±'
    )
    skip_exps = []
    total=0
    if file_paths is None:
        return skip_exps
    for file_path in file_paths.split(','):
        with open(file_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                match = pattern.search(line)
                if match:
                    total+=1
                    tag = match.group(1)
                    kendall_spear = float(match.group(2))
                    if kendall_spear < thredhold:
                        skip_exps.append(tag)
    print(f"Skip {len(skip_exps)} exps out of {total}")
    return skip_exps