import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
#from tqdm import tqdm_notebook as tqdm
from typing import List
import sys
from base import BaseTrainer
from utils import inf_loop, get_logger, Timer
from collections import OrderedDict
import argparse
from models.sup_con import LinearClassifier


class CE_Separate_Trainer(BaseTrainer):
    """
    Trainer class
    Note:
        Inherited from BaseTrainer.
    """ # Added ft_data_loader # Removed
    def __init__(self, model, train_criterion, metrics, optimizer, config, data_loader, ft_data_loader,
                 valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None, val_criterion=None):
        model, ft_model = model # Added line
        super().__init__(model, metrics, optimizer, config, val_criterion)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader

        self.test_data_loader = test_data_loader
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        # Define the linear classifier model and other related set up
        self.ft_data_loader = ft_data_loader
        # Newly added loss = ce
        self.ft_model = ft_model.to(self.device)
        
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.train_loss_list: List[float] = []
        self.val_loss_list: List[float] = []
        self.test_loss_list: List[float] = []

        self.train_criterion = train_criterion

        #Visdom visualization
        self.new_best_val = False
        self.val_acc = 0
        self.test_val_acc = 0

    def _eval_metrics(self, output, label):
        acc_metrics = np.zeros(len(self.metrics))
        for i, metric in enumerate(self.metrics):
            acc_metrics[i] += metric(output, label)
        return acc_metrics

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch
        :param epoch: Current training epoch.
        :return: A log that contains all information you want to save.
        Note:
            If you have additional information to record, for example:
                > additional_log = {"x": x, "y": y}
            merge it with log before return. i.e.
                > log = {**log, **additional_log}
                > return log
            The metrics in log must have the key 'metrics'.
        """
        self.model.train()
        self.ft_model.train()

        total_loss = 0
        total_metrics = np.zeros(len(self.metrics))

        with tqdm(self.data_loader) as progress:
            for batch_idx, (data, label) in enumerate(progress):
                progress.set_description_str(f'Train epoch {epoch}')
                
                data, label = data.to(self.device), label.to(self.device)

                features = self.model(data) # This features are after the projection head of CE
                logits = self.ft_model(features)

                loss = self.train_criterion(logits, label) 
                total_metrics += self._eval_metrics(logits, label) 
                    
                self.optimizer.zero_grad()
                #self.optim_ft.zero_grad()
                loss.backward()
                
                self.optimizer.step()
                #self.optim_ft.step()

                self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx, epoch=epoch)
                self.writer.add_scalar({'loss': loss.item()})
                self.train_loss_list.append(loss.item())
                total_loss += loss.item()

                if batch_idx % self.log_step == 0:
                    progress.set_postfix_str(' {} Loss: {:.6f}'.format(
                        self._progress(batch_idx),
                        loss.item()))
                    #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

                if batch_idx == self.len_epoch:
                    break

        log = {
            'loss': total_loss / self.len_epoch,
            'metrics': (total_metrics / self.len_epoch).tolist(),
            'learning rate': self.lr_scheduler.get_lr()
        }
        
        # We dont't do validation and we do test every some time
        # This is because the test process is quite expansive
        if self.do_test:
            test_log = self._test_epoch(epoch)
            log.update(test_log)

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log
    
    def _test_epoch(self, epoch):
        """
        Test after training an epoch
        :return: A log that contains information about test
        Note:
            The Test metrics in log must have the key 'val_metrics'.
        """
        self.model.eval()
        self.ft_model.eval()
        total_test_loss = 0
        total_test_metrics = np.zeros(len(self.metrics))
        
        with torch.no_grad():
            with tqdm(self.test_data_loader) as progress:
                for batch_idx, (data, label) in enumerate(progress):
                    progress.set_description_str(f'Test epoch {epoch}')
                    data, label = data.to(self.device), label.to(self.device)
                    feature = self.model(data)
                    logits = self.ft_model(feature)
                    
                    loss = self.val_criterion(logits, label)

                    self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, epoch=epoch, mode = 'test')
                    self.writer.add_scalar({'loss': loss.item()})
                    self.test_loss_list.append(loss.item())
                    total_test_loss += loss.item()
                    total_test_metrics += self._eval_metrics(logits, label)
                    # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

        # add histogram of model parameters to the tensorboard
        # for name, p in self.model.named_parameters():
        #     self.writer.add_histogram(name, p, bins='auto')
        top_1_acc = (total_test_metrics / len(self.test_data_loader)).tolist()[0]
        self.writer.add_scalar({'Top-1': top_1_acc}, epoch = epoch)
        self.writer.add_scalar({'Top-5': (total_test_metrics / len(self.test_data_loader)).tolist()[1]}, epoch = epoch)

        return {
            'test_loss': total_test_loss / len(self.test_data_loader),
            'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
        }

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
