import warnings
import torch
import torch.nn as nn
from sklearn.metrics import balanced_accuracy_score
import numpy as np
from tqdm import tqdm
import json

from lib.strategies.base import BaseStrategy

warnings.filterwarnings("ignore")


class FlavaEWCStrategy(BaseStrategy):
    def __init__(self,
        model: nn.Module,
        stream: object,
        n_epochs: int,
        lr: float,
        batch_size: int,
        output_filename: str,
        device: torch.device,
        **kwargs
    ):
        super().__init__(
            model=model,
            stream=stream,
            n_epochs=n_epochs,
            lr=lr,
            batch_size=batch_size,
            device=device,
            output_filename=output_filename
        )
        
    def forward(self, inputs):
        return self.model(inputs)
    
    def _training_step(self, epoch, dataloader):
        epoch_loss = []
        epoch_acc = []
        with tqdm(dataloader, unit="batch") as tepoch:
            for inputs, targets in tepoch:
                tepoch.set_description(f"Exp: {self.current_experience+1} | Epoch: {epoch+1}")
                
                inputs.to(self.device)
                
                self.optimizer.zero_grad()
                logits = self.forward(inputs)
                
                n_seen_classes_until_now = sum(self.stream.n_classes_per_experience[:self.current_experience])
                logits[:,:n_seen_classes_until_now] = torch.ones_like(logits[:,:n_seen_classes_until_now])*torch.finfo(float).min
                
                J = self.loss(logits, targets.to(self.device))
                
                for exp_id in range(len(self.model.fisher_scores_history)):
                    for name, param in self.model.feature_extractor.named_parameters():
                        if param.grad is not None:
                            F = self.model.fisher_scores_history[exp_id]
                            Theta = self.model.parameters_history[exp_id]
                            J += 0.4*torch.sum(F[name] * (Theta[name] - param)**2)
                
                J.backward()
                self.optimizer.step()

                epoch_loss.append(float(J))
                epoch_acc.append(balanced_accuracy_score(logits.argmax(dim=-1).cpu().numpy(), targets.cpu().numpy()))

                torch.cuda.empty_cache()

                tepoch.set_postfix(
                    loss=round(np.mean(epoch_loss), 3),
                    acc=round(np.mean(epoch_acc), 3)
                )
                
    def run(self):
        self.res = dict(
            current_vs_target=dict(),
            current_vs_overall=dict()
        )
        
        n_classes_per_experience = self.stream.n_classes_per_experience
        n_output_classes = n_classes_per_experience[0]
        n_classes_per_experience = n_classes_per_experience[1:] + [0]
        
        for exp_id in range(self.stream.n_experiences):            
            self.current_experience = exp_id
            
            train_data = self.stream.train_stream[exp_id]
            train_dataloader = self._get_dataloader(train_data, True)
            self.train(train_dataloader)
            
            # Test
            self.res["current_vs_target"][f"exp {exp_id+1}"] = []
            for target_exp_id in tqdm(range(self.current_experience+1)):
                acc = self._test_target_experience(target_exp_id)
                self.res["current_vs_target"][f"exp {exp_id+1}"].append(acc)
            
            self.res["current_vs_overall"][f"exp {exp_id+1}"] = self._test_all_experiences()
            
            n_output_classes += n_classes_per_experience[exp_id]
            self.model.adaptation(
                n_output_classes=n_output_classes, 
                dataloader=train_dataloader, 
                optimizer=self.optimizer, 
                loss=self.loss
            )
            
        json.dump(self.res, open(f"./output/{self.output_filename}.json", "w"))