import h5py
import numpy as np
import os
import torch
import time
from config import get_config
from dataset import create_dataloader
from evaluation_metrics import get_corr_coeff, get_summ_f1score, get_summ_diversity
from generate_summary import generate_summary
from model import set_model
from utils import report_params, print_args, get_gt, get_exp_name, load_h5_data_to_memory
from pathlib import Path
from collections import defaultdict

os.environ["OMP_NUM_THREADS"] = "8"
os.environ["MKL_NUM_THREADS"] = "8"
torch.set_num_threads(8)



def run_trial(config):
    # Start training
    all_scores = {}
    for dataset in config.datasets.split(','):
        h5_data = load_h5_data_to_memory(dataset)
        user_scores = get_gt(dataset)
        split_kendalls = []
        split_spears = []
        split_f1s = []
        split_best_f1s = []
        split_best_epoch = []

        batch_size = int(config.batch_size)
        for split_id,(train_loader,test_loader) in enumerate(create_dataloader(dataset)):
            model = set_model(
                model_name=config.model_name,
                Scale=config.Scale,
                Softmax_axis=config.Softmax_axis,
                Balance=config.Balance,
                Positional_encoding=config.Positional_encoding,
                Positional_encoding_shape=config.Positional_encoding_shape,
                Positional_encoding_way=config.Positional_encoding_way,
                Dropout_on=config.Dropout_on,
                Dropout_ratio=config.Dropout_ratio,
                Classifier_on=config.Classifier_on,
                CLS_on=config.CLS_on,
                CLS_mix=config.CLS_mix,
                key_value_emb=config.key_value_emb,
                Skip_connection=config.Skip_connection,
                Layernorm=config.Layernorm
            )
            if config.pt_ckpt_path != 'CSTA':
                model.load_state_dict(torch.load(config.pt_ckpt_path, map_location='cpu')['state_dict'])
            model.to(config.device)
            criterion = torch.nn.MSELoss()
            optimizer = torch.optim.Adam(model.parameters(),lr=float(config.learning_rate),weight_decay=float(config.weight_decay))

            model_selection_kendall = -1    
            model_selection_spear = -1
            model_selection_f1 = -1
            model_best_f1 = -1
            model_selection_coeff = -1
            model_selection_epoch = -1

            for epoch in range(config.epochs):
                model.train()
                update_loss = 0.0
                batch = 0

                for feature,gtscore,dataset_name,video_num in train_loader:
                    feature = feature.to(config.device)
                    gtscore = gtscore.to(config.device)
                    output = model(feature)

                    loss = criterion(output,gtscore) 
                    loss.requires_grad_(True)
                    update_loss += loss
                    batch += 1

                    if batch==batch_size:
                        optimizer.zero_grad()
                        update_loss = update_loss / batch
                        update_loss.backward()
                        optimizer.step()
                        update_loss = 0.0
                        batch = 0

                if batch>0:
                    optimizer.zero_grad()
                    update_loss = update_loss / batch
                    update_loss.backward()
                    optimizer.step()
                    update_loss = 0.0
                    batch = 0

                val_spears = []
                val_kendalls = []
                val_f1scores = []
                val_coeffs = []
                model.eval()
                with torch.no_grad():
                    for feature,gtscore,dataset_name,video_num in test_loader:
                        feature = feature.to(config.device)
                        gtscore = gtscore.to(config.device)
                        output = model(feature)

                        if dataset_name in ['SumMe','TVSum']:
                            user_summary = h5_data[video_num]['user_summary']
                            sb = h5_data[video_num]['change_points']
                            n_frames = h5_data[video_num]['n_frames']
                            positions = h5_data[video_num]['positions']
                            scores = output.cpu().numpy().tolist()
                            summary = generate_summary([sb], [scores], [n_frames], [positions])[0]
                            if dataset_name=='SumMe':
                                spear,kendall = get_corr_coeff([summary],[video_num],dataset_name,user_summary)
                                f1_score = get_summ_f1score(summary,user_summary,'max')
                            elif dataset_name=='TVSum':
                                spear,kendall = get_corr_coeff([scores],[video_num],dataset_name,user_scores)
                                f1_score = get_summ_f1score(summary,user_summary,'avg')
                            
                            val_spears.append(spear)
                            val_kendalls.append(kendall)
                            val_f1scores.append(f1_score)
                            val_coeffs.append((spear+kendall)/2)

                # model_selection_kendall = max(model_selection_kendall,np.mean(val_kendalls).item())
                # model_selection_spear = max(model_selection_spear,np.mean(val_spears).item())
                if np.mean(val_coeffs) > model_selection_coeff:
                    model_selection_coeff = np.mean(val_coeffs).item()
                    model_selection_f1 = np.mean(val_f1scores).item()
                    model_selection_epoch = epoch
                    model_selection_kendall = np.mean(val_kendalls).item()
                    model_selection_spear = np.mean(val_spears).item()

                if np.mean(val_f1scores) > model_best_f1:
                    model_best_f1 = np.mean(val_f1scores)

            split_kendalls.append(model_selection_kendall)
            split_spears.append(model_selection_spear)
            split_f1s.append(model_selection_f1)
            split_best_f1s.append(model_best_f1)
            split_best_epoch.append(model_selection_epoch)
        all_scores[dataset] =  {
                'Kendall':np.mean(split_kendalls).item(),
                'Spear':np.mean(split_spears).item(),
                'Best_F1':np.mean(split_best_f1s).item(),
            }
    return all_scores

