from mia.conversion import Converter
import mia.losses as losses
import mia.metalearning as metalearning

import torch
import wandb
from tqdm import tqdm
from matplotlib import pyplot as plt
import time
from typing import Dict
from torch.distributions.beta import Beta

from mia.utils import (
    sample_nmr,
    generate_random_masks,
    to_device
)

class BaseTrainer:
    def __init__(
        self,
        func_rep,
        converter: Dict[str, Converter],
        args,
        train_dataset,
        test_dataset,
        model_path = "",
    ):
        self.func_rep = func_rep
        self.converter = converter
        self.args = args
        self.modes = args.modes
        self.num_modes = len(args.modes)
        self.sample_modes = args.sample_modes
        self.num_sample_modes = len(args.sample_modes)

        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self._process_datasets()

        if self.train_dataset is not None:
            inr_params = self.func_rep.get_inr_params()
            self.outer_optimizer_inr = torch.optim.Adam([
                { 'params' : inr_params.values(), 'lr' : args.outer_lr },
            ], lr = args.outer_lr)

            enc_params = self.func_rep.get_non_inr_params()
            logvar_params = self.func_rep.get_logvars()
            self.outer_optimizer_enc = torch.optim.Adam([
                { 'params' : enc_params.values(), 'lr' : args.outer_lr * args.encoder_lr_ratio},
                { 'params' : logvar_params.values(), 'lr' : args.outer_lr * args.logvar_lr_ratio},
            ], lr = args.outer_lr)

        self.model_path = model_path
        self.step = 0
        self.stime = time.time()

    def load_pretrained(
        self,
        states = None,
        step = None,
        filename = None,
        model_path = None,
        strict=False,
        load_step=True,
        load_optimizer=False,
        load_inr_only=False
    ):
        if states is None:
            if step is not None:
                model_path = self.args.log_dir / 'ckpt' / f'{step:010d}.pt'

            if filename is not None:
                model_path = self.args.log_dir / 'ckpt' / f'{filename}'

            states = torch.load(model_path, map_location='cpu')

        if load_inr_only:
            states['state_dict'] = {k:v for k,v in states['state_dict'].items() if 'grad' not in k}

        self.func_rep.load_state_dict(states['state_dict'], strict=strict)

        if load_step:
            self.step = states['step']

        if load_optimizer:
            self.outer_optimizer_inr.load_state_dict(states['state_dict_optim_inr'])
            self.outer_optimizer_enc.load_state_dict(states['state_dict_optim_enc'])

    def _process_datasets(self):
        """Create dataloaders for datasets based on self.args."""
        self.train_dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            shuffle=True,
            batch_size=self.args.batch_size,
            num_workers=self.args.num_workers,
            pin_memory=self.args.num_workers > 0,
            drop_last=True,
        ) if self.train_dataset is not None else None

        self.test_dataloader = torch.utils.data.DataLoader(
            self.test_dataset,
            shuffle=False,
            batch_size=self.args.validation_batch_size,
            num_workers=self.args.num_workers,
        ) if self.test_dataset is not None else None

    @torch.no_grad()
    def create_supports_and_querys(self, data_dict, distR_dict, distM_dict):
        xs_support_dict, ys_support_dict, ms_support_dict = {}, {}, {}
        ids_restore_dict, ids_sample_dict = {}, {}
        xs_query_dict, ys_query_dict, ms_query_dict = {}, {}, {}
        xs_target_dict, ys_target_dict, ms_target_dict = {}, {}, {}

        Rmin, Rmax = self.args.Rmin, self.args.Rmax

        for mode in self.sample_modes:
            data = data_dict[mode].to(self.args.device)
            x_full, y_full = self.converter[mode].to_coordinates_and_features(data)
            x_full, y_full = x_full.flatten(1, -2), y_full.flatten(1, -2)
            batch_size, n = x_full.shape[:2]

            if 'uncorrelated' in self.args and self.args.uncorrelated:
                perm = torch.randperm(x_full.shape[0])
                x_full = x_full[perm]
                y_full = y_full[perm]

            # Query set is always full coordinates
            xs_query_dict[mode], ys_query_dict[mode] = x_full, y_full
            ms_query_dict[mode] = torch.zeros_like(xs_query_dict[mode][..., 0])

            num_sample_min = int(n * Rmin[mode])
            num_sample_max = int(n * Rmax[mode])

            if distR_dict[mode] is None:
                num_samples = num_sample_max
            else:
                num_samples = distR_dict[mode].sample((1,)).item() * (num_sample_max + 1 - num_sample_min) + num_sample_min
                num_samples = min(n, int(num_samples))

            xs_support_dict[mode], ys_support_dict[mode], ids_restore_dict[mode], ids_sample_dict[mode] = sample_nmr(x_full, y_full, num_samples)
            ms_support_dict[mode] = generate_random_masks(
                xs_support_dict[mode],
                num_sample_min = num_sample_min,
                num_sample_max = num_samples,
                dist = distM_dict[mode],
            )

        for mode in self.sample_modes:
            if mode not in self.modes:
                ms_support_dict.pop(mode)
                xs_support_dict.pop(mode)
                ys_support_dict.pop(mode)
                xs_query_dict.pop(mode)
                ys_query_dict.pop(mode)

        for mode in self.modes:
            if self.args.meta_target == 'query':
                xs_target_dict[mode] = xs_query_dict[mode]
                ys_target_dict[mode] = ys_query_dict[mode]
                ms_target_dict[mode] = ms_query_dict[mode]
            else:
                xs_target_dict[mode] = xs_support_dict[mode]
                ys_target_dict[mode] = ys_support_dict[mode]
                ms_target_dict[mode] = ms_support_dict[mode]

        return batch_size, xs_support_dict, ys_support_dict, ms_support_dict, ids_restore_dict, ids_sample_dict, \
            xs_query_dict, ys_query_dict, ms_query_dict, \
            xs_target_dict, ys_target_dict, ms_target_dict

    @torch.no_grad()
    def visualize(self, *args, **kwargs):
        NotImplemented

    def train_epoch(self):
        self.func_rep.train()
        inner_steps = self.args.inner_steps

        distR_dict = { mode : Beta(self.args.Ra[mode], self.args.Rb[mode]) if self.args.Ra[mode] < 100 else None for mode in self.modes }
        distM_dict = { mode : Beta(self.args.Ma[mode], self.args.Mb[mode]) for mode in self.modes }

        """ Train model for a single epoch """
        for data_dict, _ in tqdm(self.train_dataloader):
            self.step += 1

            """ 1. Preprocessing """
            (batch_size, xs_support_dict, ys_support_dict, ms_support_dict, ids_restore_dict, _, \
            xs_query_dict, ys_query_dict, _, \
            xs_target_dict, ys_target_dict, _) = self.create_supports_and_querys(data_dict, distR_dict, distM_dict)

            num_vis = min(batch_size, self.args.num_vis)

            """ 2. Inner Loop Adaptation """
            loss_dict = {}
            grad_target_dict = {}
            grad_loss_dict = {}
            metric_dict = {}
            ys_recon_dict = {}

            inner_lr_dict = self.func_rep.meta_lr
            latent_prior_dict = self.func_rep.init_latent(batch_size)
            latent_init_dict = self.func_rep.pool_latent(latent_prior_dict, xs_support_dict, ys_support_dict, ms_support_dict, None)

            for inner_step in range(inner_steps + 1):
                if inner_step == 0:
                    latent_dict = { mode : latent_init_dict[mode] for mode in self.modes }
                else:
                    latent_dict, _  = metalearning.inner_loss_step(
                        func_rep = self.func_rep,
                        modes = self.modes,
                        xs_support_dict = xs_support_dict,
                        ys_support_dict = ys_support_dict,
                        ms_support_dict = ms_support_dict,
                        inner_lr_dict = inner_lr_dict,
                        latent_dict = latent_dict,
                        is_train = True,
                    )

                if inner_step == 0 or inner_step in self.args.outer_steps:
                    if inner_step not in loss_dict:
                        loss_dict[inner_step] = {}
                        metric_dict[inner_step] = {}

                    for mode in self.modes:
                        with torch.set_grad_enabled(inner_step in self.args.outer_steps):
                            ys_recon = self.func_rep.modulated_forward_single(xs_target_dict[mode], latent_dict[mode], mode)

                            per_example_loss = losses.batch_loss_fn(ys_recon, ys_target_dict[mode], mode)
                            per_example_metric = losses.batch_metric_fn(ys_recon, ys_target_dict[mode], mode)

                            loss_dict[inner_step][mode] = per_example_loss.mean()
                            metric_dict[inner_step][mode] = per_example_metric.mean()

                        if self.args.use_wandb and self.step % self.args.log_image_interval == 0:
                            if inner_step not in ys_recon_dict:
                                ys_recon_dict[inner_step] = {}

                            with torch.no_grad():
                                if self.args.meta_target != 'query':
                                    ys_recon = self.func_rep.modulated_forward_single(xs_query_dict[mode], latent_dict[mode], mode)


                            ys_recon_dict[inner_step][mode] = ys_recon[:num_vis].detach().cpu()

            """ 3. Outer update """
            loss = 0.0
            for inner_step in loss_dict:
                for mode in self.modes:
                    if self.args.loss_weight_mode == 'none':
                        loss += loss_dict[inner_step][mode]
                    else:
                        logvar_mode = self.func_rep.logvars[mode].squeeze()
                        if mode in ['semseg']:
                            loss += loss_dict[inner_step][mode] / (1 * logvar_mode.exp()) + logvar_mode / 2
                        else:
                            loss += loss_dict[inner_step][mode] / (2 * logvar_mode.exp()) + logvar_mode / 2

            total_loss = loss

            self.outer_optimizer_inr.zero_grad()
            self.outer_optimizer_enc.zero_grad()
            total_loss.backward(create_graph=False)

            self.outer_optimizer_inr.step()
            self.outer_optimizer_enc.step()

            for mode in self.modes:
                self.func_rep.meta_lr[mode].data.clamp_(0.0, self.args.meta_sgd_lr_max)

            if self.step % self.args.log_interval == 0:
                print(f'Step {self.step}, Total Loss {total_loss:.5f}')
                for mode in self.modes:
                    print(f'{mode:>10s}: (0-step) Loss {loss_dict[0][mode]:.5f}, Metric ({losses.metric_name[mode]}) {metric_dict[0][mode]:.5f}')
                    print(f'{mode:>10s}: ({inner_steps}-step) Loss {loss_dict[inner_steps][mode]:.5f}, Metric ({losses.metric_name[mode]}) {metric_dict[inner_steps][mode]:.5f}')

            if self.args.use_wandb and self.step % self.args.log_interval == 0:
                log_dict = {}

                for mode in self.modes:
                    log_dict[f"inner_lr-{mode}"] = self.func_rep.meta_lr[mode].mean().item()

                log_dict[f"train-loss-avg"] = loss.item()

                for inner_step in loss_dict:
                    for mode in self.modes:
                        log_dict[f"train-loss-{mode}-in_step:{inner_step}"] = loss_dict[inner_step][mode].item()
                        log_dict[f"train-metric-{mode}-in_step:{inner_step}"] = metric_dict[inner_step][mode].item()

                for inner_step in grad_loss_dict:
                    for mode in self.modes:
                        log_dict[f"train-grad_loss-{mode}-in_step:{inner_step}"] = grad_loss_dict[inner_step][mode].item()

                if self.args.loss_weight_mode == 'uncertainty':
                    for mode in self.modes:
                        log_dict[f"train-logvar-{mode}"] = self.func_rep.logvars[mode].item()

                wandb.log(log_dict, step=self.step)

            if self.args.use_wandb and self.step % self.args.log_image_interval == 0:
                log_dict = {}

                vis = self.visualize(
                    num_vis = num_vis,
                    xs_support_dict = xs_support_dict,
                    ys_support_dict = ys_support_dict,
                    ms_support_dict = ms_support_dict,
                    ids_restore_dict = ids_restore_dict,
                    xs_query_dict = xs_query_dict,
                    ys_query_dict = ys_query_dict,
                    ys_recon_dict = ys_recon_dict[inner_steps],
                    ys_recon_init_dict = ys_recon_dict[0],
                    data_dict = data_dict
                )

                if isinstance(vis, dict):
                    for mode in self.modes:
                        log_dict[f'train-recon-{mode}'] = wandb.Image(vis[mode], caption=f'mode:{mode}')
                else:
                    log_dict[f'train-recon'] = wandb.Image(vis, caption=f'modes:{self.modes}')

                wandb.log(log_dict, step=self.step)

                plt.close()

            if self.step % self.args.validate_every == 0:
                torch.cuda.empty_cache()
                for valid_inner_steps in self.args.validation_inner_steps:
                    self.validation(valid_inner_steps)
                    torch.cuda.empty_cache()

                model_path = self.args.log_dir / 'ckpt' / f'{self.step:010d}.pt'
                model_path.parent.mkdir(parents=True, exist_ok=True)
                torch.save(
                    {
                        "args": self.args,
                        "state_dict": self.func_rep.state_dict(),
                        "state_dict_optim_inr": self.outer_optimizer_inr.state_dict(),
                        "state_dict_optim_enc": self.outer_optimizer_enc.state_dict(),
                        "step": self.step,
                    },
                    model_path,
                )

            if time.time() - self.stime > 1200:
                print(f'Step: {self.step} -- Save model to {self.model_path}')
                torch.save(
                    {
                        "args": self.args,
                        "state_dict": self.func_rep.state_dict(),
                        "state_dict_optim_inr": self.outer_optimizer_inr.state_dict(),
                        "state_dict_optim_enc": self.outer_optimizer_enc.state_dict(),
                        "step": self.step,
                    },
                    self.model_path,
                )
                self.stime = time.time()

    @torch.no_grad()
    def validation(self, inner_steps):
        print(f"\nValidation, Step {self.step}:")
        print(f"Inner steps {inner_steps}")
        for mode in self.modes:
            print(f"{mode:>10s}: Rmin:{self.args.Rmin[mode]:.3f}-Rmax:{self.args.Rmax[mode]:.3f}")

        self.func_rep.eval()
        log_meta = f"in_step:{inner_steps}"

        loss_dict = { k : 0 for k in self.modes }
        metric_dict = { k : 0 for k in self.modes }
        ys_recon_dict = {}
        ys_recon_init_dict = {}

        loss_dict_per_sample = { k : [] for k in self.modes }
        metric_dict_per_sample = { k : [] for k in self.modes }
        Rs_support_dict_per_sample = { k : [] for k in self.modes }

        distR_dict = { mode : Beta(self.args.Ra[mode], self.args.Rb[mode]) if self.args.Ra[mode] < 100 else None for mode in self.modes }
        distM_dict = { mode : Beta(self.args.Ma[mode], self.args.Mb[mode]) for mode in self.modes }


        for data_dict, info_dict in tqdm(self.test_dataloader):
            # only synthetic data has info_dict, otherwise `None`

            if 'xs_support_dict' in data_dict:
                data_dict = to_device(data_dict, self.args.device)
                batch_size = data_dict[self.modes[0]].shape[0]
                xs_support_dict = data_dict.pop('xs_support_dict')
                ys_support_dict = data_dict.pop('ys_support_dict')
                ms_support_dict = data_dict.pop('ms_support_dict')
                xs_query_dict = data_dict.pop('xs_query_dict')
                ys_query_dict = data_dict.pop('ys_query_dict')
                ids_restore_dict = data_dict.pop('ids_restore_dict')
                if self.args.dataset_config['name'] in ['synthetic']:
                    info_dict = data_dict.pop('info_dict')
            else:
                """ 1. Preprocessing """
                (batch_size, xs_support_dict, ys_support_dict, ms_support_dict, ids_restore_dict, _, \
                xs_query_dict, ys_query_dict, _, \
                _, _, _) = self.create_supports_and_querys(data_dict, distR_dict, distM_dict)

            num_vis = min(self.args.num_vis, batch_size)

            """ 2. Inner Loop Adaptation """
            inner_lr_dict = self.func_rep.meta_lr

            # with torch.enable_grad():
            latent_init_dict = self.func_rep.init_latent(batch_size)
            latent_init_dict = self.func_rep.pool_latent(latent_init_dict, xs_support_dict, ys_support_dict, ms_support_dict, None)
            latent_init_dict = {
                mode : v.requires_grad_(True) for mode, v in latent_init_dict.items()
            }

            for inner_step in range(inner_steps + 1):
                if inner_step == 0:
                    latent_dict = latent_init_dict
                else:
                    latent_dict, _  = metalearning.inner_loss_step(
                        func_rep = self.func_rep,
                        modes = self.modes,
                        xs_support_dict = xs_support_dict,
                        ys_support_dict = ys_support_dict,
                        ms_support_dict = ms_support_dict,
                        inner_lr_dict = inner_lr_dict,
                        latent_dict = latent_dict,
                        is_train = False,
                    )

                if inner_step == 0 and self.args.use_wandb:
                    latent0_dict = {
                        mode : v.detach() for mode, v in latent_dict.items()
                    }

            """ 3. Evaluate on Query points """
            for mode in self.modes:
                ys_recon = self.func_rep.modulated_forward_single(xs_query_dict[mode], latent_dict[mode], mode)

                scale = info_dict[mode]['a'].to(self.args.device).view(-1, 1, 1) ** 2 if isinstance(info_dict, dict) else 1.0
                per_example_loss = losses.batch_loss_fn(ys_recon, ys_query_dict[mode], mode, norm = scale)
                per_example_metric = losses.batch_metric_fn(ys_recon, ys_query_dict[mode], mode, norm = scale)

                loss_dict_per_sample[mode] += [per_example_loss.detach().cpu()]
                metric_dict_per_sample[mode] += [per_example_metric.detach().cpu()]
                Rs_support_dict_per_sample[mode] += [(ms_support_dict[mode].shape[1] - ms_support_dict[mode].sum(1)).div(xs_query_dict[mode].shape[1]).detach().cpu()]

                if self.args.use_wandb:
                    ys_recon_dict[mode] = ys_recon[:num_vis].detach().cpu()

                    ys_recon_init = self.func_rep.modulated_forward_single(xs_query_dict[mode], latent0_dict[mode], mode)
                    ys_recon_init_dict[mode] = ys_recon_init[:num_vis].detach().cpu()

        for mode in self.modes:
            loss_dict_per_sample[mode] = torch.cat(loss_dict_per_sample[mode])
            metric_dict_per_sample[mode] = torch.cat(metric_dict_per_sample[mode])
            Rs_support_dict_per_sample[mode] = torch.cat(Rs_support_dict_per_sample[mode])

            loss_dict[mode] = loss_dict_per_sample[mode].nanmean().item()
            metric_dict[mode] = metric_dict_per_sample[mode].nanmean().item()

        metric_log_dict = { mode : {} for mode in self.modes }
        loss_log_dict = { mode : {} for mode in self.modes }
        for mode in self.modes:
            for Rrange in self.args.Rrange_lists[mode]:
                Rmin, Rmax = Rrange
                log_name = f"Rmin:{Rmin:.3f}-Rmax:{Rmax:.3f}"

                mask = ((Rs_support_dict_per_sample[mode] <= Rmax) & (Rmin <= Rs_support_dict_per_sample[mode]))

                metric = metric_dict_per_sample[mode][mask].nanmean()
                metric_log_dict[mode][log_name] = 0. if metric.isnan() else metric.item()

                loss = loss_dict_per_sample[mode][mask].nanmean()
                loss_log_dict[mode][log_name] = 0. if loss.isnan() else loss.item()

        """ Logging """
        for mode in self.modes:
            print(f"{mode:>10s}: Loss {loss_dict[mode]:.5f} Metric ({losses.metric_name[mode]}) {metric_dict[mode]:.5f}")

        log_dict = {}

        log_dict[f"val-loss-avg-{log_meta}-avg"] = sum(loss_dict.values()) / len(loss_dict)

        for mode in self.modes:
            log_dict[f"val-loss-{mode}-{log_meta}-avg"] = loss_dict[mode]
            log_dict[f"val-metric-{mode}-{log_meta}-avg"] = metric_dict[mode]

            for Rrange in metric_log_dict[mode]:
                if metric_log_dict[mode][Rrange] > 0:
                    log_dict[f"val-loss-{mode}-{log_meta}-{Rrange}"] = loss_log_dict[mode][Rrange]
                    log_dict[f"val-metric-{mode}-{log_meta}-{Rrange}"] = metric_log_dict[mode][Rrange]

        log_dict[f"val-loss-avg-avg"] = sum(loss_dict.values()) / len(loss_dict)

        for mode in self.modes:
            log_dict[f"val-loss-{mode}-avg"] = loss_dict[mode]
            log_dict[f"val-metric-{mode}-avg"] = metric_dict[mode]

            for Rrange in metric_log_dict[mode]:
                if metric_log_dict[mode][Rrange] > 0:
                    log_dict[f"val-loss-{mode}-{Rrange}"] = loss_log_dict[mode][Rrange]
                    log_dict[f"val-metric-{mode}-{Rrange}"] = metric_log_dict[mode][Rrange]

        if self.args.use_wandb and wandb.run:
            # TODO
            # Visualize samples
            vis = self.visualize(
                num_vis = num_vis,
                xs_support_dict = xs_support_dict,
                ys_support_dict = ys_support_dict,
                ms_support_dict = ms_support_dict,
                ids_restore_dict = ids_restore_dict,
                xs_query_dict = xs_query_dict,
                ys_query_dict = ys_query_dict,
                ys_recon_dict = ys_recon_dict,
                ys_recon_init_dict = ys_recon_init_dict,
                data_dict = data_dict,
            )

            if vis is None:
                pass
            elif isinstance(vis, dict):
                for mode in self.modes:
                    log_dict[f'val-recon-{mode}'] = wandb.Image(vis[mode], caption=f'mode:{mode}')
            else:
                log_dict[f'val-recon'] = wandb.Image(vis, caption=f'modes:{self.modes}')

            wandb.log(log_dict)

        plt.close()

        return log_dict
