import warnings
import torch
import torch.nn as nn
from sklearn.metrics import balanced_accuracy_score
import numpy as np
from tqdm import tqdm
import json

from lib.strategies.base import BaseStrategy

warnings.filterwarnings("ignore")


class FlavaDualPromptStrategy(BaseStrategy):
    def __init__(self,
        model: nn.Module,
        stream: object,
        n_epochs: int,
        lr: float,
        batch_size: int,
        output_filename: str,
        device: torch.device,
        **kwargs
    ):
        super().__init__(
            model=model,
            stream=stream,
            n_epochs=n_epochs,
            lr=lr,
            batch_size=batch_size,
            device=device,
            output_filename=output_filename
        )
        
    def forward(self, inputs):
        return self.model(inputs, experience_id=self.current_experience)
    
    def run(self):
            self.res = dict(
                current_vs_target=dict(),
                current_vs_overall=dict()
            )
            
            n_classes_per_experience = self.stream.n_classes_per_experience
            n_output_classes = n_classes_per_experience[0]
            n_classes_per_experience = n_classes_per_experience[1:] + [0]
            
            for exp_id in range(self.stream.n_experiences):            
                self.current_experience = exp_id
                
                train_data = self.stream.train_stream[exp_id]
                train_dataloader = self._get_dataloader(train_data, True)
                
                self.model.feature_extractor.update_e_key(train_dataloader)
                self.train(train_dataloader)
                
                # Test
                self.res["current_vs_target"][f"exp {exp_id+1}"] = []
                for target_exp_id in tqdm(range(self.current_experience+1)):
                    acc = self._test_target_experience(target_exp_id)
                    self.res["current_vs_target"][f"exp {exp_id+1}"].append(acc)
                
                self.res["current_vs_overall"][f"exp {exp_id+1}"] = self._test_all_experiences()
                
                n_output_classes += n_classes_per_experience[exp_id]
                self.model.adaptation(n_output_classes)
            
            json.dump(self.res, open(f"./output/{self.output_filename}.json", "w"))