import os.path
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import tqdm
import warnings
import wandb

class Trainer:
    def __init__(self, net, train_dataset, test_dataset, args, conf, device=None):
        self.args = args
        self.net = net
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.train_data_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=5,
            pin_memory=False,
            drop_last=True,
        )
        self.test_data_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=min(args.batch_size, 16),
            shuffle=True,
            num_workers=5,
            pin_memory=False,
            drop_last=True,
        )

        self.num_total_batches = len(self.train_dataset)
        self.exp_name = args.name
        self.save_interval = conf.get_int("save_interval")
        self.print_interval = conf.get_int("print_interval")
        self.vis_interval = conf.get_int("vis_interval")
        print("vis interval", self.vis_interval)
        self.eval_interval = conf.get_int("eval_interval")
        self.num_epoch_repeats = conf.get_int("num_epoch_repeats", 1)
        self.num_epochs = args.epochs
        self.num_iters = args.iters
        self.accu_grad = conf.get_int("accu_grad", 1)
        self.summary_path = os.path.join(args.logs_path, args.name)
        self.writer = SummaryWriter(self.summary_path)

        self.fixed_test = hasattr(args, "fixed_test") and args.fixed_test

        os.makedirs(self.summary_path, exist_ok=True)

        # Currently only Adam supported
        self.optim = torch.optim.Adam(net.parameters(), lr=args.lr)
        if args.gamma != 1.0:
            self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer=self.optim, gamma=args.gamma
            )
        else:
            self.lr_scheduler = None

        # Load weights

        self.managed_weight_saving = hasattr(net, "load_weights")
        if self.managed_weight_saving:
            net.load_weights(self.args)
        self.iter_state_path = "%s/%s/_iter" % (
            self.args.checkpoints_path,
            self.args.name,
        )
        self.optim_state_path = "%s/%s/_optim" % (
            self.args.checkpoints_path,
            self.args.name,
        )
        self.lrsched_state_path = "%s/%s/_lrsched" % (
            self.args.checkpoints_path,
            self.args.name,
        )
        self.default_net_state_path = "%s/%s/net" % (
            self.args.checkpoints_path,
            self.args.name,
        )
        ############### Best #################
        self.best_iter_state_path = "%s/%s/best_iter" % (
            self.args.checkpoints_path,
            self.args.name,
        )
        self.best_optim_state_path = "%s/%s/best_optim" % (
            self.args.checkpoints_path,
            self.args.name,
        )
        self.best_lrsched_state_path = "%s/%s/best_lrsched" % (
            self.args.checkpoints_path,
            self.args.name,
        )
        self.best_default_net_state_path = "%s/%s/best_net" % (
            self.args.checkpoints_path,
            self.args.name,
        )
        self.best_score_path = "%s/%s/best_score" % (
            self.args.checkpoints_path,
            self.args.name,
        )

        self.start_iter_id = 0
        if args.resume:
            if os.path.exists(self.optim_state_path):
                try:
                    self.optim.load_state_dict(
                        torch.load(self.optim_state_path, map_location=device)
                    )
                except:
                    warnings.warn(
                        "Failed to load optimizer state at", self.optim_state_path
                    )
            if self.lr_scheduler is not None and os.path.exists(
                self.lrsched_state_path
            ):
                self.lr_scheduler.load_state_dict(
                    torch.load(self.lrsched_state_path, map_location=device)
                )
            if os.path.exists(self.iter_state_path):
                self.start_iter_id = torch.load(
                    self.iter_state_path, map_location=device
                )["iter"]
            if not self.managed_weight_saving and os.path.exists(
                self.default_net_state_path
            ):
                net.load_state_dict(
                    torch.load(self.default_net_state_path, map_location=device)
                )
            if os.path.exists(self.best_score_path):
                self.best_score = torch.load(
                    self.best_score_path, map_location=device
                )["best_score"]
            else:
                self.best_score = 0.
        else:
            self.best_score = 0.

        self.visual_path = os.path.join(self.args.visual_path, self.args.name)
        self.conf = conf

        if self.args.use_wandb:
            
            wandb.init(config = conf, name=args.name)
            wandb.config.update(args)
            # wandb.launch()
            wandb.define_metric("train/step")
            wandb.define_metric("val/step")

            wandb.define_metric("train/loss", summary="min")
            wandb.define_metric("val/loss", summary="min")
            min_metrics = ['rc', 'rf', 'sc', 'sf']
            min_metrics += ['3D_L1_chamfer', '3D_L2_chamfer']
            for m in min_metrics:
                wandb.define_metric("train/"+m, summary="min")
                wandb.define_metric("val/"+m, summary="min")

            if args.calc_metrics:
                metrics = ['acc', 'macc', 'miou']
                # max_metrics = ['2D_acc', '2D_macc', '2D_miou']
                metrics_2D = ['2D_'+m for m in metrics]
                metrics_3D = ['3D_'+m for m in metrics]
                max_metrics = metrics_2D + metrics_3D
                max_metrics += ['shape_'+m for m in max_metrics]
                max_metrics += ['3D_L1_f_tau', '3D_L1_f_2tau', '3D_L1_p_tau', '3D_L1_p_2tau', '3D_L1_r_tau', '3D_L1_r_2tau',
                                '3D_L2_f_tau', '3D_L2_f_2tau', '3D_L2_p_tau', '3D_L2_p_2tau', '3D_L2_r_tau', '3D_L2_r_2tau',]
                for m in max_metrics:
                    wandb.define_metric("train/"+m, summary="max")
                    wandb.define_metric("val/"+m, summary="max")
                # max_metrics = ['']

            wandb.define_metric("train/*", step_metric="train/step")
            wandb.define_metric("val/*", step_metric="val/step")
            
            # logging.info(f"Launch wandb, entity: {config.wandb.entity}")
            # then init tensorboard
            # summary_writer = SummaryWriter(log_dir=config.log_dir)

    def post_batch(self, epoch, batch):
        """
        Ran after each batch
        """
        pass

    def post_epoch(self, epoch):
        """
        Ran after each batch
        """
        pass

    def extra_save_state(self):
        """
        Ran at each save step for saving extra state
        """
        pass

    def reset_metrics(self):
        pass

    def compute_metrics(self):
        return None

    def train_step(self, data, global_step):
        """
        Training step
        """
        raise NotImplementedError()

    def eval_step(self, data, global_step):
        """
        Evaluation step
        """
        raise NotImplementedError()

    def vis_step(self, data, global_step):
        """
        Visualization step
        """
        return None, None

    def start(self):
        def fmt_loss_str(losses):
            return "loss " + (" ".join(k + ":" + f'{losses[k]:.4f}' for k in losses))

        def data_loop(dl):
            """
            Loop an iterable infinitely
            """
            while True:
                for x in iter(dl):
                    yield x

        test_data_iter = data_loop(self.test_data_loader)

        step_id = self.start_iter_id
        best = self.best_score

        progress = tqdm.tqdm(bar_format="[{rate_fmt}] ")
        for epoch in range(self.num_epochs):
            self.writer.add_scalar(
                "lr", self.optim.param_groups[0]["lr"], global_step=step_id
            )

            batch = 0
            for _ in range(self.num_epoch_repeats):
                for data in self.train_data_loader:
                    losses = self.train_step(data, global_step=step_id)
                    if self.args.use_wandb:
                        new_dict = {}
                        for k, v in losses.items():
                            new_dict['train/'+k] = v
                        new_dict['train/step'] = step_id
                        wandb.log(new_dict)
                    loss_str = fmt_loss_str(losses)
                    if step_id % self.print_interval == 0:
                        print(
                            "E",
                            epoch,
                            "B",
                            batch,
                            loss_str,
                            " lr",
                            self.optim.param_groups[0]["lr"],
                        )

                    if step_id % self.eval_interval == 0 and step_id != 0:
                        self.net.eval()
                        with torch.no_grad():
                            if self.args.full_val and self.args.calc_metrics:
                                self.reset_metrics()
                                accum_test_losses = {}
                                for test_data in tqdm.tqdm(self.test_data_loader):
                                    test_losses = self.eval_step(test_data, global_step=step_id)
                                    for k, v in test_losses.items():
                                        if k in accum_test_losses:
                                            accum_test_losses[k] += v
                                        else:
                                            accum_test_losses[k] = v
                                for k, v in accum_test_losses.items():
                                    accum_test_losses[k] /= len(self.test_data_loader)
                                test_losses = accum_test_losses
                                # import pdb; pdb.set_trace()
                                test_losses.update(self.compute_metrics())

                                current = self.compute_metrics()['3D_miou'].item()
                                print(f"best: {best}, current: {current}")
                                if best < current:
                                    print(f"saving the best model!!!!! BOOM!")
                                    best = current
                                    
                                    if self.managed_weight_saving:
                                        self.net.save_weights(self.args, best = True)
                                    else:
                                        torch.save(
                                            self.net.state_dict(), self.best_default_net_state_path
                                        )
                                    torch.save(self.optim.state_dict(), self.best_optim_state_path)
                                    torch.save({"best_score": current}, self.best_score_path)
                                    if self.lr_scheduler is not None:
                                        torch.save(
                                            self.lr_scheduler.state_dict(), self.best_lrsched_state_path
                                        )
                                    torch.save({"iter": step_id + 1}, self.best_iter_state_path)
                                    self.extra_save_state()


                            else:
                                test_data = next(test_data_iter)
                                test_losses = self.eval_step(test_data, global_step=step_id)
                        self.net.train()
                        # print(test_losses)
                        if self.args.use_wandb:
                            new_dict = {}
                            for k, v in test_losses.items():
                                new_dict['val/'+k] = v
                            new_dict['val/step'] = step_id
                            wandb.log(new_dict)
                        test_loss_str = fmt_loss_str(test_losses)

                        self.writer.add_scalars("train", losses, global_step=step_id)
                        self.writer.add_scalars(
                            "test", test_losses, global_step=step_id
                        )
                        print("*** Eval:", "E", epoch, "B", batch, test_loss_str, " lr")
                    # import pdb; pdb.set_trace()
                    if step_id % self.save_interval == 0 and step_id != 0:
                        print("saving")
                        if self.managed_weight_saving:
                            self.net.save_weights(self.args)
                        else:
                            torch.save(
                                self.net.state_dict(), self.default_net_state_path
                            )
                        torch.save(self.optim.state_dict(), self.optim_state_path)
                        if self.lr_scheduler is not None:
                            torch.save(
                                self.lr_scheduler.state_dict(), self.lrsched_state_path
                            )
                        torch.save({"iter": step_id + 1}, self.iter_state_path)
                        self.extra_save_state()

                    if step_id % self.vis_interval == 0 and step_id != 0:
                        print("generating visualization")
                        if self.fixed_test:
                            test_data = next(iter(self.test_data_loader))
                        else:
                            test_data = next(test_data_iter)
                        self.net.eval()
                        with torch.no_grad():
                            vis, vis_vals = self.vis_step(
                                test_data, global_step=step_id
                            )
                        if vis_vals is not None:
                            self.writer.add_scalars(
                                "vis", vis_vals, global_step=step_id
                            )
                        self.net.train()
                        if vis is not None:
                            import imageio

                            vis_u8 = (vis * 255).astype(np.uint8)
                            imageio.imwrite(
                                os.path.join(
                                    self.visual_path,
                                    "{:04}_{:04}_vis.png".format(epoch, batch),
                                ),
                                vis_u8,
                            )

                    if (
                        batch == self.num_total_batches - 1
                        or batch % self.accu_grad == self.accu_grad - 1
                    ):
                        self.optim.step()
                        self.optim.zero_grad()

                    self.post_batch(epoch, batch)
                    step_id += 1
                    batch += 1
                    progress.update(1)
                    if step_id >= self.num_iters: return
            self.post_epoch(epoch)
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()