import os
import clip
import torch.nn as nn
from datasets import Action_DATASETS
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
import argparse
import shutil
from pathlib import Path
import yaml
from dotmap import DotMap
import pprint
import numpy
from modules.Visual_Prompt import visual_prompt
from utils.Augmentation import get_augmentation
import torch
from utils.Text_Prompt import *
import numpy as np
import editdistance
import copy
import json
import pandas as pd
import pickle

class TextCLIP(nn.Module):
    def __init__(self, model):
        super(TextCLIP, self).__init__()
        self.model = model

    def forward(self, text):
        return self.model.encode_text(text)

class ImageCLIP(nn.Module):
    def __init__(self, model):
        super(ImageCLIP, self).__init__()
        self.model = model

    def forward(self, image):
        return self.model.encode_image(image)

def validate(epoch, val_loader, val_data, device, model, fusion_model, config):
    text_aug = [f"a photo of action {{}}", f"a picture of action {{}}", f"Human action of {{}}", f"{{}}, an action",
                f"{{}} this is an action", f"{{}}, a video of action", f"Playing action of {{}}", f"{{}}",
                f"Playing a kind of action, {{}}", f"Doing a kind of action, {{}}", f"Look, the human is {{}}",
                f"Can you recognize the action of {{}}?", f"Video classification of {{}}", f"A video of {{}}",
                f"The man is {{}}", f"The woman is {{}}"]
    text_dict = {}
    num_text_aug = len(text_aug)

    try:
        textlist = [c for i, c in val_data.classes]
    except:
        textlist = pd.read_csv(config.data.label_list).name.values.tolist()

    for ii, txt in enumerate(text_aug):
        text_dict[ii] = torch.cat([clip.tokenize(txt.format(c)) for c in textlist])

    classes = torch.cat([v for k, v in text_dict.items()])

    model.eval()
    fusion_model.eval()
    num = 0
    corr_1 = 0
    corr_5 = 0

    with torch.no_grad():
        text_inputs = classes.to(device)
        text_features = model.encode_text(text_inputs)
        for iii, (image_features, class_id) in enumerate(tqdm(val_loader)):
            image_features = image_features.to(device)
            class_id = class_id.to(device)
            b = image_features.shape[0]
            text_features /= text_features.norm(dim=-1, keepdim=True)
            similarity = (100.0 * image_features @ text_features.T)
            similarity = similarity.view(b, num_text_aug, -1).softmax(dim=-1)
            similarity = similarity.mean(dim=1, keepdim=False)
            values_1, indices_1 = similarity.topk(1, dim=-1)
            values_5, indices_5 = similarity.topk(5, dim=-1)
            num += b
            for i in range(b):
                if indices_1[i] == class_id[i]:
                    corr_1 += 1
                if class_id[i] in indices_5[i]:
                    corr_5 += 1

    top1 = float(corr_1) / num * 100
    top5 = float(corr_5) / num * 100
    clipmodel, _ = clip.load('ViT-B/16', device=device, jit=False)
    def eval(ori, new):
        ori_textlist = pd.read_csv(ori).name.values
        ori_textwords = np.concatenate([text.split(" ") for text in ori_textlist])
        new_textlist_ = new

        counter=[]
        for i in range(len(ori_textlist)):
            counter.append(editdistance.eval(ori_textlist[i], new_textlist_[i]))
        counter = np.array(counter)

        with torch.no_grad():
            new_tFeature = clipmodel.encode_text(torch.cat([clip.tokenize(text, context_length=77) for text in new_textlist_]).to(device))
            ori_tFeature = clipmodel.encode_text(torch.cat([clip.tokenize(text, context_length=77) for text in ori_textlist]).to(device))
        
        new_tFeature =  new_tFeature / new_tFeature.norm(dim=-1, keepdim=True)
        ori_tFeature =  ori_tFeature / ori_tFeature.norm(dim=-1, keepdim=True)

        return counter.sum() / len(ori_textwords), counter.mean(), np.sum(counter == 0) / len(counter), np.sum(counter <= 1) / len(counter), np.sum(counter <= 2) / len(counter), ((new_tFeature@ori_tFeature.T).diag()).mean().cpu().numpy()

    edit_distance_per_word_list, edit_distance_per_label_list, acc_of_label_0_list, acc_of_label_1_list, acc_of_label_2_list, semantic_cos_sim_list = [], [], [], [], [], []
    edit_distance_per_word, edit_distance_per_label, acc_of_label_0, acc_of_label_1, acc_of_label_2, semantic_cos_sim = eval("./lists/ucf_labels_sep.csv", textlist)
    edit_distance_per_word_list.append(edit_distance_per_word)
    edit_distance_per_label_list.append(edit_distance_per_label)
    acc_of_label_0_list.append(acc_of_label_0)
    acc_of_label_1_list.append(acc_of_label_1)
    acc_of_label_2_list.append(acc_of_label_2)
    semantic_cos_sim_list.append(semantic_cos_sim)

    print('valtop1 {:.03f},'.format(top1),
        'valtop5 {:.03f}.'.format(top5),
        'edit distance per word {:.03f}.'.format(np.mean(edit_distance_per_word_list)),
        'edit distance per label {:.03f}.'.format(np.mean(edit_distance_per_label_list)),
        'acc of label 0 {:.03f}.'.format(np.mean(acc_of_label_0_list)),
        'acc of label 1 {:.03f}.'.format(np.mean(acc_of_label_1_list)),
        'acc of label 2 {:.03f}.'.format(np.mean(acc_of_label_2_list)),
        'semantic cos sim {:.03f},'.format(np.mean(semantic_cos_sim_list)))
    return top1

