import torch
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchaudio.transforms as T
from dataclasses import dataclass
from accelerate import Accelerator
from torch.utils.tensorboard import SummaryWriter
from ..utils.steg import single_val, single_multival, single_multival_cond
from ..utils.steg import single_step, single_multistep, single_multistep_cond
from ..logging.visualization import single_vis, single_multivis, single_multivis_cond
@dataclass
class StegLearner:
    hide: nn.Module
    find: nn.Module
    opt: optim.Optimizer
    writer: SummaryWriter
    accelerator: Accelerator = None
    vis_batch: list = None

    def single_step(self, batch, cur_log_step):
        return single_step(self.hide, self.find, self.opt, self.writer,
            batch, cur_log_step, accelerator=self.accelerator)

    @torch.no_grad()
    def single_vis(self, batch, cur_log_step):
        vis_batch = self.vis_batch
        if vis_batch is None:
            vis_batch = batch
        return single_vis(self.hide, self.find, self.writer, vis_batch,
            cur_log_step, accelerator=self.accelerator)
    
    @torch.no_grad()
    def single_val(self, val_dl, cur_log_step):
        return single_val(self.hide, self.find, self.writer, val_dl,
            cur_log_step, accelerator=self.accelerator)


@dataclass
class StegLearnerMulti:
    hide: nn.Module
    find: nn.Module
    max_audio_len: int
    opt: optim.Optimizer
    writer: SummaryWriter
    accelerator: Accelerator = None
    vis_batch: list = None

    def single_step(self, batch, cur_log_step):
        spec_len = random.randint(1, self.max_audio_len)
        return single_multistep(self.hide, self.find, self.opt, self.writer,
            batch, spec_len, cur_log_step, accelerator=self.accelerator)

    @torch.no_grad()
    def single_vis(self, batch, cur_log_step):
        vis_batch = self.vis_batch
        if vis_batch is None:
            vis_batch = batch
        
        for spec_len in range(1, self.max_audio_len+1):
            single_multivis(self.hide, self.find, self.writer, vis_batch,
                spec_len, cur_log_step, accelerator=self.accelerator)
    
    @torch.no_grad()
    def single_val(self, val_dl, cur_log_step, device):
        return single_multival(self.hide, self.find, self.writer, val_dl,
            self.max_audio_len, cur_log_step, accelerator=self.accelerator,
            device=device)


@dataclass
class StegLearnerMultiCond:
    hide: nn.Module
    find: nn.Module
    max_audio_len: int
    cond_range: tuple
    opt: optim.Optimizer
    writer: SummaryWriter
    accelerator: Accelerator = None
    vis_batch: list = None

    def single_step(self, batch, cur_log_step):
        # cur_cond is sampled at a log-uniform distribution
        cur_cond = 10.0**random.uniform(*self.cond_range)
        spec_len = random.randint(1, self.max_audio_len)
        return single_multistep_cond(self.hide, self.find, self.opt, self.writer,
            batch, spec_len, cur_cond, cur_log_step, accelerator=self.accelerator)

    @torch.no_grad()
    def single_vis(self, batch, cur_log_step):
        vis_batch = self.vis_batch
        if vis_batch is None:
            vis_batch = batch
        
        for spec_len in range(1, self.max_audio_len+1):
            for cur_cond in 10.0**np.linspace(*self.cond_range, 5):
                single_multivis_cond(self.hide, self.find, self.writer, vis_batch,
                    spec_len, cur_cond, cur_log_step, accelerator=self.accelerator)
    
    @torch.no_grad()
    def single_val(self, val_dl, cur_log_step, device):
        conds = 10.0**np.linspace(*self.cond_range, 5)
        return single_multival_cond(self.hide, self.find, self.writer, val_dl, 
            conds, self.max_audio_len, cur_log_step, accelerator=self.accelerator,
            device=device)