import os
import time
import torch
import shutil
import numpy as np
import pandas as pd
from tqdm import tqdm
from training.utils import seed_everything, AverageMeter
from models import get_model_from_config
from datasets import get_dataset_from_config

class Trainer:
    def __init__(self, config):
        self.config = config
        self.device = torch.device(self.config.device)

        seed_everything(self.config.seed)
        self.model = self._prepare_model()
        self.initialise_model()

        if self.config.use_default_args:
            self.config = self.model.get_default_args(self.config)

        self.optimizer = self._prepare_optimizer()
        self.scheduler = self._prepare_scheduler()
        self.dataloaders = self._prepare_dataloaders()
        
        self.current_epoch = 0
        self.best_accuracy = 0
        self.stop_training = False
        self.train_loss_list = []
        self.valid_loss_list = []
        self.valid_accuracy_list = []
        self.model_budget_list = []
        self.root_dir = None
        self._checks()

        self.config.print_config()

    def __call__(self, x):
        return self.model(x)

    def initialise_model(self):
        '''
        helper function to initialise model with a pretrained weight
        '''
        pass

    def on_step_start(self):
        '''
        helper function to be called at the start of each step
        '''
        pass

    def on_step_end(self):
        '''
        helper function to be called at the end of each step
        '''
        pass

    def on_train_start(self):
        '''
        helper function to be called at the start of each epoch
        '''
        pass

    def on_train_end(self):
        '''
        helper function to be called at the end of each epoch
        '''
        pass

    def on_eval_end(self, valid_accuracy, valid_loss):
        '''
        helper function to be called at the end of each eval epoch
        '''
        pass

    def criterion(self, y_pred, y_true):
        '''
        helper function to calculate loss
        '''
        pass

    def set_root_dir(self, root_dir=None):
        if root_dir is None:
            root_dir = os.path.join(self.config.root_dir, \
                            self.config.model_name, \
                            self.config.dataset_name, \
                            self.config.experiment_name, \
                            f"{self.config.model_compression_strategy}_{self.config.model_required_bugdet}")
            print("==> Using default root directory for saving models:", root_dir)

        self.root_dir = root_dir

    def get_model_flops(self):
        return self.model.get_model_flops()

    def load_best_model(self):
        file = os.path.join(self.root_dir, "model_best.pth")
        state = torch.load(file, map_location=self.device)
        self.model.load_state_dict(state['state_dict'])
    
    def load_current_model(self):
        file = os.path.join(self.root_dir, "checkpoint.pth")
        state = torch.load(file, map_location=self.device)
        self.model.load_state_dict(state['state_dict'])

    def _prepare_optimizer(self):
        param_optimizer = list(self.model.named_parameters())
        no_decay = self.config.optimizer_no_decay
        optimizer_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': self.config.optimizer_weight_decay},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
            ]

        if self.config.optimizer_type == "sgd":
            optimizer = torch.optim.SGD(optimizer_parameters, lr=self.config.optimizer_lr, momentum=self.config.optimizer_momentum,)
        elif self.config.optimizer_type == "adam":
            optimizer = torch.optim.AdamW(optimizer_parameters, lr=self.config.optimizer_lr)
        else:
            raise NotImplementedError
        return optimizer

    def _prepare_scheduler(self):
        if self.config.scheduler_type == "multi_step":
            scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config.scheduler_milestones, gamma=self.config.scheduler_gamma)
        elif self.config.scheduler_type == "none":
            scheduler = None
        else:
            raise NotImplementedError
        return scheduler
    
    def _prepare_dataloaders(self):
        datasets = get_dataset_from_config(self.config)
        dataloaders = {
            'train': torch.utils.data.DataLoader(datasets['train'], \
                batch_size=self.config.training_batch_size, shuffle=self.config.train_shuffle, \
                    num_workers=self.config.num_workers, pin_memory=self.config.pin_memory),
            'valid': torch.utils.data.DataLoader(datasets['valid'], \
                batch_size=self.config.test_batch_size, shuffle=self.config.test_shuffle, \
                    num_workers=self.config.num_workers, pin_memory=self.config.pin_memory),
            'test': torch.utils.data.DataLoader(datasets['test'], \
                batch_size=self.config.test_batch_size, shuffle=self.config.test_shuffle, \
                    num_workers=self.config.num_workers, pin_memory=self.config.pin_memory),
        }
        return dataloaders

    def _prepare_model(self):
        model = get_model_from_config(self.config)
        model.to(self.device)
        return model

    def _checks(self):
        if self.root_dir is None:
            self.set_root_dir()
        if os.path.exists(self.root_dir):
            print("==> Root directory already exists. Overwriting...")
        os.makedirs(self.root_dir, exist_ok=True)
        self.config.save(self.root_dir)

    def _resume(self, resume_training):
        file = os.path.join(self.root_dir, "checkpoint.pth")
        if resume_training and os.path.exists(file):
            state = torch.load(file, map_location=self.device)
            self.current_epoch = state["epoch"]
            self.best_accuracy = state["accuracy"]
            self.train_loss_list = state["train_loss_list"]
            self.valid_loss_list = state["valid_loss_list"]
            self.valid_accuracy_list = state["valid_accuracy_list"]
            self.model_budget_list = state["model_budget_list"]
            self.model.load_state_dict(state['state_dict'])
            self.optimizer.load_state_dict(state['optimizer'])
            print(f"==> Loaded checkpoint {file} (Epoch: {self.current_epoch},  Best Accuracy: {self.best_accuracy})")

    def _train_epoch(self):
        self.model.train()
        pbar = tqdm(self.dataloaders['train'], total=len(self.dataloaders['train']))
        pbar.set_description(f"[{self.current_epoch+1}/{self.config.epochs}] Train")
        running_loss = AverageMeter()
        for x, y_true in pbar:
            self.on_step_start()
 
            x = x.to(device=self.device)
            y_true = y_true.to(device=self.device)
            bs = x.size(0)
            
            # Forward
            y_pred = self.model(x)

            # Calculate loss
            loss = self.criterion(y_pred, y_true)

            running_loss.update(loss.item(), bs)
            pbar.set_postfix(loss=running_loss.avg)
            
            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            if self.config.clip_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip_grad_norm)
            self.optimizer.step()
            
            self.on_step_end()
        return running_loss.avg

    def evaluate(self, phase):
        self.model.eval()
        pbar = tqdm(self.dataloaders[phase], total=len(self.dataloaders[phase]))
        pbar.set_description(f"[{self.current_epoch+1}/{self.config.epochs}] {phase.capitalize()}")
        running_loss = AverageMeter()
        running_acc = AverageMeter()

        with torch.no_grad():
            for x, y_true in pbar:
                x = x.to(device=self.device)
                y_true = y_true.to(device=self.device)
                bs = x.size(0)
        
                y_pred = self.model(x)

                loss = self.criterion(y_pred, y_true)

                running_loss.update(loss.item(), bs)

                _, y_pred = torch.max(y_pred.data, 1)
                correct = (y_pred == y_true).cpu().numpy().mean()
                running_acc.update(correct, bs)
                pbar.set_postfix(loss=running_loss.avg, accuracy=running_acc.avg)
        return running_acc.avg, running_loss.avg

    def _save_model(self, checkpoint_name):
        torch.save({
                "epoch": self.current_epoch + 1,
                "state_dict" : self.model.state_dict(),
                "accuracy" : self.best_accuracy,
                'optimizer' : self.optimizer.state_dict(),
                "train_loss_list": self.train_loss_list,
                "valid_loss_list": self.valid_loss_list,
                "valid_accuracy_list": self.valid_accuracy_list,
                "model_budget_list": self.model_budget_list
            }, os.path.join(self.root_dir, checkpoint_name))

    def _save_and_log(self, accuracy):
        ## Save model
        if self.current_epoch + 1 in self.config.save_at_epochs:
            self._save_model(f"model_ep{self.current_epoch + 1 }.pth")
            print(f"==> Saved checkpoint {self.current_epoch + 1}")

        self._save_model("checkpoint.pth")

        if accuracy>self.best_accuracy:
            print(f"==> Best Accuracy improved to {accuracy} from {self.best_accuracy}")
            self.best_accuracy = accuracy
            shutil.copyfile(os.path.join(self.root_dir, "checkpoint.pth"), os.path.join(self.root_dir, "model_best.pth"))

        ## Log results
        ## Always log model budget
        self.model_budget_list.append(self.get_model_flops().item())

        data_list = [self.train_loss_list, self.valid_loss_list, self.valid_accuracy_list, self.model_budget_list]
        column_list = ['train_losses', 'valid_losses', 'valid_accuracy', 'model_budget']
        
        df_data = np.array(data_list).T
        df = pd.DataFrame(df_data, columns=column_list)
        df.to_csv(os.path.join(self.root_dir, "logs.csv"))

    def fit(self, resume_training):
        self._resume(resume_training)
        start_epoch = self.current_epoch
        for self.current_epoch in range(start_epoch, self.config.epochs):

            self.on_train_start()
            training_loss = self._train_epoch()
            self.on_train_end()
            valid_accuracy, valid_loss = self.evaluate("valid")

            self.train_loss_list.append(training_loss)
            self.valid_loss_list.append(valid_loss)
            self.valid_accuracy_list.append(valid_accuracy)
            
            if self.scheduler is not None:
                self.scheduler.step()

            self._save_and_log(valid_accuracy)

            self.on_eval_end(valid_accuracy, valid_loss)

            if self.stop_training:
                break

        self.load_best_model()
        test_accuracy, _ = self.evaluate("test")
        print(f"Test Accuracy: {test_accuracy} | Valid Accuracy: {self.best_accuracy}")

