import copy
from tqdm.contrib.logging import logging_redirect_tqdm
import numpy as np
import torch.nn as nn
import os
from tqdm import tqdm
import torch
import tasks
import torch.multiprocessing as mp
import torch.distributed as dist
import logging
from tensorboardX import SummaryWriter
from schedulers import build_scheduler
import torch_optimizer as torch_optim


log = logging.getLogger(__name__)

class Runner():
    def __init__(self, cfg, task, model, criterion):
        self.cfg = cfg
        self.model = model
        self.task = task
        self.evaluator = None
        self.device = cfg.device
        self.criterion = criterion
        self.exp_dir = os.getcwd()
        self.output_tb = cfg.get("output_tb", True)
        self.logger = None
        if self.output_tb:
            self.logger = SummaryWriter(self.exp_dir)

        if cfg.multi_gpu:
            self.model = torch.nn.DataParallel(self.model)
            log.info(f'Use {torch.cuda.device_count()} GPUs')
        assert not(cfg.device=='cpu' and cfg.multi_gpu)
        self.model.to(self.device)
        self.optim = self._init_optim(self.cfg)
        self.scheduler = build_scheduler(self.cfg.scheduler, self.optim)
        total_steps = self.cfg.total_steps
        self.progress = tqdm(total=total_steps, dynamic_ncols=True, desc="overall")

        if 'start_from_ckpt' in cfg:
            self.load_from_ckpt()

    def load_from_ckpt(self):
        ckpt_path = self.cfg.start_from_ckpt
        init_state = torch.load(ckpt_path)
        self.task.load_model_weights(self.model, init_state['model'], self.cfg.multi_gpu)
        self.optim.load_state_dict(init_state["optim"])
        self.scheduler.load_state_dict(init_state["optim"])

    def _init_optim(self, args):
        if args.optim == "SGD":
            optim = torch.optim.SGD(self.model.parameters(), lr=args.lr, momentum = 0.9)
        elif args.optim == 'Adam':
            optim = torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=0.01)
        elif args.optim == 'AdamW':
            optim = torch.optim.AdamW(self.model.parameters(), lr=args.lr)
        elif args.optim == 'AdamW_finetune':
            linear_out_params = self.model.linear_out.parameters() if not self.cfg.multi_gpu else self.model.module.linear_out.parameters()
            ignored_params = list(map(id, linear_out_params))
            base_params = filter(lambda p: id(p) not in ignored_params,
                                 self.model.parameters())

            optim = torch.optim.AdamW([
                        {'params': base_params},
                        {'params': linear_out_params, 'lr': args.lr}
                    ], lr=args.lr*0.1)
        elif args.optim == 'LAMB':
            optim = torch_optim.Lamb(self.model.parameters(), lr=args.lr)
        else:
            print("no valid optim name")
        return optim

    def output_logs(self, train_logging_outs, val_logging_outs):
        global_step = self.progress.n
        train_logging_outs['lr'] = self.scheduler.get_lr()
        standard_metrics = ["lr", "loss", "grad_norm"]
        all_standard_metrics = {}
        def add_prefix(prefix, outs):
            for k,v in outs.items():
                if k in standard_metrics:
                    all_standard_metrics[f'{prefix}_{k}'] = v
        add_prefix('train', train_logging_outs)
        add_prefix('val', val_logging_outs)

        log.info(all_standard_metrics)

        if self.logger is not None:
            for k,v in all_standard_metrics.items():
                self.logger.add_scalar(k, v, global_step=global_step)
        self.task.output_logs(train_logging_outs, val_logging_outs, self.logger, global_step)

    def get_valid_outs(self):
        valid_loader = self.get_batch_iterator(self.task.valid_set, self.cfg.valid_batch_size, shuffle=self.cfg.shuffle, num_workers=self.cfg.num_workers)
        valid_logging_outs = self.task.get_valid_outs(self.model, valid_loader, self.criterion, self.device) 
        return valid_logging_outs

    def save_checkpoint_last(self, states, best_val=False):
        cwd = os.getcwd()
        if best_val:
            save_path = os.path.join(cwd, 'checkpoint_best.pth')
        else:
            save_path = os.path.join(cwd, 'checkpoint_last.pth')
        log.info(f'Saving checkpoint to {save_path}')
        torch.save(states, save_path)
        log.info(f'Saved checkpoint to {save_path}')

    def save_checkpoints(self, best_val=False):
        all_states = {}
        all_states = self.task.save_model_weights(self.model, all_states, self.cfg.multi_gpu)
        all_states['optim'] = self.optim.state_dict()
        all_states['scheduler'] = self.scheduler.get_state_dict()
        if self.cfg.multi_gpu:
            all_states['model_cfg'] = self.model.module.cfg
        else:
            all_states['model_cfg'] = self.model.cfg
        self.save_checkpoint_last(all_states)
        if best_val:
            self.save_checkpoint_last(all_states, best_val)
        
    def run_epoch(self, train_loader, total_loss, best_state):
        epoch_loss = []
        for batch in train_loader:
            if self.progress.n >= self.progress.total:
                break
            self.model.train()
            logging_out = self.task.train_step(batch, self.model, self.criterion, self.optim, self.scheduler, self.device, self.cfg.grad_clip)
            total_loss.append(logging_out["loss"])
            epoch_loss.append(logging_out["loss"])
            log_step = self.progress.n % self.cfg.log_step == 0 or self.progress.n == self.progress.total - 1

            ckpt_step = False
            if self.cfg.checkpoint_step > -1:
                ckpt_step = self.progress.n % self.cfg.checkpoint_step == 0 or self.progress.n == self.progress.total - 1

            best_model, best_val = best_state
            valid_logging_outs = {}
            if ckpt_step or log_step:
                self.model.eval()
                valid_logging_outs = self.get_valid_outs()
            if log_step:
                logging_out["loss"] = np.mean(total_loss)
                self.output_logs(logging_out, valid_logging_outs)
                total_loss = []
            if ckpt_step:
                if valid_logging_outs["loss"] < best_val["loss"]:
                    self.save_checkpoints(best_val=True)
                    best_val = valid_logging_outs
                    best_model = copy.deepcopy(self.model)
                else:
                    self.save_checkpoints()
            self.progress.update(1)
        return total_loss, (best_model, best_val)

    def scheduler_step(self):
        pass

    def train(self, train_loader=None):
        if train_loader is None:
            train_loader = self.get_batch_iterator(self.task.train_set, self.cfg.train_batch_size, shuffle=self.cfg.shuffle, num_workers=self.cfg.num_workers, persistent_workers=self.cfg.num_workers>0)

        total_loss = []
        best_val = {"loss": float("inf")}
        best_model = None
        best_state = (best_model, best_val)
        with logging_redirect_tqdm():
            if self.cfg.checkpoint_step > -1:
                self.save_checkpoints()
            while self.progress.n < self.progress.total:
                total_loss, best_state = self.run_epoch(train_loader, total_loss, best_state)
                best_model, best_val = best_state
            self.progress.close()
        return best_model
                
    def test(self, best_model_weights, test_loader=None):
        if test_loader is None:
            test_loader = self.get_batch_iterator(self.task.test_set, self.cfg.valid_batch_size, shuffle=self.cfg.shuffle, num_workers=self.cfg.num_workers, persistent_workers=self.cfg.num_workers>0)

        test_outs = self.task.get_valid_outs(self.model, test_loader, self.criterion, self.device)
        log.info(f"test_results {test_outs}")
        return test_outs

#    def train_test_cross_validate(self):
#        train_datasets, val_datasets = self.task.get_crossval_datasets()
#        #train_loaders = self.task.get_crossval_train_loaders()
#        #val_loaders = self.task.get_crossval_val_loaders()
#        for train_dataset, val_dataset in zip(train_datasets, val_datasets):
#            #make sure that the "save checkpoint" is false
#            train_loader = self.get_batch_iterator(train_dataset, self.cfg.train_batch_size, shuffle=self.cfg.shuffle, num_workers=self.cfg.num_workers, persistent_workers=self.cfg.num_workers>0)            
#            import pdb; pdb.set_trace()
#            #need some way to pass the test loader in here
#            best_model = self.train(train_loader)
#            import pdb; pdb.set_trace()
#            val_loader = self.get_batch_iterator(val_dataset, self.cfg.valid_batch_size, shuffle=self.cfg.shuffle, num_workers=self.cfg.num_workers, persistent_workers=self.cfg.num_workers>0)            
#            test_outs = self.test(best_model, val_loader)

    def get_batch_iterator(self, dataset, batch_size, **kwargs):
        return self.task.get_batch_iterator(dataset, batch_size, **kwargs)
