import torch
import cv2
import random
import os.path as osp
import my_dataset.datasets as datasets
import argparse
from scipy.stats import spearmanr, pearsonr
from scipy.stats.stats import kendalltau as kendallr
import numpy as np
from time import time
from tqdm import tqdm
import pickle
import math
import yaml
from functools import reduce
from thop import profile
from Q_CLIP import CLIP_VQA
import copy
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def train_test_split(dataset_path, ann_file, ratio=0.8, seed=42):
    random.seed(seed)
    video_infos = []
    with open(ann_file, "r") as fin:
        for line in fin.readlines():
            line_split = line.strip().split(",")
            filename, _, _, label = line_split
            label = float(label)
            filename = osp.join(dataset_path, filename)
            video_infos.append(dict(filename=filename, label=label))
    random.shuffle(video_infos)
    return (
        video_infos[: int(ratio * len(video_infos))],
        video_infos[int(ratio * len(video_infos)):],
    )


def rank_loss(y_pred, y):
    ranking_loss = torch.nn.functional.relu(
        (y_pred - y_pred.t()) * torch.sign((y.t() - y))
    )
    scale = 1 + torch.max(ranking_loss)
    return (
            torch.sum(ranking_loss) / y_pred.shape[0] / (y_pred.shape[0] - 1) / scale
    ).float()


def plcc_loss(y_pred, y):
    sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False)
    y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8)
    sigma, m = torch.std_mean(y, unbiased=False)
    y = (y - m) / (sigma + 1e-8)
    loss0 = torch.nn.functional.mse_loss(y_pred, y) / 4
    rho = torch.mean(y_pred * y)
    loss1 = torch.nn.functional.mse_loss(rho * y_pred, y) / 4
    return ((loss0 + loss1) / 2).float()


def rescaled_l2_loss(y_pred, y):
    y_pred_rs = (y_pred - y_pred.mean()) / y_pred.std()
    y_rs = (y - y.mean()) / (y.std() + eps)
    return torch.nn.functional.mse_loss(y_pred_rs, y_rs)


def rplcc_loss(y_pred, y, eps=1e-8):
    ## Literally (1 - PLCC) / 2
    cov = torch.cov(y_pred, y)
    std = (torch.std(y_pred) + eps) * (torch.std(y) + eps)
    return (1 - cov / std) / 2


def self_similarity_loss(f, f_hat, f_hat_detach=False):
    if f_hat_detach:
        f_hat = f_hat.detach()
    return 1 - torch.nn.functional.cosine_similarity(f, f_hat, dim=1).mean()


def contrastive_similarity_loss(f, f_hat, f_hat_detach=False, eps=1e-8):
    if f_hat_detach:
        f_hat = f_hat.detach()
    intra_similarity = torch.nn.functional.cosine_similarity(f, f_hat, dim=1).mean()
    cross_similarity = torch.nn.functional.cosine_similarity(f, f_hat, dim=0).mean()
    return (1 - intra_similarity) / (1 - cross_similarity + eps)


def rescale(pr, gt=None):
    if gt is None:
        pr = (pr - np.mean(pr)) / np.std(pr)
    else:
        pr = ((pr - np.mean(pr)) / np.std(pr)) * np.std(gt) + np.mean(gt)
    return pr



def finetune_epoch(ft_loader, model, model_ema, optimizer, scheduler, device, epoch=-1):
    model.train()
    for i, data in enumerate(tqdm(ft_loader, desc=f"Training in epoch {epoch}")):
        optimizer.zero_grad()
        video = data["video"].to(device)
        # print("video", video.shape)
        y = data["gt_label"].float().detach().to(device).unsqueeze(-1)
        # print("123,",y)

        scores = model(video)
        y_pred = scores.unsqueeze(1)
        # print("y_pred,",y_pred)
        # print("y_pred", y_pred.shape)
        # print("y", y.shape)
        # Plain Supervised Loss
        p_loss, r_loss = plcc_loss(y_pred, y), rank_loss(y_pred, y)

        loss = p_loss + 0.3 * r_loss

        loss.backward()
        optimizer.step()
        scheduler.step()

        # ft_loader.dataset.refresh_hypers()

        # if model_ema is not None:
        #     model_params = dict(model.named_parameters())
        #     model_ema_params = dict(model_ema.named_parameters())
        #     for k in model_params.keys():
        #         model_ema_params[k].data.mul_(0.999).add_(
        #             model_params[k].data, alpha=1 - 0.999
        #         )
    model.eval()



