import os
import clip
import torch.nn as nn
from datasets import Action_DATASETS
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
import argparse
import shutil
from pathlib import Path
import yaml
from dotmap import DotMap
import pprint
import numpy
from modules.Visual_Prompt import visual_prompt
from utils.Augmentation import get_augmentation
import torch
from utils.Text_Prompt import *
import numpy as np
import editdistance
import copy
import json
import pandas as pd
import pickle

class TextCLIP(nn.Module):
    def __init__(self, model):
        super(TextCLIP, self).__init__()
        self.model = model

    def forward(self, text):
        return self.model.encode_text(text)

class ImageCLIP(nn.Module):
    def __init__(self, model):
        super(ImageCLIP, self).__init__()
        self.model = model

    def forward(self, image):
        return self.model.encode_image(image)

def validate(epoch, val_loader, val_data, device, model, fusion_model, config):
    text_aug = [f"a photo of action {{}}", f"a picture of action {{}}", f"Human action of {{}}", f"{{}}, an action",
                f"{{}} this is an action", f"{{}}, a video of action", f"Playing action of {{}}", f"{{}}",
                f"Playing a kind of action, {{}}", f"Doing a kind of action, {{}}", f"Look, the human is {{}}",
                f"Can you recognize the action of {{}}?", f"Video classification of {{}}", f"A video of {{}}",
                f"The man is {{}}", f"The woman is {{}}"]
    text_dict = {}
    num_text_aug = len(text_aug)

    try:
        textlist = [c for i, c in val_data.classes]
    except:
        textlist = pd.read_csv(config.data.label_list).name.values.tolist()

    for ii, txt in enumerate(text_aug):
        text_dict[ii] = torch.cat([clip.tokenize(txt.format(c)) for c in textlist])

    classes = torch.cat([v for k, v in text_dict.items()])

    model.eval()
    fusion_model.eval()
    num = 0
    corr_1 = 0
    corr_5 = 0

    with torch.no_grad():
        text_inputs = classes.to(device)
        text_features = model.encode_text(text_inputs)
        for iii, (image_features, class_id) in enumerate(tqdm(val_loader)):
            image_features = image_features.to(device)
            class_id = class_id.to(device)
            b = image_features.shape[0]
            text_features /= text_features.norm(dim=-1, keepdim=True)
            similarity = (100.0 * image_features @ text_features.T)
            similarity = similarity.view(b, num_text_aug, -1).softmax(dim=-1)
            similarity = similarity.mean(dim=1, keepdim=False)
            values_1, indices_1 = similarity.topk(1, dim=-1)
            values_5, indices_5 = similarity.topk(5, dim=-1)
            num += b
            for i in range(b):
                if indices_1[i] == class_id[i]:
                    corr_1 += 1
                if class_id[i] in indices_5[i]:
                    corr_5 += 1

    top1 = float(corr_1) / num * 100
    top5 = float(corr_5) / num * 100
    clipmodel, _ = clip.load('ViT-B/16', device=device, jit=False)
    def eval(ori, new):
        ori_textlist = pd.read_csv(ori).name.values
        ori_textwords = np.concatenate([text.split(" ") for text in ori_textlist])
        new_textlist_ = new

        counter=[]
        for i in range(len(ori_textlist)):
            counter.append(editdistance.eval(ori_textlist[i], new_textlist_[i]))
        counter = np.array(counter)

        with torch.no_grad():
            new_tFeature = clipmodel.encode_text(torch.cat([clip.tokenize(text, context_length=77) for text in new_textlist_]).to(device))
            ori_tFeature = clipmodel.encode_text(torch.cat([clip.tokenize(text, context_length=77) for text in ori_textlist]).to(device))
        
        new_tFeature =  new_tFeature / new_tFeature.norm(dim=-1, keepdim=True)
        ori_tFeature =  ori_tFeature / ori_tFeature.norm(dim=-1, keepdim=True)

        return counter.sum() / len(ori_textwords), counter.mean(), np.sum(counter == 0) / len(counter), np.sum(counter <= 1) / len(counter), np.sum(counter <= 2) / len(counter), ((new_tFeature@ori_tFeature.T).diag()).mean().cpu().numpy()

    edit_distance_per_word_list, edit_distance_per_label_list, acc_of_label_0_list, acc_of_label_1_list, acc_of_label_2_list, semantic_cos_sim_list = [], [], [], [], [], []
    edit_distance_per_word, edit_distance_per_label, acc_of_label_0, acc_of_label_1, acc_of_label_2, semantic_cos_sim = eval("./lists/ucf_labels_sep.csv", textlist)
    edit_distance_per_word_list.append(edit_distance_per_word)
    edit_distance_per_label_list.append(edit_distance_per_label)
    acc_of_label_0_list.append(acc_of_label_0)
    acc_of_label_1_list.append(acc_of_label_1)
    acc_of_label_2_list.append(acc_of_label_2)
    semantic_cos_sim_list.append(semantic_cos_sim)

    print('valtop1 {:.03f},'.format(top1),
        'valtop5 {:.03f}.'.format(top5),
        'edit distance per word {:.03f}.'.format(np.mean(edit_distance_per_word_list)),
        'edit distance per label {:.03f}.'.format(np.mean(edit_distance_per_label_list)),
        'acc of label 0 {:.03f}.'.format(np.mean(acc_of_label_0_list)),
        'acc of label 1 {:.03f}.'.format(np.mean(acc_of_label_1_list)),
        'acc of label 2 {:.03f}.'.format(np.mean(acc_of_label_2_list)),
        'semantic cos sim {:.03f},'.format(np.mean(semantic_cos_sim_list)))
    return top1
    
