import argparse
import torch
import numpy as np
from TIEM.mask_generator import TIEM
from train_model import train
from models.whiteboxmodel import WhiteBoxModel
import matplotlib.pyplot as plt
from utils.funcs import *
from PIL import Image
from data_preprocessing import create_video_dataset
torch.manual_seed(0)


def main(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if args.data_preprocessing:
        print("Starting data preprocessing...")
        video_base_directory = 'dataset/UCF101/'
        create_video_dataset(video_base_directory)

        # Model Training
    if args.train_model:
        print("Starting model training...")
        if args.model_arch:
            model = train(args.model_arch, device, False)  # 모델 학습을 시작하는 부분
        else:
            print("Please specify the model architecture for training.")

    if args.figure == 5:
        for ex_type in ["gentle", "dynamic"]:
            model_type = "white-box"
            if model_type == "white-box":
                img = torch.rand(1, 3, 16, 100, 100) * 10
                if ex_type == "gentle":
                    model = WhiteBoxModel(0, version="moving_diagonal", device=device)
                    model.forward(img)
                    model = WhiteBoxModel(model.tmp_out, version="moving_diagonal", device=device)
                    mask, TIS_for_windw, TIS = TIEM(model=model.to(device), input=img.to(device),
                                                    target=torch.LongTensor([1]).to(device),
                                                    learning_rate=8e-2, reward_weight=100, regul_weight=3000,
                                                    areas=0.07, max_iter=2000, print_iter=400,
                                                    perturb_type="blur", variant="preserve", sigma=10, alpha=0.99)

                    visualize_mask(mask, f"white_box_gentle.png")


                elif ex_type == "dynamic":
                    model = WhiteBoxModel(0, version="solid_time", device=device)
                    model.forward(img)
                    model = WhiteBoxModel(model.tmp_out, version="solid_time", device=device)
                    mask, TIS_for_windw, TIS = TIEM(model=model.to(device), input=img.to(device),
                                                    target=torch.LongTensor([1]).to(device),
                                                    learning_rate=5e-2, reward_weight=100, regul_weight=1300,
                                                    areas=0.05, max_iter=2000, print_iter=400,
                                                    perturb_type="blur", variant="preserve", sigma=10, alpha=0.6)

                    visualize_mask(mask, f"white_box_dynamic.png")


    elif args.figure in [6, 9]:
        model_type = "black-box"
        model_arch = args.model_arch
        ex_type = args.ex_type
        if model_arch == "R2p1d":
            model = train(model_arch, device, True)
        elif model_arch == "R50LSTM":
            model = train(model_arch, device, True)
            model.eval()

        input_path = f"sample_videos/{ex_type}.npy"
        input = torch.FloatTensor(np.load(input_path)).unsqueeze(0).permute(0, 4, 1, 2, 3)
        label = model(input.to(device)).argmax(axis=1)

        if ex_type == "frontcrawl":
            mask, TIS_for_windw, TIS = TIEM(model=model.to(device), input=input.to(device),
                                            target=label.to(device), learning_rate=1e-2,
                                            reward_weight=1500, regul_weight=3000, areas=0.12,
                                            max_iter=2000, print_iter=400, perturb_type="blur",
                                            variant="preserve", sigma=13, alpha=0.8)
        elif ex_type == "breaststroke":
            mask, TIS_for_windw, TIS = TIEM(model=model.to(device), input=input.to(device),
                                            target=label.to(device), learning_rate=8e-2,
                                            reward_weight=300, regul_weight=3000, areas=0.1,
                                            max_iter=2000, print_iter=400, perturb_type="blur",
                                            variant="preserve", sigma=13, alpha=0.8)
        elif ex_type == "floorgymnetics":
            mask, TIS_for_windw, TIS = TIEM(model=model.to(device), input=input.to(device),
                                            target=label.to(device), learning_rate=5e-2,
                                            reward_weight=300, regul_weight=1000, areas=0.07,
                                            max_iter=2000, print_iter=400, perturb_type="blur",
                                            variant="preserve", sigma=13, alpha=0.6)

        # Visualize the result
        result_image = np.copy(input.squeeze().permute(1, 2, 3, 0))
        for i in range(16):
            result_image[i] = result_image[i] * mask.detach().cpu().squeeze().numpy()[i].reshape(128, 128, 1)
        fig, axes = plt.subplots(4, 4, figsize=(20, 20))
        for i, ax in enumerate(axes.flatten()):
            ax.imshow(Image.fromarray((result_image[i] * 255).astype(np.uint8)))
            ax.axis('off')
            ax.set_title(f"Frame= {i}", fontsize=20)
        plt.tight_layout()
        plt.show()

        # TIS for each window
        fig, axs = plt.subplots(4, 4, figsize=(20, 20))
        cnt = 0
        for rr in range(4):
            for j in range(4):
                data = TIS_for_windw[cnt]
                default_color = next(axs[rr, j]._get_lines.prop_cycler)['color']
                if len(data) == 1:
                    axs[rr, j].axhline(y=data[0], color=default_color, linestyle='-', linewidth=2)
                else:
                    axs[rr, j].plot(data, color=default_color, linestyle='-', linewidth=2)
                axs[rr, j].set_title(f"w={cnt + 1}", fontsize=20)
                axs[rr, j].set_ylim(-0.1, 1.1)
                axs[rr, j].xaxis.set_major_locator(plt.MaxNLocator(integer=True))
                axs[rr, j].tick_params(axis='both', which='major', labelsize=20)
                cnt += 1
        plt.tight_layout()
        plt.show()

        # TIS plot
        plt.plot(TIS)
        plt.xlabel("Frame", fontsize=20)
        plt.ylabel("TIS", fontsize=20)
        plt.tight_layout()
        plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='TIEM Model Training and Evaluation')
    parser.add_argument('--figure', type=int, choices=[5, 6, 9], help='Specify which figure to generate (5, 6, or 9)')
    parser.add_argument('--model_arch', type=str, choices=['R2p1d', 'R50LSTM'],
                        help='Model architecture (required for figures 6 and 9)')
    parser.add_argument('--ex_type', type=str,
                        choices=['gentle', 'dynamic', 'frontcrawl', 'breaststroke', 'floorgymnetics'],
                        help='Type of experiment to run (required for figures 6 and 9)')
    parser.add_argument('--train_model', action='store_true', help='If specified, will start model training')
    parser.add_argument('--data_preprocessing', action='store_true', help='If specified, will start data preprocessing')
    args = parser.parse_args()
    main(args)