def inference_set(inf_loader, model, device, best_, save_model=False, suffix='s', save_name="divide"):
    results = []

    best_s, best_p, best_k, best_r = best_

    for i, data in enumerate(tqdm(inf_loader, desc="Validating")):
        result = dict()
        video, video_up = {}, {}
        video = data["video"].to(device)

        with torch.no_grad():
            result["pr_labels"] = model(video).cpu().numpy()

        result["gt_label"] = data["gt_label"].item()
        del video, video_up
        # result['frame_inds'] = data['frame_inds']
        # del data
        results.append(result)

    ## generate the demo video for video quality localization
    gt_labels = [r["gt_label"] for r in results]
    pr_labels = [np.mean(r["pr_labels"][:]) for r in results]
    pr_labels = rescale(pr_labels, gt_labels)

    s = spearmanr(gt_labels, pr_labels)[0]
    p = pearsonr(gt_labels, pr_labels)[0]
    k = kendallr(gt_labels, pr_labels)[0]
    r = np.sqrt(((gt_labels - pr_labels) ** 2).mean())


    del results, result  # , video, video_up
    torch.cuda.empty_cache()

    if s + p > best_s + best_p and save_model:
        state_dict = model.state_dict()
        torch.save(
            {
                "state_dict": state_dict,
                "validation_results": best_,
            },
            f"pretrained_weights/{save_name}_{suffix}_dev_v0.0.pth",
        )

    best_s, best_p, best_k, best_r = (
        max(best_s, s),
        max(best_p, p),
        max(best_k, k),
        min(best_r, r),
    )

    print(
        f"For {len(inf_loader)} videos, \nthe accuracy of the model: [{suffix}] is as follows:\n  SROCC: {s:.4f} best: {best_s:.4f} \n  PLCC:  {p:.4f} best: {best_p:.4f}  \n  KROCC: {k:.4f} best: {best_k:.4f} \n  RMSE:  {r:.4f} best: {best_r:.4f}."
    )

    return best_s, best_p, best_k, best_r


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-o", "--opt", type=str, default="./finetune_konvid-1k.yml", help="the option file"
    )
    args = parser.parse_args()
    with open(args.opt, "r") as f:
        opt = yaml.safe_load(f)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    sum_SROCC = 0
    sum_KROCC = 0
    sum_RMSE = 0
    sum_PLCC = 0

    for i in range(1,11):
    # 实例化模型
        model = CLIP_VQA()
        model = model.to("cpu")
        # 加载保存的字典文件并提取state_dict
        checkpoint = torch.load(
            "./pretrained_weights/...",map_location=torch.device('cpu')
        )
        state_dict = checkpoint["state_dict"]  # 只提取state_dict

        # 4. 加载参数到模型
        model.load_state_dict(state_dict)
        # torch.cuda.empty_cache()
        model = model.to(device)
        # print(model)
        # 加载训练集的dataset和dataloader
        train_datasets = {}
        for key in opt["data"]:
            if key.startswith("train"):
                # opt["data"][key]["args"]["anno_file"] = f"./data_labels/IQA_labels/kadid10k/kadid10k_train_{i}.txt"
                current_path = opt["data"][key]["args"]["anno_file"]
                if i <= 10:
                    update_path = f"{current_path[:-5]}{i}.txt"
                else:
                    update_path = f"{current_path[:-6]}{i}.txt"
                opt["data"][key]["args"]["anno_file"] = update_path
                train_dataset = getattr(datasets, opt["data"][key]["type"])(opt["data"][key]["args"])
                train_datasets[key] = train_dataset

        train_loaders = {}
        for key, train_dataset in train_datasets.items():
            train_loaders[key] = torch.utils.data.DataLoader(
                train_dataset, batch_size=opt["batch_size"], num_workers=opt["num_workers"], shuffle=True,
            )

        # 加载验证集的dataset和dataloader
        val_datasets = {}
        for key in opt["data"]:
            if key.startswith("val"):
                # opt["data"][key]["args"]["anno_file"] = f"./data_labels/IQA_labels/kadid10k/kadid10k_test_{i}.txt"
                current_path = opt["data"][key]["args"]["anno_file"]
                if i <= 10:
                    update_path = f"{current_path[:-5]}{i}.txt"
                else:
                    update_path = f"{current_path[:-6]}{i}.txt"
                opt["data"][key]["args"]["anno_file"] = update_path
                val_datasets[key] = getattr(datasets, opt["data"][key]["type"])(opt["data"][key]["args"])

        val_loaders = {}
        for key, val_dataset in val_datasets.items():
            val_loaders[key] = torch.utils.data.DataLoader(
                val_dataset, batch_size=1, num_workers=opt["num_workers"], pin_memory=True,
            )


        if opt["ema"]:
            from copy import deepcopy
            model_ema = deepcopy(model)
        else:
            model_ema = None

        # todo: AdamW优化器

        # 冻结模型的所有参数
        # for param in model.parameters():
        #     param.requires_grad = False

        # print("冻结 image_adapters 组件...")
        # for adapter in model.image_adapters:
        #     for param in adapter.parameters():
        #         param.requires_grad = False

        print("冻结 all_shared_mlps 组件...")
        for param in model.all_shared_mlps.parameters():
            param.requires_grad = False

        # print("冻结 proj_adapter 组件...")
        # for param in model.proj_adapter.parameters():
        #     param.requires_grad = False


        # # 1. 提取image_adapters的参数
        # params_image_adapters = []
        # for adapter in model.image_adapters:
        #     params_image_adapters.extend(adapter.parameters())
        #
        # # 2. 提取all_shared_mlps的参数
        # params_all_shared_mlps = list(model.all_shared_mlps.parameters())
        #
        # # 3. 提取proj_adapter的参数
        # params_proj_adapter = list(model.proj_adapter.parameters())
        #
        # # 4. 提取模型其他组件的参数（使用默认学习率）
        # other_params = []
        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         if not any(keyword in name for keyword in ["image_adapters", "all_shared_mlps", "proj_adapter"]):
        #             other_params.append(param)
        #
        # # 5. 创建参数分组
        # param_groups = [
        #     {"params": params_image_adapters, "lr": 1e-4, "name": "image_adapters"},
        #     {"params": params_all_shared_mlps, "lr": 1e-4, "name": "all_shared_mlps"},
        #     {"params": params_proj_adapter, "lr": 1e-3, "name": "proj_adapter"},
        #     {"params": other_params, "lr": max(1e-4, 1e-4, 1e-3) / 10, "name": "other_params"}
        # ]
        # # 6. 过滤掉空的参数组
        # param_groups = [group for group in param_groups if group["params"]]

        param_groups = [{"params": model.parameters()}]

        optimizer = torch.optim.AdamW(lr=opt["optimizer"]["lr"], params=param_groups,
                                      weight_decay=opt["optimizer"]["wd"],
                                      )

        # optimizer = torch.optim.AdamW(
        #     param_groups,
        #     betas=(0.9, 0.999),
        #     eps=1e-8,
        #     weight_decay=0.01
        # )

        warmup_iter = 0
        for train_loader in train_loaders.values():
            warmup_iter += int(opt["warmup_epochs"] * len(train_loader))
        max_iter = int((opt["num_epochs"]) * len(train_loader))
        lr_lambda = (
            lambda cur_iter: cur_iter / warmup_iter
            if cur_iter <= warmup_iter
            else 0.5 * (1 + math.cos(math.pi * (cur_iter - warmup_iter) / max_iter))
        )

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lr_lambda,
        )


        bests = {}
        bests_n = {}
        for key in val_loaders:
            bests[key] = -1, -1, -1, 1000
            bests_n[key] = -1, -1, -1, 1000


        for epoch in range(opt["num_epochs"]):
            print(f"Finetune Epoch {epoch}:")

            for key, train_loader in train_loaders.items():
                finetune_epoch(
                    train_loader, model, model_ema, optimizer, scheduler, device, epoch
                )

            for key in val_loaders:
                bests[key] = inference_set(
                    val_loaders[key],
                    model_ema if model_ema is not None else model,
                    device, bests[key], save_model=opt["save_model"], save_name=opt["name"],
                    suffix=key + "_s",
                )
                if model_ema is not None:
                    bests_n[key] = inference_set(
                        val_loaders[key],
                        model,
                        device, bests_n[key], save_model=opt["save_model"], save_name=opt["name"],
                        suffix=key + '_n',
                    )
                else:
                    bests_n[key] = bests[key]

        for key in val_loaders:
            print(
                f"""The {i}-th test result on {key},
                the best validation accuracy of the model-s is as follows:
                SROCC: {bests[key][0]:.4f}
                PLCC:  {bests[key][1]:.4f}
                KROCC: {bests[key][2]:.4f}
                RMSE:  {bests[key][3]:.4f}."""
            )
            if model_ema is not None:
                print(
                    f"""For the finetuning process on {key} with {len(val_loaders[key])} videos,
                    the best validation accuracy of the model-n is as follows:
                    SROCC: {bests_n[key][0]:.4f}
                    PLCC:  {bests_n[key][1]:.4f}
                    KROCC: {bests_n[key][2]:.4f}
                    RMSE:  {bests_n[key][3]:.4f}."""
                )
        sum_SROCC = sum_SROCC + bests[key][0]
        sum_PLCC = sum_PLCC + bests[key][1]
        sum_KROCC = sum_KROCC + bests[key][2]
        sum_RMSE = sum_RMSE + bests[key][3]

    print(
        f"""the average validation accuracy of the model is as follows:
                   SROCC: {sum_SROCC/10}
                   PLCC:  {sum_PLCC/10}
                   KROCC: {sum_KROCC/10}
                   RMSE:  {sum_RMSE/10}."""
    )


if __name__ == "__main__":
    main()

