import warnings
import torch
import torch.nn as nn
import json
from tqdm import tqdm
from torch.utils.data import DataLoader
from lib.data import collate_fn
from torch.utils.data.sampler import WeightedRandomSampler
from lib.data.utils import (
    merge_vl_datasets,
    subsample_dataset,
    compute_dataset_sampling_weights
)
from sklearn.metrics import balanced_accuracy_score
import numpy as np

from lib.strategies.base import BaseStrategy

warnings.filterwarnings("ignore")

class FlavaExperienceReplayStrategy(BaseStrategy):
    def __init__(self,
        model: nn.Module,
        stream: object,
        n_epochs: int,
        lr: float,
        batch_size: int,
        output_filename: str,
        device: torch.device,
        n_samples_per_class: int = None,
        percent_samples_per_class: float = None,
        **kwargs
    ):
        super().__init__(
            model=model,
            stream=stream,
            n_epochs=n_epochs,
            lr=lr,
            batch_size=batch_size,
            device=device,
            output_filename=output_filename
        )
        self.n_samples_per_class = n_samples_per_class
        self.percent_samples_per_class = percent_samples_per_class
        
    def forward(self, inputs):
        return self.model(inputs)
    
    def _get_dataloader(self, dataset, is_train: bool = True, sampler: WeightedRandomSampler = None):
        if sampler is None:
            return DataLoader(dataset, batch_size=self.batch_size, shuffle=is_train, num_workers=12, collate_fn=collate_fn)
        else:
            return DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, num_workers=12, collate_fn=collate_fn)
        
    def _training_step(self, epoch, dataloader):
        epoch_loss = []
        epoch_acc = []
        with tqdm(dataloader, unit="batch") as tepoch:
            for inputs, targets in tepoch:
                tepoch.set_description(f"Exp: {self.current_experience+1} | Epoch: {epoch+1}")
                
                inputs.to(self.device)
                
                self.optimizer.zero_grad()
                logits = self.forward(inputs)
                
                J = self.loss(logits, targets.to(self.device))
                J.backward()
                self.optimizer.step()

                epoch_loss.append(float(J))
                epoch_acc.append(balanced_accuracy_score(logits.argmax(dim=-1).cpu().numpy(), targets.cpu().numpy()))

                torch.cuda.empty_cache()

                tepoch.set_postfix(
                    loss=round(np.mean(epoch_loss), 3),
                    acc=round(np.mean(epoch_acc), 3)
                )
    
    def run(self):
        self.res = dict(
            current_vs_target=dict(),
            current_vs_overall=dict()
        )
        
        n_classes_per_experience = self.stream.n_classes_per_experience
        n_output_classes = n_classes_per_experience[0]
        n_classes_per_experience = n_classes_per_experience[1:] + [0]
        
        for exp_id in range(self.stream.n_experiences):            
            self.current_experience = exp_id
            
            exp_train_data = self.stream.train_stream[exp_id]
            
            if self.current_experience == 0:
                train_data = exp_train_data
                train_dataloader = self._get_dataloader(train_data, True)
            else:
                train_data = merge_vl_datasets(
                    exp_train_data, 
                    subsample_dataset(
                        train_data, 
                        n_samples_per_class=self.n_samples_per_class,
                        percent_samples_per_class=self.percent_samples_per_class
                    )
                )
                weights = compute_dataset_sampling_weights(train_data)
                train_dataloader = self._get_dataloader(train_data, sampler=WeightedRandomSampler(weights, len(weights)))
                
            self.train(train_dataloader)
            
            # Test
            self.res["current_vs_target"][f"exp {exp_id+1}"] = []
            for target_exp_id in tqdm(range(self.current_experience+1)):
                acc = self._test_target_experience(target_exp_id)
                self.res["current_vs_target"][f"exp {exp_id+1}"].append(acc)
            
            self.res["current_vs_overall"][f"exp {exp_id+1}"] = self._test_all_experiences()
            
            n_output_classes += n_classes_per_experience[exp_id]
            self.model.adaptation(n_output_classes)
            
        json.dump(self.res, open(f"./output/{self.output_filename}.json", "w"))