import warnings
import torch
import torch.nn as nn
from sklearn.metrics import balanced_accuracy_score
import numpy as np
from tqdm import tqdm

from lib.strategies.base import BaseStrategy

warnings.filterwarnings("ignore")


class FlavaDualPromptStrategy(BaseStrategy):
    def __init__(self,
        model: nn.Module,
        stream: object,
        n_epochs: int,
        lr: float,
        batch_size: int,
        output_filename: str,
        device: torch.device,
        **kwargs
    ):
        super().__init__(
            model=model,
            stream=stream,
            n_epochs=n_epochs,
            lr=lr,
            batch_size=batch_size,
            device=device,
            output_filename=output_filename
        )
        
    def forward(self, inputs, experience_id=None):
        return self.model(inputs, experience_id)
    
    def _training_step(self, epoch, dataloader):
        epoch_loss = []
        epoch_acc = []
        with tqdm(dataloader, unit="batch") as tepoch:
            for inputs, targets in tepoch:
                tepoch.set_description(f"Exp: {self.current_experience+1} | Epoch: {epoch+1}")
                
                inputs.to(self.device)
                
                self.optimizer.zero_grad()
                logits, key = self.forward(inputs, experience_id=self.current_experience)
                
                n_seen_classes_until_now = sum(self.stream.n_classes_per_experience[:self.current_experience])
                logits[:,:n_seen_classes_until_now] = torch.ones_like(logits[:,:n_seen_classes_until_now])*torch.finfo(float).min
                
                J = self.loss(logits, targets.to(self.device)) + self.model.feature_extractor.get_dissimilarity_score(inputs, key).mean()
                J.backward()
                self.optimizer.step()

                epoch_loss.append(float(J))
                epoch_acc.append(balanced_accuracy_score(logits.argmax(dim=-1).cpu().numpy(), targets.cpu().numpy()))

                torch.cuda.empty_cache()

                tepoch.set_postfix(
                    loss=round(np.mean(epoch_loss), 3),
                    acc=round(np.mean(epoch_acc), 3)
                )
    
    @torch.no_grad()
    def _test_target_experience(self, target_exp_id):
        self.model.eval()
        
        targets_ = []
        preds_ = []
        
        test_data = self.stream.test_stream[target_exp_id]
        test_dataloader = self._get_dataloader(test_data, False)
        
        #onehot = np.hstack([1*(np.repeat(i, n)==target_exp_id) for i,n in enumerate(self.stream.n_classes_per_experience[:self.current_experience+1])])
        #w = torch.tensor((1-onehot) * torch.finfo(float).min)
                
        for inputs, targets in test_dataloader:
            inputs.to(self.device)
            logits, _ = self.forward(inputs)
            #logits += w.to(self.device)
            targets_ += targets.tolist()
            preds_ += logits.argmax(dim=-1).tolist()    
                 
        return balanced_accuracy_score(preds_, targets_)
    
    @torch.no_grad()
    def _test_all_experiences(self):
        self.model.eval()
        
        targets_ = []
        preds_ = []
        
        for exp_id in tqdm(range(self.current_experience+1)):
            test_data = self.stream.test_stream[exp_id]
            test_dataloader = self._get_dataloader(test_data, False)
        
            for inputs, targets in test_dataloader:
                inputs.to(self.device)
                logits, _ = self.forward(inputs)
                targets_ += targets.tolist()
                preds_ += logits.argmax(dim=-1).tolist() 
                    
        return balanced_accuracy_score(preds_, targets_)