import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import warnings
import argparse

from lib.data import data_config
from lib.data import VLExperienceStream
from lib.models import FlavaProcessor

from lib.models.upper_bound import FlavaUpperBoundCL
from lib.models.lower_bound import FlavaLowerBoundCL
from lib.models.dual_prompt import FlavaDualPromptCL
from lib.models.dual_prompt_closed_form import FlavaDualPromptCL as FlavaDualPromptClosedFormCL
from lib.models.learning_to_prompt import FlavaLearningToPromptCL

from lib.strategies.upper_bound import FlavaUpperBoundStrategy
from lib.strategies.lower_bound import FlavaLowerBoundStrategy
from lib.strategies.dual_prompt import FlavaDualPromptStrategy
from lib.strategies.dual_prompt_closed_form import FlavaDualPromptStrategy as FlavaDualPromptClosedFormStrategy
from lib.strategies.learning_to_prompt import FlavaLearningToPromptStrategy
from lib.strategies.experience_replay import FlavaExperienceReplayStrategy

import torch

warnings.filterwarnings("ignore")
device = torch.device("cuda:0")

###

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, help="lower_bound | upper_bound | l2p | dual_prompt | dual_prompt_cf | experience_replay")
parser.add_argument("--dataset", type=str, help="cub200 | flowers | dvm_car")
parser.add_argument("--experiences", type=str, help="10 | 20")
parser.add_argument("--run", type=str, help="1 | 2 | 3")

###

cl_configs = dict(
    lower_bound=dict(
        lr=1e-5,
        n_epochs=5,
        model=FlavaLowerBoundCL,
        strategy=FlavaLowerBoundStrategy
    ),
    upper_bound=dict(
        lr=1e-5,
        n_epochs=5,
        model=FlavaUpperBoundCL,
        strategy=FlavaUpperBoundStrategy
    ),
    dual_prompt=dict(
        lr=5e-3,
        n_epochs=5,
        model=FlavaDualPromptCL,
        strategy=FlavaDualPromptStrategy
    ),
    dual_prompt_cf=dict(
        lr=5e-3,
        n_epochs=5,
        model=FlavaDualPromptClosedFormCL,
        strategy=FlavaDualPromptClosedFormStrategy
    ),
    l2p=dict(
        lr=5e-3,
        n_epochs=5,
        model=FlavaLearningToPromptCL,
        strategy=FlavaLearningToPromptStrategy
    ),
    experience_replay=dict(
        lr=1e-5,
        n_epochs=5,
        model=FlavaLowerBoundCL,
        strategy=FlavaExperienceReplayStrategy,
        percent_samples_per_class=0.25
    )
)

seeds = [42, 105, 200]

if __name__ == "__main__":
    args = parser.parse_args()
    
    seed = seeds[int(args.run)-1]
    torch.manual_seed(seed)
    
    image_processor, text_processor = FlavaProcessor.get_processor()

    stream = VLExperienceStream(
        data_path=data_config[args.dataset],
        n_experiences=int(args.experiences),
        image_processor=image_processor,
        text_processor=text_processor,
        seed=seed
    )
    
    config = cl_configs[args.model]
    
    if config["model"] is not None:
        model = config["model"](n_output_classes=stream.n_classes_per_experience[0])
    else:
        model = None
    config.pop("model")
        
    strategy = config["strategy"](
        model=model,
        stream=stream,
        batch_size=16,
        device=device,
        output_filename=args.dataset + "_" + args.model + "_" + args.experiences + "_run-" + args.run,
        **config
    )
    
    strategy.run()