def validate_DENOISER(epoch, val_loader, val_data, device, model, fusion_model, config, working_dir):
    clipmodel, _ = clip.load('ViT-B/16', device=device, jit=False)

    image_features_all_in_one = []
    class_id_all_in_one = []

    text_aug = [f"a photo of action {{}}", f"a picture of action {{}}", f"Human action of {{}}", f"{{}}, an action",
                f"{{}} this is an action", f"{{}}, a video of action", f"Playing action of {{}}", f"{{}}",
                f"Playing a kind of action, {{}}", f"Doing a kind of action, {{}}", f"Look, the human is {{}}",
                f"Can you recognize the action of {{}}?", f"Video classification of {{}}", f"A video of {{}}",
                f"The man is {{}}", f"The woman is {{}}"]
    num_augs = len(text_aug)

    corpus_textlist=list(json.load(open(config["DENOISER"]["corpus"])).keys())
    try:
        textlist = [c for i, c in val_data.classes]
    except:
        textlist = pd.read_csv(config.data.label_list).name.values.tolist()
    num_classes = len(textlist)

    top_k = config["DENOISER"]["top_k"]
    num_beams = config["DENOISER"]["num_beams"]
    num_iter = config["DENOISER"]["num_iter"]
    if config["DENOISER"]["temp"] == "linear": # only used when both intra and inter is active
        temp = np.linspace(0.01,1,num_iter)
    elif config["DENOISER"]["temp"] == "log":
        temp = np.logspace(-2,0,num_iter)
    elif config["DENOISER"]["temp"] == "constant":
        temp = np.ones(num_iter)
    else:
        raise
    percentage = config["DENOISER"]["percentage"]
    intra = config["DENOISER"]["intra"]
    inter = config["DENOISER"]["inter"]
    Q_style = config["DENOISER"]["Q_style"]
    try:
        assert num_beams<=top_k
        assert intra or inter
    except:
        raise

    class CorpusDataset(Dataset):
        def __init__(self, corpus_textlist, textlist, top_k):
            self.corpus_textlist = corpus_textlist
            self.textlist = textlist
            self.top_k = top_k
        def __len__(self):
            return len(self.textlist)
        def __getitem__(self, c):
            retrieval_dict={}
            retrieval_distance_dict={}
            for subword in self.textlist[c].split(" "):
                retrieval_dict[subword] = np.array(self.corpus_textlist)[np.argsort([editdistance.eval(subword,word) for word in self.corpus_textlist])[:self.top_k]]
                retrieval_distance_dict[subword] = torch.from_numpy(np.sort([editdistance.eval(subword,word) for word in self.corpus_textlist])[:self.top_k])
            return retrieval_dict, retrieval_distance_dict

    retrieval_dict={}
    retrieval_distance_dict={}
    corpusdataset = CorpusDataset(corpus_textlist, textlist, top_k)
    corpusloader = DataLoader(corpusdataset, batch_size=1, num_workers=16, shuffle=False, drop_last=False, collate_fn=lambda x:x)
    for x in tqdm(corpusloader): # for each word that appears in all the text labels, retrieve top k words in corpus based on edit distance, calculate edit-distance weight
        retrieval_dict.update(x[0][0])
        retrieval_distance_dict.update(x[0][1])
    # for empty word, used when len(words) < current iteration idx
    retrieval_dict[""] = [""]*top_k
    retrieval_distance_dict[""] = torch.from_numpy(np.zeros(top_k))

    # at init, there is only one beam
    previous_selection=torch.zeros(len(textlist), 1) # [num_cls, num_beams], but [num_cls, 1] at initialization
    new_textlist = np.stack([textlist]) # [num_beams, num_cls], but [1, num_cls] at initialization

    class TextDataset(Dataset):
        def __init__(self, i, m, textlist, text_aug, retrieval_dict, retrieval_distance_dict, previous_selection):
            self.i = i
            self.m = m
            self.textlist = textlist
            self.retrieval_dict = retrieval_dict
            self.retrieval_distance_dict = retrieval_distance_dict
            self.text_aug = text_aug
            self.previous_selection = previous_selection
        def __len__(self):
            return len(self.textlist)
        def replace_and_propose(self, text, c, m, i, retrieval_dict, retrieval_distance_dict, previous_selection):
            text_splitted=text.split(" ")
            if i>0 and i<=len(text_splitted):
                text_splitted[i-1] = retrieval_dict[text_splitted[i-1]][previous_selection[c, m]]
            if i<-1 and -i<=len(text_splitted)+1:
                text_splitted[i+1] = retrieval_dict[text_splitted[i+1]][previous_selection[c, m]]
            pre_subwords = text_splitted[:i]
            current_subword = text_splitted[i] if i<len(text_splitted) and -i<=len(text_splitted) else ""
            post_subwords = text_splitted[i+1:] if i!=-1 else ""
            results = [" ".join((*pre_subwords, augmented_current_subword, *post_subwords)) for augmented_current_subword in retrieval_dict[current_subword]]
            retrieval_distance = retrieval_distance_dict[current_subword] # [topk, ]
            return " ".join(text_splitted), results, retrieval_distance
        def __getitem__(self, c):
            T_imc, T_imck, D_imck = self.replace_and_propose(self.textlist[c], c, self.m, self.i, self.retrieval_dict, self.retrieval_distance_dict, self.previous_selection)
            T_imck_tokenized = []
            for txt in self.text_aug:
                T_imck_tokenized.append(torch.cat([clip.tokenize(txt.format(T), context_length=77) for T in T_imck]))
            return T_imc, T_imck_tokenized, D_imck

    for i in range(num_iter):
        current_num_beams = num_beams if i!=0 else 1
        model.eval()
        fusion_model.eval()
        num, corr_1, corr_5 = 0, 0, 0
        t_i, T_i, D_i=[], [], []
        for m in range(current_num_beams):
            t_im, T_im, D_im=[], [], []
            textdataset = TextDataset(i, m, new_textlist[m], text_aug, retrieval_dict, retrieval_distance_dict, previous_selection)
            textloader = DataLoader(textdataset, batch_size=1, num_workers=16, shuffle=False, drop_last=False)
            for T_imc, T_imck_tokenized, D_imck in tqdm(textloader):
                with torch.no_grad():
                    t_imc = model.encode_text(torch.cat(list(map(lambda x:x.squeeze(0), T_imck_tokenized)), dim=0).to(device))
                t_imc = t_imc.reshape(num_augs, top_k, -1)

                T_im.append(T_imc[0])
                t_im.append(t_imc)
                D_im.append(D_imck[0])

            t_i.append(torch.stack(t_im, 2)) # stack num_classes
            D_i.append(torch.stack(D_im, 1).to(device))
            T_i.append(T_im)

            with open(os.path.join(working_dir,config["DENOISER"]["result"].format(m,config["DENOISER"]["idx_sim"])), 'w') as f:
                f.writelines(map(lambda x: x+'\n', T_im))

        text_features = torch.stack(t_i, 1) # [num_augs, current_num_beams, top_k, num_classes, dim] # stack current_num_beams
        intra_modal_weight = torch.stack(D_i, 0)
        new_textlist = np.stack(T_i, 0)

        dim = text_features.shape[-1]
        text_features=torch.reshape(text_features, (num_augs*current_num_beams*top_k*num_classes, dim)) # [..., dim]

        intra_modal_weight = nn.Softmax(dim=1)(-intra_modal_weight/temp[i]) # [current_num_beams, top_k, num_classes]
        intra_modal_weight = intra_modal_weight.reshape((current_num_beams*top_k, num_classes))

        with torch.no_grad():
            inter_modal_weight, similarity_all_in_one=[], []
            for iii, (image_features, class_id) in enumerate(tqdm(val_loader)):
                image_features = image_features.to(device)
                class_id = class_id.to(device)
                b = image_features.shape[0]
                text_features /= text_features.norm(dim=-1, keepdim=True)

                logits = image_features @ text_features.T # [b, ...]
                logits = logits.reshape((b, num_augs, current_num_beams*top_k, num_classes))

                inter_modal_weight.append(logits)
                if Q_style == "max":
                    if intra:
                        logits_max_idx=torch.argmax(logits*intra_modal_weight, keepdim=True, axis=-2)
                    else:
                        logits_max_idx=torch.argmax(logits, keepdim=True, axis=-2)
                    logits=torch.take_along_dim(logits,logits_max_idx,dim=-2).squeeze(2) # [b, num_augs, 1, num_classes]
                if Q_style == "mean":
                    if intra:
                        logits=torch.mean(logits*intra_modal_weight, axis=-2)
                    else:
                        logits=torch.mean(logits, axis=-2)

                similarity = logits.softmax(dim=-1)
                similarity = similarity.mean(dim=1, keepdim=False)
                similarity_all_in_one.append(similarity)

                values_1, indices_1 = similarity.topk(1, dim=-1)
                values_5, indices_5 = similarity.topk(5, dim=-1)
                num += b
                for ii in range(b):
                    if indices_1[ii] == class_id[ii]:
                        corr_1 += 1
                    if class_id[ii] in indices_5[ii]:
                        corr_5 += 1

            top1 = float(corr_1) / num * 100
            top5 = float(corr_5) / num * 100

            inter_modal_weight = torch.cat(inter_modal_weight, dim=0) # [num items, num_augs, current_num_beams*top_k, num_classes]
            inter_modal_weight = inter_modal_weight[:round(len(inter_modal_weight)*percentage/100)] # ony part of data used to update text labels, to simulate scenario when percentage < 100
            inter_modal_weight = inter_modal_weight.mean(1) # [num items, current_num_beams*top_k, num_classes]

            similarity_all_in_one = torch.cat(similarity_all_in_one, dim=0)
            similarity_all_in_one = similarity_all_in_one[:round(len(similarity_all_in_one)*percentage/100)] # ony part of data used to update text labels, to simulate scenario when percentage < 100
            values_1, indices_1 = similarity_all_in_one.topk(1)

            weight = torch.take_along_dim(inter_modal_weight, indices_1[:,None,:], dim=-1).squeeze(-1) # [num items, current_num_beams*top_k]

            beam_weights = []
            previous_selection=[]

            new_textlist_tmp=copy.deepcopy(new_textlist)
            if i == 0:
                new_textlist=np.repeat(new_textlist, num_beams, 0)
            for ii in range(len(textlist)):
                if intra and inter:
                    total_weight = weight[indices_1[:,0]==ii].mean(0)*intra_modal_weight[:,ii]
                if intra and (not inter):
                    total_weight = intra_modal_weight[:,ii]
                if inter and (not intra):
                    total_weight = weight[indices_1[:,0]==ii].mean(0)
                beam_weight, beam_idxs = total_weight.topk(num_beams)
                beam_weights.append(beam_weight.softmax(dim=0))
                for j, beam_idx in enumerate(beam_idxs):
                    new_textlist[j,ii]=new_textlist_tmp[beam_idx.item()//top_k,ii] # always be in [0, retrieval], if not then the corresponding previous textlabel is in the next beam branch, replace it to current branch
                    beam_idxs[j]=beam_idx%top_k # mod num_retrieval
                previous_selection.append(beam_idxs)
            previous_selection = torch.stack(previous_selection, dim=0).cpu() # [num_cls,num_beams]

            beam_weights = torch.stack(beam_weights, dim=-1) #[num_beams, num_cls]

            def eval(ori, new):
                ori_textlist = pd.read_csv(ori).name.values
                ori_textwords = np.concatenate([text.split(" ") for text in ori_textlist])
                new_textlist_=[label.strip('\n') for label in open(new).readlines()]

                counter=[]
                for i in range(len(ori_textlist)):
                    counter.append(editdistance.eval(ori_textlist[i], new_textlist_[i]))
                counter = np.array(counter)

                with torch.no_grad():
                    new_tFeature = clipmodel.encode_text(torch.cat([clip.tokenize(text, context_length=77) for text in new_textlist_]).to(device))
                    ori_tFeature = clipmodel.encode_text(torch.cat([clip.tokenize(text, context_length=77) for text in ori_textlist]).to(device))
                
                new_tFeature =  new_tFeature / new_tFeature.norm(dim=-1, keepdim=True)
                ori_tFeature =  ori_tFeature / ori_tFeature.norm(dim=-1, keepdim=True)

                return counter.sum() / len(ori_textwords), counter.mean(), np.sum(counter == 0) / len(counter), np.sum(counter <= 1) / len(counter), np.sum(counter <= 2) / len(counter), ((new_tFeature@ori_tFeature.T).diag()).mean().cpu().numpy()

            edit_distance_per_word_list, edit_distance_per_label_list, acc_of_label_0_list, acc_of_label_1_list, acc_of_label_2_list, semantic_cos_sim_list = [], [], [], [], [], []
            for m in range(current_num_beams):
                edit_distance_per_word, edit_distance_per_label, acc_of_label_0, acc_of_label_1, acc_of_label_2, semantic_cos_sim = eval("./lists/ucf_labels_sep.csv", os.path.join(working_dir,config["DENOISER"]["result"].format(m,config["DENOISER"]["idx_sim"])))
                edit_distance_per_word_list.append(edit_distance_per_word)
                edit_distance_per_label_list.append(edit_distance_per_label)
                acc_of_label_0_list.append(acc_of_label_0)
                acc_of_label_1_list.append(acc_of_label_1)
                acc_of_label_2_list.append(acc_of_label_2)
                semantic_cos_sim_list.append(semantic_cos_sim)

            print('valtop1 {:.03f},'.format(top1),
                'valtop5 {:.03f}.'.format(top5),
                'edit distance per word {:.03f}.'.format(np.mean(edit_distance_per_word_list)),
                'edit distance per label {:.03f}.'.format(np.mean(edit_distance_per_label_list)),
                'acc of label 0 {:.03f}.'.format(np.mean(acc_of_label_0_list)),
                'acc of label 1 {:.03f}.'.format(np.mean(acc_of_label_1_list)),
                'acc of label 2 {:.03f}.'.format(np.mean(acc_of_label_2_list)),
                'semantic cos sim {:.03f},'.format(np.mean(semantic_cos_sim_list)))

    return top1

def main():
    global args, best_prec1
    global global_step
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-cfg', type=str, default='./configs/ucf101/ucf_zero_shot.yaml') # './configs/ucf101/ucf_zero_shot.yaml' './configs/hmdb51/hmdb_zero_shot.yaml' './configs/k700/k700_zero_shot.yaml'
    parser.add_argument('--name', type=str, default='ucf')
    parser.add_argument('--log_time', type=str, default='20240105_172223')
    parser.add_argument('--idx_sim', type=int, default=1)

    parser.add_argument('--top_k', type=int, default=10)
    parser.add_argument('--num_beams', type=int, default=1)
    parser.add_argument('--num_iter', type=int, default=10)
    parser.add_argument('--temp', type=str, default="linear", choices=["log", "linear", "constant"])
    parser.add_argument('--percentage', type=int, default=100)
    parser.add_argument('--intra', type=str, default="True")
    parser.add_argument('--inter', type=str, default="True")
    parser.add_argument('--Q_style', type=str, default="max", choices=["max", "mean"])
    args = parser.parse_args()

    with open(args.config, 'r') as f:
        config = yaml.full_load(f)
    working_dir = os.path.join('./exp', config['network']['type'], config['network']['arch'], config['data']['dataset'],
                               args.log_time)

    config["DENOISER"]["top_k"] = args.top_k
    config["DENOISER"]["num_beams"] = args.num_beams
    config["DENOISER"]["num_iter"] = args.num_iter
    config["DENOISER"]["temp"] = args.temp
    config["DENOISER"]["percentage"] = args.percentage
    config["DENOISER"]["idx_sim"] = args.idx_sim
    config["DENOISER"]["intra"] = eval(args.intra)
    config["DENOISER"]["inter"] = eval(args.inter)
    config["DENOISER"]["Q_style"] = args.Q_style
    
    config["data"]["label_list"] = os.path.join(working_dir, f"{args.name}_labels_{args.idx_sim}.csv")

    with open(os.path.join(working_dir,"config.yaml"), 'w') as f:
        yaml.dump(data=config, stream=f)

    print('-' * 80)
    print(' ' * 20, "working dir: {}".format(working_dir))
    print('-' * 80)

    print('-' * 80)
    print(' ' * 30, "Config")
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(config)
    print('-' * 80)

    config = DotMap(config)

    Path(working_dir).mkdir(parents=True, exist_ok=True)
    shutil.copy('test_DENOISER.py', working_dir)

    device = "cuda" if torch.cuda.is_available() else "cpu"  # If using GPU then use mixed precision training.

    model, clip_state_dict = clip.load(config.network.arch, device=device, jit=False, tsm=config.network.tsm,
                                                   T=config.data.num_segments, dropout=config.network.drop_out,
                                                   emb_dropout=config.network.emb_dropout)  # Must set jit=False for training  ViT-B/32

    transform_val = get_augmentation(False, config)

    fusion_model = visual_prompt(config.network.sim_header, clip_state_dict, config.data.num_segments)

    model_text = TextCLIP(model)
    model_image = ImageCLIP(model)

    model_text = torch.nn.DataParallel(model_text).cuda()
    model_image = torch.nn.DataParallel(model_image).cuda()
    fusion_model = torch.nn.DataParallel(fusion_model).cuda()

    val_data = TensorDataset(pickle.load(open("./features/image_features_list","rb")), pickle.load(open("./features/class_id_list","rb")))
    
    val_loader = DataLoader(val_data, batch_size=config.data.batch_size, num_workers=config.data.workers, shuffle=True, # remember to set to true for batch E-M
                            pin_memory=True, drop_last=True, persistent_workers=True)

    if device == "cpu":
        model_text.float()
        model_image.float()
    else:
        clip.model.convert_weights(
            model_text)  # Actually this line is unnecessary since clip by default already on float16
        clip.model.convert_weights(model_image)

    start_epoch = config.solver.start_epoch

    if config.pretrain:
        if os.path.isfile(config.pretrain):
            print(("=> loading checkpoint '{}'".format(config.pretrain)))
            checkpoint = torch.load(config.pretrain)
            model.load_state_dict(checkpoint['model_state_dict'])
            fusion_model.load_state_dict(checkpoint['fusion_model_state_dict'])
            del checkpoint
        else:
            print(("=> no checkpoint found at '{}'".format(config.pretrain)))

    best_prec1 = 0.0
    prec1 = validate_DENOISER(start_epoch, val_loader, val_data, device, model, fusion_model, config, working_dir)

if __name__ == '__main__':
    main()
