import torch
import copy
from abc import ABC, abstractmethod

from oucl.agents.memory_modules import (
                    ParallelReservoirSamplingMemory, 
                     ParallelSCALEMemory, 
                     ParallelQueueMemory, 
                     ParallelLossBufferMemory, 
                     ParallelRandomBufferMemory, 
                     ParallelRepulsionMemory,
                     ParallelRepulsionMemory2, 
                     ParallelRandomSamplingMemory,
                     ParallelFixedThreshMemory
                    )
from oucl.scenarios.transforms import FromNumpyMultiViewTransform, FromNumpyDefaultTransform, load_transform
from oucl.agents.encoders import load_encoder
from oucl.agents.losses import load_loss

from sklearn.neighbors import KNeighborsClassifier
from sklearn.cluster import KMeans

from collections import deque
import numpy as np
import wandb
from tqdm import tqdm


class BaseAgent(ABC):

    def __init__(self, config):
        self.name = config.agent.name

        self.linear_epochs = config.agent.linear_epochs
        self.linear_lr = config.agent.linear_lr
        self.linear_wd = config.agent.linear_wd

    def initialize(self, dataloader, config):
        return

    @abstractmethod
    def __call__(self, x):
        pass

    def supervise(self, super_loader):
        self.encoder.eval()
        zs = []
        ys = []

        for it, (x, y, _, _) in enumerate(super_loader):
            zs.append(self.encoder.embed(x).detach().cpu().numpy())
            ys.append(y)

        zs = np.concatenate(zs)
        ys = torch.cat(ys).numpy()
        self.y_offset = ys.min()
        self.num_classes = len(np.unique(ys))
        ys -= self.y_offset

        self.knn_1 = KNeighborsClassifier(1).fit(zs, ys)
        self.knn_5 = KNeighborsClassifier(5).fit(zs, ys)
        self.knn_20 = KNeighborsClassifier(20).fit(zs, ys)
        self.knn_50 = KNeighborsClassifier(50).fit(zs, ys)

        self.encoder.train()
        return
    
    def supervise_linear(self, super_loader, eval_step):
        self.encoder.eval()
        
        self.classifier = torch.nn.Linear(self.encoder.backbone.out_dim, self.num_classes).to(self.device)
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(self.classifier.parameters(), self.linear_lr, 
                                    weight_decay=self.linear_wd, momentum=0.9, nesterov=True)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0.0001, T_max=self.linear_epochs)

        for epoch in tqdm(range(self.linear_epochs)):
            epoch_loss = 0
            for it, (x, y, _, _) in enumerate(super_loader):
                optimizer.zero_grad()

                z = self.encoder.embed(x)
                y_out = self.classifier(z)
                loss = criterion(y_out, y.to(self.device) - self.y_offset)
                loss.backward()

                epoch_loss += loss.item()

                optimizer.step()

            scheduler.step()    
            wandb.log({'epoch': epoch, f'linear_loss_{eval_step}': epoch_loss / len(super_loader)})

        self.encoder.train()
        return

    def classify(self, eval_loader):
        self.encoder.eval()
        if hasattr(self, 'classifier'):
            self.classifier.eval()
        else:
            self.classifier = None

        zs = []
        ys = []
        linear_ys = []
        for it, (x, y, _, _) in enumerate(eval_loader):
            z = self.encoder.embed(x).detach()
            zs.append(z.cpu().numpy())
            linear_ys.append(torch.argmax(self.classifier(z), dim=1))
            ys.append(y)

        zs = np.concatenate(zs)
        ys = torch.cat(ys).numpy() - self.y_offset
        

        linear_ys = torch.cat(linear_ys).detach().cpu().numpy()

        y_pred_1 = self.knn_1.predict(zs)
        y_pred_5 = self.knn_5.predict(zs)
        y_pred_20 = self.knn_20.predict(zs)
        y_pred_50 = self.knn_50.predict(zs)

        y_preds = {'1nn': y_pred_1, '5nn': y_pred_5, 
                   '20nn': y_pred_20, '50nn': y_pred_50,
                   'linear': linear_ys}
        
        self.encoder.train()
        return y_preds, ys
    
    def cluster(self, eval_loader):
        self.encoder.eval()
        zs = []
        ys = []
        for it, (x, y, _, _) in enumerate(eval_loader):
            zs.append(self.encoder.embed(x).detach().cpu().numpy())
            ys.append(y)

        zs = np.concatenate(zs)
        ys = torch.cat(ys).numpy() - self.y_offset
        k = len(np.unique(ys))

        y_kmeans_1 = KMeans(k).fit_predict(zs)
        y_kmeans_2 = KMeans(2*k).fit_predict(zs)

        y_preds = {'kmeans_1': y_kmeans_1, 'kmeans_2': y_kmeans_2} 
        
        self.encoder.train()
        return y_preds, ys

