from collections import defaultdict
from datetime import datetime
from functools import partial
from pathlib import Path
import shutil
import sys
import time
from typing import Any, Dict

from diffusers.optimization import get_constant_schedule_with_warmup
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import trange
import wandb

from agent import Agent
from collect import make_collector
from data import BatchSampler, collate_segments_to_batch, DatasetTraverser, EpisodeDataset
from envs import Env, WorldModelEnv
from models.actor_critic import ActorCritic 
from models.diffuser import WorldModel
from rl_step import make_actor_critic_loss_computer
from utils import configure_optimizer, set_seed, try_until_no_except


class Trainer:
    def __init__(self, cfg: DictConfig) -> None:
        torch.backends.cuda.matmul.allow_tf32 = True
        OmegaConf.resolve(cfg)
        self.cfg = cfg
        if cfg.common.seed is None:
            cfg.common.seed = int(datetime.now().timestamp()) % 10 ** 5
        set_seed(cfg.common.seed)
        try_until_no_except(partial(wandb.init, config=OmegaConf.to_container(cfg, resolve=True), reinit=True, resume=True, **cfg.wandb))

        self.device = torch.device(cfg.common.device)
        if 'cuda' in cfg.common.device:
            torch.cuda.set_device(self.device) # (quick) fix compilation error on multi-gpu nodes

        self.start_epoch = 1
        self.num_batch_denoiser = 0
        self.num_batch_rew_end = 0
        self.num_batch_ac = 0

        self.ckpt_dir = Path('checkpoints')
        self.media_dir = Path('media')
        self.episode_dir = self.media_dir / 'episodes'
        self.images_dir = self.media_dir / 'images'

        if cfg.collection.path_to_static_dataset is None:
            self.dataset_dir = Path('dataset')
        else:
            print(f'Using a static dataset, no collection.')
            self.dataset_dir = Path(cfg.collection.path_to_static_dataset).expanduser()
            assert (self.dataset_dir / 'train').is_dir() and (self.dataset_dir / 'test').is_dir()

        if not cfg.common.resume:
            config_dir = Path('config')
            config_path = config_dir / 'trainer.yaml'
            config_dir.mkdir(exist_ok=False, parents=False)
            shutil.copy('.hydra/config.yaml', config_path)
            wandb.save(str(config_path))
            shutil.copytree(src=(Path(hydra.utils.get_original_cwd()) / "src"), dst="./src")
            shutil.copytree(src=(Path(hydra.utils.get_original_cwd()) / "scripts"), dst="./scripts")
            self.ckpt_dir.mkdir(exist_ok=False, parents=False)
            self.media_dir.mkdir(exist_ok=False, parents=False)
            self.episode_dir.mkdir(exist_ok=False, parents=False)
            self.images_dir.mkdir(exist_ok=False, parents=False)

        self.train_dataset = EpisodeDataset(directory=self.dataset_dir / 'train', name='train_dataset', cache_in_ram=cfg.training.cache_in_ram)
        self.test_dataset  = EpisodeDataset(directory=self.dataset_dir / 'test' , name='test_dataset', cache_in_ram=False)

        train_env_fn_ = partial(instantiate, config=cfg.env.train)
        test_env_fn_  = partial(instantiate, config=cfg.env.test)
        num_actions = Env(train_env_fn_, num_envs=1).num_actions

        # assert self.cfg.training.should or self.cfg.evaluation.should
        # env = train_env if self.cfg.training.should else test_env

        c = instantiate(cfg.world_model)
        c.num_actions = num_actions
        wm = WorldModel(c).to(self.device)
        ac = ActorCritic(num_actions).to(self.device)

        self.agent = Agent(wm, ac)

        print(f'{sum(p.numel() for p in wm.denoiser.parameters())} parameters in denoiser')
        print(f'{sum(p.numel() for p in wm.rew_end_model.parameters())} parameters in reward/end model')
        print(f'{sum(p.numel() for p in ac.parameters())} parameters in actor_critic.')

        train_env_fn = partial(Env, env_fn=train_env_fn_, num_envs=cfg.collection.train.num_envs, device=self.device)
        test_env_fn = partial(Env, env_fn=test_env_fn_, num_envs=cfg.collection.test.num_envs, device=self.device)
        self.train_collector = make_collector(train_env_fn, model=ac, dataset=self.train_dataset, epsilon=cfg.collection.train.config.epsilon)
        self.test_collector  = make_collector(test_env_fn , model=ac, dataset=self.test_dataset, save_on_disk=False)

        make_batch_sampler = partial(BatchSampler, dataset=self.train_dataset)

        make_data_loader = partial(DataLoader, dataset=self.train_dataset, collate_fn=collate_segments_to_batch,
                              num_workers=0 if cfg.training.cache_in_ram else 4, persistent_workers=not cfg.training.cache_in_ram,
                              pin_memory=True, pin_memory_device=str(self.device))

        ######################################################

        c = cfg.training.denoiser
        
        self.opt_denoiser = configure_optimizer(wm.denoiser, learning_rate=float(c.learning_rate), weight_decay=float(c.weight_decay), eps=float(c.epsilon))
        self.lr_sched_denoiser = get_constant_schedule_with_warmup(optimizer=self.opt_denoiser, num_warmup_steps=c.lr_warmup_steps)
        bs_denoiser = make_batch_sampler(batch_size=c.batch_size, sequence_length=cfg.world_model.num_steps_conditioning + 1   , can_sample_beyond_end=True)
        self.data_loader_denoiser = make_data_loader(batch_sampler=bs_denoiser)

        ######################################################
        
        c = cfg.training.rew_end
        
        self.opt_rew_end  = configure_optimizer(wm.rew_end_model, learning_rate=float(c.learning_rate), weight_decay=float(c.weight_decay), eps=float(c.epsilon))
        self.lr_sched_rew_end  = get_constant_schedule_with_warmup(optimizer=self.opt_rew_end , num_warmup_steps=c.lr_warmup_steps)
        bs_rew_end  = make_batch_sampler(batch_size=c.batch_size , sequence_length=cfg.training.actor_critic.imagination_horizon, can_sample_beyond_end=True)
        self.data_loader_rew_end  = make_data_loader(batch_sampler=bs_rew_end)

        ######################################################

        c = cfg.training.actor_critic

        self.opt_ac = configure_optimizer(ac, learning_rate=float(c.learning_rate), weight_decay=float(c.weight_decay), eps=float(c.epsilon))
        self.lr_sched_ac = get_constant_schedule_with_warmup(optimizer=self.opt_ac, num_warmup_steps=c.lr_warmup_steps)
        
        if cfg.training.model_free:
            env = Env(env_fn=train_env_fn_, num_envs=c.batch_size, device=self.device)
        
        else:
            bs = make_batch_sampler(batch_size=c.batch_size, sequence_length=cfg.world_model.num_steps_conditioning, can_sample_beyond_end=False)
            data_loader_wm_env = make_data_loader(batch_sampler=bs)
            env = WorldModelEnv(wm, data_loader_wm_env, horizon=c.imagination_horizon)
            
            # Compile (full length)
            try:
                c_ = cfg.world_model
                env.step = torch.compile(env.step, mode='reduce-overhead')
                env.reset_from(
                    torch.zeros(c.batch_size, c_.num_steps_conditioning, c_.image_channels, c_.image_size, c_.image_size, device=self.device),
                    torch.zeros(c.batch_size, c_.num_steps_conditioning, dtype=torch.long, device=self.device))
                for _ in range(c.imagination_horizon): 
                    env.step(torch.zeros(c.batch_size, dtype=torch.long, device=self.device))
            except:
                print('Cannot compile.')

        self.rl_step = make_actor_critic_loss_computer(env, ac, c.gamma, c.lambda_, c.entropy_weight, clip_reward=True)

        if cfg.initialization.path_to_ckpt is not None:
            self.agent.load(**cfg.initialization)

        if cfg.common.resume:
            self.load_checkpoint()
        else:
            self.save_checkpoint(0)

    def run(self) -> None:

        for epoch in range(self.start_epoch, 1 + self.cfg.training.epochs):

            print(f"\nEpoch {epoch} / {self.cfg.training.epochs}\n")
            start_time = time.time()
            to_log = []

            if self.cfg.training.should:
                if not self.cfg.training.model_free:
                    c = self.cfg.collection
                    if c.path_to_static_dataset is None and epoch <= c.train.stop_after_epochs:
                        num_steps = c.train.config.num_steps_first_epoch if epoch == 1 else c.train.config.num_steps
                        to_log.extend(self.train_collector.send(num_steps))
                to_log += self.train_agent(epoch)

            if self.cfg.evaluation.should and (epoch % self.cfg.evaluation.every == 0):
                if self.cfg.collection.path_to_static_dataset is None:
                    to_log.extend(self.test_collector.send(None))
                # to_log += self.eval_agent(epoch)

            if self.cfg.training.should:
                self.save_checkpoint(epoch)

            to_log.append({'duration': (time.time() - start_time) / 3600})

            for d in to_log:
                wandb.log({'epoch': epoch, **d})

        self.finish()

    def train_agent(self, epoch: int) -> None:
        self.agent.train()
        self.agent.zero_grad()

        to_log = []
        if not self.cfg.training.model_free and epoch > self.cfg.training.denoiser.start_after_epochs:
            to_log_denoiser = self.train_denoiser(epoch)
            for d in to_log_denoiser: wandb.log({'epoch': epoch, **{'denoiser/train/' + k: v for k, v in d.items()}})
        if not self.cfg.training.model_free and epoch > self.cfg.training.rew_end.start_after_epochs:
            to_log_rew_end = self.train_rew_end(epoch)
            for d in to_log_rew_end: wandb.log({'epoch': epoch, **{'rew_end/train/' + k: v for k, v in d.items()}})
        if epoch > self.cfg.training.actor_critic.start_after_epochs:
            to_log_ac = self.train_ac(epoch)
            for d in to_log_ac: wandb.log({'epoch': epoch, **{'actor_critic/train/' + k: v for k, v in d.items()}})

        return tuple(to_log)

    def train_denoiser(self, epoch: int):
        c = self.cfg.training.denoiser
        denoiser = self.agent.wm.denoiser
        denoiser.train()

        data_iterator = iter(self.data_loader_denoiser)
        to_log = []

        self.opt_denoiser.zero_grad()

        num_steps = c.grad_acc_steps * (c.steps_per_epoch if epoch > 1 else c.steps_first_epoch)

        for i in trange(num_steps, desc='Training denoiser'):

            batch = next(data_iterator).to(self.device)
            loss, metrics = denoiser.compute_loss(batch)
            loss.backward()

            metrics['num_batch_denoiser'] = self.num_batch_denoiser
            self.num_batch_denoiser += 1
            
            if (i + 1) % c.grad_acc_steps == 0:
                
                if c.max_grad_norm is not None:
                    grad_norm = torch.nn.utils.clip_grad_norm_(denoiser.parameters(), c.max_grad_norm)
                    metrics['grad_norm_before_clip']  = grad_norm

                self.opt_denoiser.step()
                self.opt_denoiser.zero_grad()
                
                if self.lr_sched_denoiser is not None:
                    metrics['lr'] = self.lr_sched_denoiser.get_last_lr()[0]
                    self.lr_sched_denoiser.step()
            
            to_log.append(metrics)

        return to_log
    
    def train_rew_end(self, epoch: int):
        c = self.cfg.training.rew_end
        rew_end_model = self.agent.wm.rew_end_model
        rew_end_model.train()

        data_iterator = iter(self.data_loader_rew_end)
        to_log = []

        self.opt_rew_end.zero_grad()

        num_steps = c.grad_acc_steps * (c.steps_per_epoch if epoch > 1 else c.steps_first_epoch)

        for i in trange(num_steps, desc='Training rew_end_model'):

            batch = next(data_iterator).to(self.device)
            loss, metrics = rew_end_model.compute_loss(batch)
            loss.backward()

            metrics['num_batch_rew_end'] = self.num_batch_rew_end
            self.num_batch_rew_end += 1
            
            if (i + 1) % c.grad_acc_steps == 0:
                
                if c.max_grad_norm is not None:
                    grad_norm = torch.nn.utils.clip_grad_norm_(rew_end_model.parameters(), c.max_grad_norm)
                    metrics['grad_norm_before_clip']  = grad_norm

                self.opt_rew_end.step()
                self.opt_rew_end.zero_grad()
                
                if self.lr_sched_rew_end is not None:
                    metrics['lr'] = self.lr_sched_rew_end.get_last_lr()[0]
                    self.lr_sched_rew_end.step()
            
            to_log.append(metrics)

        return to_log

    def train_ac(self, epoch: int):
        c = self.cfg.training.actor_critic
        ac = self.agent.ac
        ac.train()

        to_log = []

        self.opt_ac.zero_grad()

        num_steps = c.grad_acc_steps * (c.steps_per_epoch if epoch > 1 else c.steps_first_epoch)

        for i in trange(num_steps, desc='Training actor_critic'):

            loss, metrics = self.rl_step.send(c.num_steps)
            loss.backward()

            metrics['num_batch_ac'] = self.num_batch_ac 
            self.num_batch_ac += 1

            if (i + 1) % c.grad_acc_steps == 0:
                if c.max_grad_norm is not None:
                    grad_norm = torch.nn.utils.clip_grad_norm_(ac.parameters(), c.max_grad_norm)
                    metrics['grad_norm_before_clip'] = grad_norm

                self.opt_ac.step()
                self.opt_ac.zero_grad()

                if self.lr_sched_ac is not None:
                    metrics['lr'] = self.lr_sched_ac.get_last_lr()[0]
                    self.lr_sched_ac.step()

            to_log.append(metrics)

        return to_log

    @torch.no_grad()
    def eval_agent(self, epoch: int) -> None:
        self.agent.eval()

        to_log_denoiser = {}

        c = self.cfg.evaluation.world_model

        if epoch > c.start_after_epochs:
            raise NotImplemented
            to_log_denoiser = self.eval_component(self.agent.diffuser, cfg_diffuser.batch_size, sequence_length=1)

        # if c.save_reconstructions:
        #     batch = next(iter()).to(self.agent.device)
        #     past_obs, past_act = batch.observations[:, :-1], batch.actions[:, :-1]
        #     _g = partial(self.agent.world_model.denoiser.generate_next_obs, past_obs=past_obs, past_act=past_act, generator=torch.manual_seed(self.cfg.common.seed))
        #     def g(*args, **kwargs): return _g(*args, **kwargs).cpu().unsqueeze(1)
        #     next_obs   = g()
        #     all_ = torch.cat((past_obs.cpu(), next_obs), dim=1)
        #     torchvision.utils.save_image(rearrange(all_, 'b t c h w -> (b t) c h w'), self.images_dir / f'images_{epoch:04d}.png', nrow=all_.size(1))

        return to_log_denoiser,

    @torch.no_grad()
    def eval_component(self, component: nn.Module, batch_num_samples: int, sequence_length: int, **kwargs_loss: Any) -> Dict[str, float]:
        raise NotImplementedError
        to_log = defaultdict(float)
        loader = DatasetTraverser(self.test_dataset, batch_num_samples, sequence_length)

        for batch in tqdm(loader, desc=f"Evaluating {component}", file=sys.stdout):

            batch = batch.to(self.device)
            losses, metrics = component.compute_loss(batch, **kwargs_loss)

            for k, v in {**losses.all_losses, **metrics}.items():
                to_log[f'{component}/eval/{k}'] += v

        to_log = {k: v / len(loader) for k, v in to_log.items()}

        return to_log

    def _save_training_state(self, epoch: int) -> None:
        ckpt = {
            'epoch': epoch,
            'num_batch_denoiser': self.num_batch_denoiser,
            'num_batch_rew_end': self.num_batch_rew_end,
            'num_batch_ac': self.num_batch_ac,
            'agent': self.agent.state_dict(),
            'opt_denoiser': self.opt_denoiser.state_dict(),
            'opt_rew_end': self.opt_rew_end.state_dict(),
            'opt_ac': self.opt_ac.state_dict(),
            'lr_sched_denoiser': self.lr_sched_denoiser.state_dict() if self.lr_sched_denoiser is not None else None,
            'lr_sched_rew_end': self.lr_sched_rew_end.state_dict() if self.lr_sched_rew_end is not None else None,
            'lr_sched_ac': self.lr_sched_ac.state_dict() if self.lr_sched_ac is not None else None,
        }
        torch.save(ckpt, self.ckpt_dir / 'last.pt')
        if self.cfg.collection.path_to_static_dataset is None:
            self.train_dataset.save_info()
            self.test_dataset.save_info()
    
    def _save_agent(self, epoch: int):
        every = self.cfg.training.save_agent_every
        assert every == -1 or every > 0
        d = self.ckpt_dir / 'all'
        if every == -1:
            shutil.rmtree(d, ignore_errors=True)
        if (every == -1) or (epoch % every == 0):
            d.mkdir(exist_ok=True)
            torch.save(self.agent.state_dict(), d / f'agent_epoch_{epoch:05d}.pt')

    def _save_info(self, epoch: int):
        torch.save({'epoch': epoch, 'name': self.cfg.wandb.name}, self.ckpt_dir / 'info_for_import_script.pt')

    def save_checkpoint(self, epoch: int) -> None:
        tmp_checkpoint_dir = Path('checkpoints_tmp')
        shutil.copytree(src=self.ckpt_dir, dst=tmp_checkpoint_dir)
        self._save_training_state(epoch)
        self._save_agent(epoch)
        self._save_info(epoch)
        shutil.rmtree(tmp_checkpoint_dir)

    def load_checkpoint(self) -> None:
        ckpt = torch.load(self.ckpt_dir / 'last.pt', map_location=self.device)
        self.start_epoch = ckpt['epoch'] + 1
        self.num_batch_denoiser = ckpt['num_batch_denoiser']
        self.num_batch_rew_end = ckpt['num_batch_rew_end']
        self.num_batch_ac = ckpt['num_batch_ac']
        self.agent.load_state_dict(ckpt['agent'])
        self.opt_denoiser.load_state_dict(ckpt['opt_denoiser']),
        self.opt_rew_end.load_state_dict(ckpt['opt_rew_end']),
        self.opt_ac.load_state_dict(ckpt['opt_ac']),
        if self.lr_sched_denoiser is not None: self.lr_sched_denoiser.load_state_dict(ckpt['lr_sched_denoiser'])
        if self.lr_sched_rew_end is not None: self.lr_sched_rew_end.load_state_dict(ckpt['lr_sched_rew_end'])
        if self.lr_sched_ac is not None: self.lr_sched_ac.load_state_dict(ckpt['lr_sched_ac'])
        print(f'Successfully loaded checkpoint from {self.ckpt_dir.absolute()}.')

    def finish(self) -> None:
        wandb.finish()

