import os
import clip
import torch.nn as nn
from datasets import Action_DATASETS
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
import argparse
import shutil
from pathlib import Path
import yaml
from dotmap import DotMap
import pprint
import numpy
from modules.Visual_Prompt import visual_prompt
from utils.Augmentation import get_augmentation
import torch
from utils.Text_Prompt import *
import numpy as np
import editdistance
import copy
import json
import pandas as pd
import pickle

class TextCLIP(nn.Module):
    def __init__(self, model):
        super(TextCLIP, self).__init__()
        self.model = model

    def forward(self, text):
        return self.model.encode_text(text)

class ImageCLIP(nn.Module):
    def __init__(self, model):
        super(ImageCLIP, self).__init__()
        self.model = model

    def forward(self, image):
        return self.model.encode_image(image)
    
def cache_features(val_loader, device, model, fusion_model, config):
    image_features_list = []
    class_id_list = []

    with torch.no_grad():
        for iii, (image, class_id) in enumerate(tqdm(val_loader)):
            image = image.view((-1, config.data.num_segments, 3) + image.size()[-2:])
            b, t, c, h, w = image.size()
            class_id = class_id.to(device)
            image_input = image.to(device).view(-1, c, h, w)
            image_features = model.encode_image(image_input).view(b, t, -1)
            image_features = fusion_model(image_features)
            image_features /= image_features.norm(dim=-1, keepdim=True)

            image_features_list.append(image_features)
            class_id_list.append(class_id)

    pickle.dump(torch.cat(image_features_list).cpu(),open("./features/image_features_list", "wb"))
    pickle.dump(torch.cat(class_id_list).cpu(),open("./features/class_id_list", "wb"))
    return

def main():
    global args, best_prec1
    global global_step
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-cfg', type=str, default='./configs/ucf101/ucf_zero_shot.yaml') # './configs/ucf101/ucf_zero_shot.yaml' './configs/hmdb51/hmdb_zero_shot.yaml' './configs/k700/k700_zero_shot.yaml'
    parser.add_argument('--name', type=str, default='ucf')
    parser.add_argument('--log_time', type=str, default='20240105_172223')
    parser.add_argument('--idx_sim', type=int, default=1)

    parser.add_argument('--top_k', type=int, default=10)
    parser.add_argument('--num_beams', type=int, default=1)
    parser.add_argument('--num_iter', type=int, default=10)
    parser.add_argument('--temp', type=str, default="linear", choices=["log", "linear", "constant"])
    parser.add_argument('--percentage', type=int, default=100)
    parser.add_argument('--intra', type=str, default="True")
    parser.add_argument('--cross', type=str, default="True")
    parser.add_argument('--Q_style', type=str, default="max", choices=["max", "mean"])
    args = parser.parse_args()

    with open(args.config, 'r') as f:
        config = yaml.full_load(f)
    working_dir = os.path.join('./exp', config['network']['type'], config['network']['arch'], config['data']['dataset'],
                               args.log_time)

    config["DENOISER"]["top_k"] = args.top_k
    config["DENOISER"]["num_beams"] = args.num_beams
    config["DENOISER"]["num_iter"] = args.num_iter
    config["DENOISER"]["temp"] = args.temp
    config["DENOISER"]["percentage"] = args.percentage
    config["DENOISER"]["idx_sim"] = args.idx_sim
    config["DENOISER"]["intra"] = eval(args.intra)
    config["DENOISER"]["cross"] = eval(args.cross)
    config["DENOISER"]["Q_style"] = args.Q_style
    
    config["data"]["label_list"] = os.path.join(working_dir, f"{args.name}_labels_{args.idx_sim}.csv")

    with open(os.path.join(working_dir,"config.yaml"), 'w') as f:
        yaml.dump(data=config, stream=f)

    print('-' * 80)
    print(' ' * 20, "working dir: {}".format(working_dir))
    print('-' * 80)

    print('-' * 80)
    print(' ' * 30, "Config")
    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(config)
    print('-' * 80)

    config = DotMap(config)

    Path(working_dir).mkdir(parents=True, exist_ok=True)
    shutil.copy('test.py', working_dir)

    device = "cuda" if torch.cuda.is_available() else "cpu"  # If using GPU then use mixed precision training.

    model, clip_state_dict = clip.load(config.network.arch, device=device, jit=False, tsm=config.network.tsm,
                                                   T=config.data.num_segments, dropout=config.network.drop_out,
                                                   emb_dropout=config.network.emb_dropout)  # Must set jit=False for training  ViT-B/32

    transform_val = get_augmentation(False, config)

    fusion_model = visual_prompt(config.network.sim_header, clip_state_dict, config.data.num_segments)

    model_text = TextCLIP(model)
    model_image = ImageCLIP(model)

    model_text = torch.nn.DataParallel(model_text).cuda()
    model_image = torch.nn.DataParallel(model_image).cuda()
    fusion_model = torch.nn.DataParallel(fusion_model).cuda()

    val_data = Action_DATASETS(config.data.val_list, config.data.label_list, num_segments=config.data.num_segments,
                        image_tmpl=config.data.image_tmpl,
                        transform=transform_val, random_shift=config.random_shift)
    
    val_loader = DataLoader(val_data, batch_size=config.data.batch_size, num_workers=config.data.workers, shuffle=False,
                            pin_memory=True, drop_last=True, persistent_workers=True)

    if device == "cpu":
        model_text.float()
        model_image.float()
    else:
        clip.model.convert_weights(
            model_text)  # Actually this line is unnecessary since clip by default already on float16
        clip.model.convert_weights(model_image)

    start_epoch = config.solver.start_epoch

    if config.pretrain:
        if os.path.isfile(config.pretrain):
            print(("=> loading checkpoint '{}'".format(config.pretrain)))
            checkpoint = torch.load(config.pretrain)
            model.load_state_dict(checkpoint['model_state_dict'])
            fusion_model.load_state_dict(checkpoint['fusion_model_state_dict'])
            del checkpoint
        else:
            print(("=> no checkpoint found at '{}'".format(config.pretrain)))

    cache_features(val_loader, device, model, fusion_model, config)

if __name__ == '__main__':
    main()