class OfflineAgent(BaseAgent):

    def __init__(self, config):
        super().__init__(config)

        self.encoder = load_encoder(config)
        self.device = config.device

        self.loss_fn = load_loss(config.agent.loss, config)

        self.step = 0
        self.current_task = 0

        self.eval_only = config.agent.eval_only

    def __call__(self, data, y, t, inds):
        if self.eval_only:
            return

        self.encoder.train()
        x, raw = data

        self.encoder.optimizer.zero_grad()

        #vutils.save_image(x[:20], f'img_{self.step}.png', normalize=True)
        outs = self.encoder(x)
        outs['y'] = y.to(self.device)

        loss = self.loss_fn(outs)
        loss.backward()

        self.encoder.optimizer.step()
        if not self.encoder.lr_scheduler == None:
            self.encoder.lr_scheduler.step()

        wandb.log({
                'loss_step' : self.step,
                'loss': loss.item(),
            })
        
        self.step += 1

class SSLERAgent(BaseAgent):

    def __init__(self, config):
        super().__init__(config)

        self.encoder = load_encoder(config)
        self.device = config.device
        self.use_memory = config.agent.use_memory
        self.eval_only = config.agent.eval_only

        self.samples_per_step = config.agent.samples_per_step
        self.mem_batch_size = config.agent.mem_batch_size
        self.batch_size = config.scenario.batch_size
        self.memory_type = config.agent.memory_type

        self.loss_fn = load_loss(config.agent.loss, config)

        if config.agent.memory_type == 'reservoir-buffer':
            self.memory = ParallelReservoirSamplingMemory(config.dataset.img_size, 
                                                config.agent.mem_capacity, 
                                                config.agent.mem_batch_size,
                                                FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                                    config.dataset.img_size,
                                                                                    config.dataset.mean,
                                                                                    config.dataset.std,
                                                                                    config.agent)),
                                                config.agent.num_workers, config.agent.parallel_queue_size)
            
        elif config.agent.memory_type == 'queue-buffer':
            self.memory = ParallelQueueMemory(config.dataset.img_size, 
                                                config.agent.mem_capacity, 
                                                config.agent.mem_batch_size,
                                                FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                                    config.dataset.img_size,
                                                                                    config.dataset.mean,
                                                                                    config.dataset.std,
                                                                                    config.agent)),
                                                config.agent.num_workers, config.agent.parallel_queue_size)
            
        elif config.agent.memory_type == 'random-buffer':
            self.memory = ParallelRandomBufferMemory(config.dataset.img_size, 
                                                config.agent.mem_capacity, 
                                                config.agent.mem_batch_size,
                                                FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                                    config.dataset.img_size,
                                                                                    config.dataset.mean,
                                                                                    config.dataset.std,
                                                                                    config.agent)), config.agent.k,
                                                config.agent.num_workers, config.agent.parallel_queue_size)
            
        elif config.agent.memory_type == 'loss-buffer':
            self.memory = ParallelLossBufferMemory(config.dataset.img_size, 
                                                config.agent.mem_capacity, 
                                                config.agent.mem_batch_size,
                                                FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                                    config.dataset.img_size,
                                                                                    config.dataset.mean,
                                                                                    config.dataset.std,
                                                                                    config.agent)),
                                                config.agent.beta,
                                                config.agent.k,
                                                config.agent.num_workers, config.agent.parallel_queue_size)


        else:
            raise ValueError(f"memory_type must be either 'queue' or 'reservoir' found {config.agent.memory_type}")
        
        self.current_task = 0
        self.step = 0
        self.unique_labels = set()


    def training_step(self, x):
        self.encoder.optimizer.zero_grad()

        outs = self.encoder(x)

        loss, ind_loss = self.loss_fn(outs, return_ind=True)
        loss.backward()

        self.encoder.optimizer.step()
        if not self.encoder.lr_scheduler == None:
            self.encoder.lr_scheduler.step()

        return loss, ind_loss

    def __call__(self, data, y, t, inds):
        if self.eval_only:
            return

        self.unique_labels.update(y.numpy().tolist())
        self.encoder.train()
        x, raw = data

        if int(self.memory.n_items.value) >= 2*self.mem_batch_size:
            loss = 0
            for _ in range(self.samples_per_step):
                mem_batch = self.memory.sample_memory()
                xs = torch.cat((x, mem_batch), dim=1)

                l, l_ind = self.training_step(xs)
                loss += l
                
            wandb.log({
                'loss_step' : self.step,
                'loss': loss / self.samples_per_step,
                'qsize': self.memory.queue.qsize()
            })

        else:

            loss, l_ind = self.training_step(x)
                
            wandb.log({
                'loss_step' : self.step,
                'loss': loss / self.samples_per_step,
                'qsize': self.memory.queue.qsize()
            })
        
        if self.use_memory:
            if self.memory_type == 'loss-buffer':
                self.memory.add(raw, y, l_ind.detach().cpu().numpy()[:len(raw)].tolist())
            else:
                self.memory.add(raw, y)

            self.memory.log_class_balance(max(self.unique_labels)+1 - (t*20))

        self.step += 1