if __name__=="__main__":
    # Load configurations
    config = get_config()
    exp_name, hparams_info = get_exp_name(config)

    # Print the number of parameters
    report_params(
        model_name=config.model_name,
        Scale=config.Scale,
        Softmax_axis=config.Softmax_axis,
        Balance=config.Balance,
        Positional_encoding=config.Positional_encoding,
        Positional_encoding_shape=config.Positional_encoding_shape,
        Positional_encoding_way=config.Positional_encoding_way,
        Dropout_on=config.Dropout_on,
        Dropout_ratio=config.Dropout_ratio,
        Classifier_on=config.Classifier_on,
        CLS_on=config.CLS_on,
        CLS_mix=config.CLS_mix,
        key_value_emb=config.key_value_emb,
        Skip_connection=config.Skip_connection,
        Layernorm=config.Layernorm
    )

    repeat_res = defaultdict(list)
    for iteration in range(config.repeat):
        start_time  = time.time()
        all_scores = run_trial(config)
        for dataset,metrics in all_scores.items():
            repeat_res[f'{dataset}.Kendall_Spear'].append(
                np.mean([metrics['Kendall'],metrics['Spear']]).item()
            )
            repeat_res[f'{dataset}.Kendall'].append(metrics['Kendall'])
            repeat_res[f'{dataset}.Spear'].append(metrics['Spear'])
        end_time = time.time()
        duration_hours = (end_time - start_time) / 3600
        print(f"\n=========Iter {iteration}, execution time: {duration_hours:.3f} hours=========\n")

    exp_results=[]


    for metric,values in repeat_res.items():
        if 'Kendall_Spear' in metric:
            continue
        res_info = f'{metric}:{np.mean(values):.3f}±{np.std(values):.3f}'
        exp_results.append(res_info)

    best_info = []
    best_coeffs = []
    joint_mean = 0
    joint_var = 0
    for dataset in config.datasets.split(','):
        joint_mean += np.mean(repeat_res[f'{dataset}.Kendall_Spear'])
        joint_var += np.var(repeat_res[f'{dataset}.Kendall_Spear'], ddof=1)
        res_info = ", ".join([f"{v:.3f}" for v in repeat_res[f'{dataset}.Kendall_Spear']])
        exp_results.append(f'{dataset}:<{res_info}>')

        best_ind = repeat_res[f'{dataset}.Kendall_Spear'].index(max(repeat_res[f'{dataset}.Kendall_Spear']))
        best_info.append(
            f"{dataset}.Kendall:{repeat_res[f'{dataset}.Kendall'][best_ind]:.3f}; {dataset}.Spear:{repeat_res[f'{dataset}.Spear'][best_ind]:.3f}"
        )
        best_coeffs.append(repeat_res[f'{dataset}.Kendall_Spear'][best_ind])
    best_info = f"AVG:{np.mean(best_coeffs):.3f}; {'; '.join(best_info)}"
    exp_results = [f"Overall:{joint_mean/2:.3f}±{np.sqrt(joint_var)/2:.3f}"] + exp_results
    exp_results = [f"BestIter:<{best_info}>"] + exp_results

    Path(config.result_path).parent.mkdir(exist_ok=True,parents=True)
    with open(config.result_path,'a') as f:
        f.write(
            f"({hparams_info})[{exp_name}]: " + ", ".join(exp_results) + "\n"
        )