import warnings
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import balanced_accuracy_score
import numpy as np
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from abc import abstractmethod
import json
from lib.data import collate_fn

warnings.filterwarnings("ignore")

class BaseStrategy:
    def __init__(self,
        model: nn.Module,
        stream: object,
        n_epochs: int,
        lr: float,
        batch_size: int,
        device: torch.device,
        output_filename: str = None,
        **kwargs
    ):
        self.device = device
        self.model = model.to(self.device)
        self.stream = stream
        self.n_epochs = n_epochs
        self.lr = lr
        self.batch_size = batch_size
        self.loss = nn.CrossEntropyLoss()
        self.current_experience = None
        self.res = None
        self.output_filename = output_filename
        
    @abstractmethod
    def forward(self, *args, **kwargs):
        raise NotImplementedError()
        
    def _get_dataloader(self, dataset, is_train: bool = True):
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=is_train, num_workers=12, collate_fn=collate_fn)
        
    def train(self, dataloader):
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        self.model.train()
        for epoch in range(self.n_epochs):
            self._training_step(epoch, dataloader)
    
    def _training_step(self, epoch, dataloader):
        epoch_loss = []
        epoch_acc = []
        with tqdm(dataloader, unit="batch") as tepoch:
            for inputs, targets in tepoch:
                tepoch.set_description(f"Exp: {self.current_experience+1} | Epoch: {epoch+1}")
                
                inputs.to(self.device)
                
                self.optimizer.zero_grad()
                logits = self.forward(inputs)
                
                n_seen_classes_until_now = sum(self.stream.n_classes_per_experience[:self.current_experience])
                logits[:,:n_seen_classes_until_now] = torch.ones_like(logits[:,:n_seen_classes_until_now])*torch.finfo(float).min
                
                J = self.loss(logits, targets.to(self.device))
                J.backward()
                self.optimizer.step()

                epoch_loss.append(float(J))
                epoch_acc.append(balanced_accuracy_score(logits.argmax(dim=-1).cpu().numpy(), targets.cpu().numpy()))

                torch.cuda.empty_cache()

                tepoch.set_postfix(
                    loss=round(np.mean(epoch_loss), 3),
                    acc=round(np.mean(epoch_acc), 3)
                )
    
    @torch.no_grad()    
    def _test_target_experience(self, target_exp_id):
        self.model.eval()
        
        targets_ = []
        preds_ = []
        
        test_data = self.stream.test_stream[target_exp_id]
        test_dataloader = self._get_dataloader(test_data, False)
        
        #onehot = np.hstack([1*(np.repeat(i, n)==target_exp_id) for i,n in enumerate(self.stream.n_classes_per_experience[:self.current_experience+1])])
        #w = torch.tensor((1-onehot) * torch.finfo(float).min)
                
        for inputs, targets in test_dataloader:
            inputs.to(self.device)
            logits = self.forward(inputs)
            #logits += w.to(self.device)
            targets_ += targets.tolist()
            preds_ += logits.argmax(dim=-1).tolist()    
                 
        return balanced_accuracy_score(preds_, targets_)
    
    @torch.no_grad()
    def _test_all_experiences(self):
        self.model.eval()
        
        targets_ = []
        preds_ = []
        
        for exp_id in tqdm(range(self.current_experience+1)):
            test_data = self.stream.test_stream[exp_id]
            test_dataloader = self._get_dataloader(test_data, False)
        
            for inputs, targets in test_dataloader:
                inputs.to(self.device)
                
                logits = self.forward(inputs)
                targets_ += targets.tolist()
                preds_ += logits.argmax(dim=-1).tolist() 
                    
        return balanced_accuracy_score(preds_, targets_)
    
    def run(self):
        self.res = dict(
            current_vs_target=dict(),
            current_vs_overall=dict()
        )
        
        n_classes_per_experience = self.stream.n_classes_per_experience
        n_output_classes = n_classes_per_experience[0]
        n_classes_per_experience = n_classes_per_experience[1:] + [0]
        
        for exp_id in range(self.stream.n_experiences):            
            self.current_experience = exp_id
            
            train_data = self.stream.train_stream[exp_id]
            train_dataloader = self._get_dataloader(train_data, True)
            self.train(train_dataloader)
            
            # Test
            self.res["current_vs_target"][f"exp {exp_id+1}"] = []
            for target_exp_id in tqdm(range(self.current_experience+1)):
                acc = self._test_target_experience(target_exp_id)
                self.res["current_vs_target"][f"exp {exp_id+1}"].append(acc)
            
            self.res["current_vs_overall"][f"exp {exp_id+1}"] = self._test_all_experiences()
            
            n_output_classes += n_classes_per_experience[exp_id]
            self.model.adaptation(n_output_classes)
            
        json.dump(self.res, open(f"./output/{self.output_filename}.json", "w"))
            
    def plot_results(self, results = None):
        res_ = self.res if results is None else results
        
        if res_ is None:
            return
        
        # Current Exp vs Target Exp
        res = res_["current_vs_target"]
        mat = list(res.values())
        max_ = max(len(m) for m in mat)
        mat = [mat[i] + [0]*(max_ - i - 1) for i in range(len(mat))]
        mat = np.array(mat)
        
        p = sns.heatmap(mat, annot=True, annot_kws={"fontsize": 12}, cmap="viridis", cbar=False)
        p.set_yticklabels([f"After Exp {i+1}" for i in range(max_)])
        p.set_xticklabels([f"Exp {i+1}" for i in range(max_)])
        plt.yticks(rotation=0)
        plt.show()
        
        # Current Exp vs Overall Exp
        res = res_["current_vs_overall"]

        p = sns.pointplot(
            data=pd.DataFrame(dict(x=[s[0].upper() + s[1:] for s in res.keys()], y=res.values())),
            x="x", y="y",
            legend=None
        )
        p.set(xlabel=None)
        p.set(ylabel="Balanced Accuracy")
        plt.show()