class SSLER2Agent(BaseAgent):

    def __init__(self, config):
        super().__init__(config)

        self.encoder = load_encoder(config)
        self.device = config.device
        self.use_memory = config.agent.use_memory
        self.memory_type = config.agent.memory_type
        self.eval_only = config.agent.eval_only

        self.samples_per_step = config.agent.samples_per_step
        self.batch_method = config.agent.batch_method

        self.stm_batch_size = config.agent.stm_batch_size
        self.ltm_batch_size = config.agent.ltm_batch_size

        self.stm_capacity = config.agent.stm_capacity
        self.ltm_capacity = config.agent.ltm_capacity

        self.update_interval = config.agent.update_interval

        self.loss_fn = load_loss(config.agent.loss, config)

        self.stm = ParallelLossBufferMemory(config.dataset.img_size, 
                                        config.agent.stm_capacity, 
                                        config.agent.stm_batch_size,
                                        FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                            config.dataset.img_size,
                                                                            config.dataset.mean,
                                                                            config.dataset.std,
                                                                            config.agent)),
                                        config.agent.alpha, config.agent.t,
                                        config.agent.num_workers, config.agent.parallel_queue_size)

        if config.agent.memory_type == 'repulse':
            self.ltm = ParallelRepulsionMemory(config.dataset.img_size, 
                                                config.agent.ltm_capacity, 
                                                config.agent.ltm_batch_size,
                                                FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                                    config.dataset.img_size,
                                                                                    config.dataset.mean,
                                                                                    config.dataset.std,
                                                                                    config.agent)),
                                                FromNumpyDefaultTransform(config.dataset.img_size,
                                                                          config.dataset.mean,
                                                                          config.dataset.std),
                                                config.agent.beta,
                                                config.agent.k,
                                                config.agent.prune,
                                                config.agent.num_workers, config.agent.parallel_queue_size)
            self.ltm.update_encoder(copy.deepcopy(self.encoder))
        elif config.agent.memory_type == 'repulse2':
            self.ltm = ParallelRepulsionMemory2(config.dataset.img_size, 
                                                config.agent.ltm_capacity, 
                                                config.agent.ltm_batch_size,
                                                FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                                    config.dataset.img_size,
                                                                                    config.dataset.mean,
                                                                                    config.dataset.std,
                                                                                    config.agent)),
                                                FromNumpyDefaultTransform(config.dataset.img_size,
                                                                          config.dataset.mean,
                                                                          config.dataset.std),
                                                config.agent.beta,
                                                config.agent.k,
                                                config.agent.prune,
                                                config.agent.num_workers, config.agent.parallel_queue_size,
                                                len(config.scenario.task_classes), len(config.scenario.task_classes[0]))
            self.ltm.update_encoder(copy.deepcopy(self.encoder))
        elif config.agent.memory_type == 'reservoir':
            self.ltm = ParallelReservoirSamplingMemory(config.dataset.img_size, 
                                        config.agent.ltm_capacity, 
                                        config.agent.ltm_batch_size,
                                        FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                            config.dataset.img_size,
                                                                            config.dataset.mean,
                                                                            config.dataset.std,
                                                                            config.agent)),
                                        config.agent.num_workers, config.agent.parallel_queue_size)
        elif config.agent.memory_type == 'random':
            self.ltm = ParallelRandomSamplingMemory(config.dataset.img_size, 
                                        config.agent.ltm_capacity, 
                                        config.agent.ltm_batch_size,
                                        FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                            config.dataset.img_size,
                                                                            config.dataset.mean,
                                                                            config.dataset.std,
                                                                            config.agent)),
                                        config.agent.k,
                                        config.agent.num_workers, config.agent.parallel_queue_size)
            
        elif config.agent.memory_type == 'fixed':
            self.ltm = ParallelFixedThreshMemory(config.dataset.img_size, 
                                        config.agent.ltm_capacity, 
                                        config.agent.ltm_batch_size,
                                        FromNumpyMultiViewTransform(load_transform(config.agent.memory_transform,
                                                                            config.dataset.img_size,
                                                                            config.dataset.mean,
                                                                            config.dataset.std,
                                                                            config.agent)),
                                        FromNumpyDefaultTransform(config.dataset.img_size,
                                                                          config.dataset.mean,
                                                                          config.dataset.std),
                                        config.agent.beta,
                                        config.agent.k,
                                        config.agent.fixed_thresh,
                                        config.agent.num_workers, config.agent.parallel_queue_size)
            self.ltm.update_encoder(copy.deepcopy(self.encoder))


        
        self.current_task = 0
        self.step = 0
        self.unique_labels = set()
        self.loss_history = deque(maxlen=config.agent.window_len)
        self.window_len = config.agent.window_len
        self.first_update = False

        self.logging_transform = FromNumpyDefaultTransform(config.dataset.img_size,
                                                                          config.dataset.mean,
                                                                          config.dataset.std)


    def training_step(self, x):
        self.encoder.optimizer.zero_grad()
        

        outs = self.encoder(x)

        loss, ind = self.loss_fn(outs, return_ind=True)
        
        loss.backward()

        self.encoder.optimizer.step()

        if not self.encoder.lr_scheduler == None:
            self.encoder.lr_scheduler.step()

        return loss, ind
    
    def __call__(self, data, y, t, inds):
        if self.eval_only:
            return

        self.encoder.train()
        x, raw = data

        self.unique_labels.update(y.numpy().tolist())

        if t > self.current_task:
            if hasattr(self.ltm, 'update_p'):
                self.ltm.update_p()
            self.current_task += 1

        if int(self.stm.n_items.value) >= 2*self.stm_batch_size:
            loss = 0
            for _ in range(self.samples_per_step):
                if int(self.ltm.n_items.value) > 2*self.ltm_batch_size:
                    stm_batch = self.stm.sample_memory()
                    ltm_batch = self.ltm.sample_memory()
                    mem_batch = torch.cat((stm_batch, ltm_batch), dim=1)
                else:
                    mem_batch = self.stm.sample_memory()

                xs = torch.cat((x, mem_batch), dim=1)

                l, l_ind = self.training_step(xs)
                loss += l
                
            wandb.log({
                'loss_step' : self.step,
                'loss': loss / self.samples_per_step,
                #'qsize': self.memory.queue.qsize()
            })

        else:
            loss, l_ind = self.training_step(x)
                
            wandb.log({
                'loss_step' : self.step,
                'loss': loss / self.samples_per_step,
                #'qsize': self.memory.queue.qsize()
            })
        
        if self.use_memory:
            evicted_x, evicted_y = self.stm.add(raw, y, l_ind.detach().cpu().numpy()[:len(raw)].tolist())
            if len(evicted_y) > 0:
                if self.memory_type=='fixed' or self.memory_type == 'repulse' or self.memory_type == 'repulse2':
                    self.ltm.add(evicted_x, evicted_y)
                    if ((self.step+1) % self.update_interval == 0) or self.first_update == False:
                        self.ltm.update_encoder(copy.deepcopy(self.encoder))
                        self.first_update = True
                else:
                    self.ltm.add(evicted_x, evicted_y)

            wandb.log({
                'ltm_q_size': self.ltm.queue.qsize(),
                })

            #self.stm.log_class_balance(max(self.unique_labels)+1 - (t*20))
            if (self.step + 1) % 250 == 0:
                self.ltm.log_class_balance(max(self.unique_labels)+1, self.current_task)

        self.step += 1

def load_agent(config):
    if config.agent.name == 'ssl-er':
        return SSLERAgent(config)
    elif config.agent.name == 'ssl-er2':
        return SSLER2Agent(config)
    elif config.agent.name == 'offline':
        return OfflineAgent(config)