def main():
    global args, best_prec1
    global global_step
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-cfg', type=str, default='./configs/ucf101/ucf_zero_shot.yaml') # './configs/ucf101/ucf_zero_shot.yaml' './configs/hmdb51/hmdb_zero_shot.yaml' './configs/k700/k700_zero_shot.yaml'
    parser.add_argument('--name', type=str, default='ucf')
    parser.add_argument('--log_time', type=str, default='20240105_172223')
    parser.add_argument('--idx_sim', type=int, default=1)

    args = parser.parse_args()

    with open(args.config, 'r') as f:
        config = yaml.full_load(f)
    working_dir = os.path.join('./exp', config['network']['type'], config['network']['arch'], config['data']['dataset'],
                               args.log_time)
    
    config["data"]["label_list"] = os.path.join(working_dir, f"{args.name}_labels_{args.idx_sim}.csv")

    with open(os.path.join(working_dir,"config.yaml"), 'w') as f:
        yaml.dump(data=config, stream=f)

    print('-' * 80)
    print(' ' * 20, "working dir: {}".format(working_dir))
    print('-' * 80)

    print('-' * 80)
    print(' ' * 30, "Config")
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(config)
    print('-' * 80)

    config = DotMap(config)

    Path(working_dir).mkdir(parents=True, exist_ok=True)
    shutil.copy('test_vanilla.py', working_dir)

    device = "cuda" if torch.cuda.is_available() else "cpu"  # If using GPU then use mixed precision training.

    model, clip_state_dict = clip.load(config.network.arch, device=device, jit=False, tsm=config.network.tsm,
                                                   T=config.data.num_segments, dropout=config.network.drop_out,
                                                   emb_dropout=config.network.emb_dropout)  # Must set jit=False for training  ViT-B/32

    transform_val = get_augmentation(False, config)

    fusion_model = visual_prompt(config.network.sim_header, clip_state_dict, config.data.num_segments)

    model_text = TextCLIP(model)
    model_image = ImageCLIP(model)

    model_text = torch.nn.DataParallel(model_text).cuda()
    model_image = torch.nn.DataParallel(model_image).cuda()
    fusion_model = torch.nn.DataParallel(fusion_model).cuda()

    val_data = TensorDataset(pickle.load(open("./features/image_features_list","rb")), pickle.load(open("./features/class_id_list","rb")))
    
    val_loader = DataLoader(val_data, batch_size=config.data.batch_size, num_workers=config.data.workers, shuffle=True, # remember to set to true for batch E-M
                            pin_memory=True, drop_last=True, persistent_workers=True)

    if device == "cpu":
        model_text.float()
        model_image.float()
    else:
        clip.model.convert_weights(
            model_text)  # Actually this line is unnecessary since clip by default already on float16
        clip.model.convert_weights(model_image)

    start_epoch = config.solver.start_epoch

    if config.pretrain:
        if os.path.isfile(config.pretrain):
            print(("=> loading checkpoint '{}'".format(config.pretrain)))
            checkpoint = torch.load(config.pretrain)
            model.load_state_dict(checkpoint['model_state_dict'])
            fusion_model.load_state_dict(checkpoint['fusion_model_state_dict'])
            del checkpoint
        else:
            print(("=> no checkpoint found at '{}'".format(config.pretrain)))

    prec1 = validate(start_epoch, val_loader, val_data, device, model, fusion_model, config)

if __name__ == '__main__':
    main()
