import warnings
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score
import numpy as np
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import DataLoader
from lib.data import collate_fn
from lib.data.utils import (
    merge_vl_datasets,
    compute_dataset_sampling_weights
)
from lib.models.upper_bound import FlavaUpperBoundCL
import json

warnings.filterwarnings("ignore")

class FlavaUpperBoundStrategy():
    def __init__(self,
        stream: object,
        n_epochs: int,
        lr: float,
        batch_size: int,
        output_filename: str,
        device: torch.device,
        **kwargs
    ):
        self.device = device
        self.stream = stream
        self.n_epochs = n_epochs
        self.lr = lr
        self.batch_size = batch_size
        self.loss = nn.CrossEntropyLoss()
        self.current_experience = None
        self.res = None
        self.output_filename = output_filename
        
    def forward(self, inputs):
        return self.model(inputs)
    
    def _get_dataloader(self, dataset, is_train: bool = True, sampler: WeightedRandomSampler = None):
        if sampler is None:
            return DataLoader(dataset, batch_size=self.batch_size, shuffle=is_train, num_workers=12, collate_fn=collate_fn)
        else:
            return DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, num_workers=12, collate_fn=collate_fn)
    
    def train(self, dataloader):
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        self.model.train()
        for epoch in range(self.n_epochs):
            self._training_step(epoch, dataloader)
               
    def _training_step(self, epoch, dataloader):
        epoch_loss = []
        epoch_acc = []
        with tqdm(dataloader, unit="batch") as tepoch:
            for inputs, targets in tepoch:
                tepoch.set_description(f"Exp: {self.current_experience+1} | Epoch: {epoch+1}")
                
                inputs.to(self.device)
                
                self.optimizer.zero_grad()
                logits = self.forward(inputs)
                
                J = self.loss(logits, targets.to(self.device))
                J.backward()
                self.optimizer.step()

                epoch_loss.append(float(J))
                epoch_acc.append(balanced_accuracy_score(logits.argmax(dim=-1).cpu().numpy(), targets.cpu().numpy()))

                torch.cuda.empty_cache()

                tepoch.set_postfix(
                    loss=round(np.mean(epoch_loss), 3),
                    acc=round(np.mean(epoch_acc), 3)
                )
                
    @torch.no_grad()
    def _test_all_experiences(self):
        self.model.eval()
        
        targets_ = []
        preds_ = []
        
        for exp_id in tqdm(range(self.current_experience+1)):
            test_data = self.stream.test_stream[exp_id]
            test_dataloader = self._get_dataloader(test_data, False)
        
            for inputs, targets in test_dataloader:
                inputs.to(self.device)
                
                logits = self.forward(inputs)
                targets_ += targets.tolist()
                preds_ += logits.argmax(dim=-1).tolist() 
                    
        return balanced_accuracy_score(preds_, targets_)
    
    @torch.no_grad()    
    def _test_target_experience(self, target_exp_id):
        self.model.eval()
        
        targets_ = []
        preds_ = []
        
        test_data = self.stream.test_stream[target_exp_id]
        test_dataloader = self._get_dataloader(test_data, False)
        
        #onehot = np.hstack([1*(np.repeat(i, n)==target_exp_id) for i,n in enumerate(self.stream.n_classes_per_experience[:self.current_experience+1])])
        #w = torch.tensor((1-onehot) * torch.finfo(float).min)
                
        for inputs, targets in test_dataloader:
            inputs.to(self.device)
            logits = self.forward(inputs)
            #logits += w.to(self.device)
            targets_ += targets.tolist()
            preds_ += logits.argmax(dim=-1).tolist()    
                 
        return balanced_accuracy_score(preds_, targets_)
    
    def run(self):
        self.res = dict(
            current_vs_target=dict(),
            current_vs_overall=dict()
        )
        
        for exp_id in range(self.stream.n_experiences): 
            
            model = FlavaUpperBoundCL(n_output_classes=sum(self.stream.n_classes_per_experience[:(exp_id+1)]))
            self.model = model.to(self.device)    
            
            self.current_experience = exp_id
            
            exp_train_data = self.stream.train_stream[exp_id]
            
            if self.current_experience == 0:
                train_data = exp_train_data
            else:
                train_data = merge_vl_datasets(exp_train_data, train_data)
            
            weights = compute_dataset_sampling_weights(train_data)
            train_dataloader = self._get_dataloader(train_data, sampler=WeightedRandomSampler(weights, len(weights)))
            self.train(train_dataloader)
            
            # Test
            self.res["current_vs_target"][f"exp {exp_id+1}"] = []
            for target_exp_id in tqdm(range(self.current_experience+1)):
                acc = self._test_target_experience(target_exp_id)
                self.res["current_vs_target"][f"exp {exp_id+1}"].append(acc)
            
            self.res["current_vs_overall"][f"exp {exp_id+1}"] = self._test_all_experiences()
            
        json.dump(self.res, open(f"./output/{self.output_filename}.json", "w"))
            

    
