import copy
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from ..tasks import build_task
from ..models import build_model
from . import register_pipe
from .base_pipe import BasePipe
from ..utils import Diffusion, plot_tsne
from ..utils.early_stop import EarlyStopping
from ..utils import plot_confusion_matrix


class ClassifierDict(nn.Module):
    def __init__(self, model, feat_func, time_list, name_list, base_lr, epoch, signal_length, num_classes, device):
        super(ClassifierDict, self).__init__()
        self.best_loss = None
        self.feat_func = feat_func
        self.times = time_list
        self.names = name_list
        self.classifiers = nn.ModuleDict()
        self.optims = {}
        self.schedulers = {}
        self.loss_fn = nn.CrossEntropyLoss()

        for time in self.times:
            feats = self.feat_func(torch.zeros(1, *[2, signal_length]).to(device), time)
            if self.names is None:
                self.names = list(feats.keys()) # all available names

            for name in self.names:
                key = self.make_key(time, name)
                layers = nn.Linear(feats[name].shape[1], num_classes)
                optimizer = torch.optim.AdamW([{"params": model.parameters()},
                                              {"params": layers.parameters()}], lr=base_lr)
                scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.8, patience=300, verbose=True, min_lr=1e-12)
                self.classifiers[key] = layers
                self.optims[key] = optimizer
                self.schedulers[key] = scheduler

    def train(self, x, y):
        self.classifiers.train()
        outputs = {}
        for time in self.times:
            feats = self.feat_func(x, time)
            for name in self.names:
                key = self.make_key(time, name)
                # representation = feats[name].detach()
                representation = feats[name]
                logit = self.classifiers[key](representation)
                pred = logit.argmax(dim=-1)
                outputs[key] = pred
                loss = self.loss_fn(logit, y)
                self.optims[key].zero_grad()
                loss.backward()
                self.optims[key].step()
                self.optimizer = self.optims[key]
        for time in self.times:
            for name in self.names:
                key = self.make_key(time, name)
                self.schedulers[key].step(loss.item())
        print(f"---loss:{loss.item()}---lr:{self.optimizer.param_groups[0]['lr']}---")
        return outputs, loss
    
    def test(self, x, y):
        outputs = {}
        with torch.no_grad():
            self.classifiers.eval()
            for time in self.times:
                feats = self.feat_func(x, time)
                for name in self.names:
                    key = self.make_key(time, name)
                    representation = feats[name].detach()
                    logit = self.classifiers[key](representation)
                    pred = logit.argmax(dim=-1)
                    outputs[key] = pred
                    loss = self.loss_fn(logit, y)
        return outputs, loss

    def make_key(self, t, n):
        return str(t) + '/' + n

    def get_lr(self):
        key = self.make_key(self.times[0], self.names[0])
        optim = self.optims[key]
        return optim.param_groups[0]['lr']

    def schedule_step(self):
        for time in self.times:
            for name in self.names:
                key = self.make_key(time, name)
                self.schedulers[key].step()
    
    def loss_step(self, loss):
        if isinstance(loss, torch.Tensor):
            loss = loss.item()
        if self.best_loss is None:
            self.best_loss = loss
            self.save_model()
        else:
            if loss < self.best_loss:
                self.save_model()
            self.best_loss = np.min((loss, self.best_loss))
    
    def save_model(self):
        self.best_classifier = copy.deepcopy(self.classifiers)
    
    def load_model(self):
        return self.best_classifier

@register_pipe("ddae_trainer")
class DDAETrainer(BasePipe):
    def __init__(self, args):
        super(DDAETrainer, self).__init__(args)
        self.model = build_model("DDAE_Network").build_model_from_args(args).to(args.device)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f'{total_params:,} total parameters.')
        total_trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'{total_trainable_params:,} training parameters.')
        parent_dir = osp.dirname(self._checkpoint)
        if args.load_from_pretrained:
            self.load_from_pretrained()
        if hasattr(args, "compile_flag"):
            if args.compile_flag:
                self.compile()
        if args.use_distribute:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.device])
        print("-----------------------load model done-----------------------")
        self.max_step = args.max_step
        self.optimizer = self.candidate_optimizer[args.optimizer](self.model.parameters(),
                                                                  lr=args.lr, weight_decay=args.weight_decay)
        task_name = "signal_prediction"
        self.pretrain_task = build_task(args, task_name)
        self.pretrain_loss_fn = self.pretrain_task.get_loss_func()
        train_dataset = self.pretrain_task.get_pretrain_data()

        self.dataloader = DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       pin_memory=False if args.device=="cpu" else True,
                                       sampler=DistributedSampler(train_dataset) if args.use_distribute else None,
                                       drop_last=True)

        task_name = "modulation_classification"
        self.task = build_task(args, task_name)
        self.loss_fn = self.task.get_loss_func()
        train_dataset, val_dataset, test_dataset = self.task.get_data()
        self.classes = self.task.get_classes()
        self.train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False)
        self.val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
        self.test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
        self.diffusion = Diffusion(self.max_step, self.args.min_noise, self.args.max_noise, args.device)

        self.classifier = ClassifierDict(self.model, self.get_feature, [args.linear["timestep"]], [args.linear["blockname"]],
                                         args.linear["lrate"], args.linear["n_epoch"], args.signal_length,
                                         len(self.classes), self.device).to(self.device)

        # self.scaler = torch.cuda.amp.GradScaler()
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', factor=0.5, patience=250, verbose=True, min_lr=1e-12)
        self.SNR_list = list(range(args.minSNR, args.maxSNR + 1, 2))
        self.output_dir = args.output_dir
        if not hasattr(args, "plot"):
            self.plot = False
        else:
            self.plot = args.plot

    def train(self):
        stopper = EarlyStopping(self.args.patience, self._checkpoint, self.args.compile_flag, self.args.use_distribute)
        self.model.train()
        early_stop = torch.tensor(False, dtype=torch.bool, device=self.args.device)
        loss_list = []
        for epoch in range(15):
            for i, (data, _, stft_data, _) in enumerate(tqdm(self.dataloader)):
                data = data.to(self.device, non_blocking=True)
                t = torch.randint(0, self.max_step, (data.shape[0],), dtype=torch.int64).to(self.device)
                x_noised, noise = self.diffusion.q_sample(data, t)
                
                out = self.model(x_noised, t / self.max_step)
                loss = self.pretrain_loss_fn(noise, out)
                print(loss)
                loss_list.append(loss.detach().cpu().numpy())
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(parameters=self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
                # self.scaler.scale(loss).backward()
                # self.scaler.step(self.optimizer)
                # self.scaler.update()
                self.scheduler.step(loss.item())
                
                if i % self.args.evaluate_interval == 0:
                    print(f"---loss:{loss.item()}---lr:{self.optimizer.param_groups[0]['lr']}---")
                early_stop = torch.tensor(stopper.loss_step(loss, self.model), dtype=torch.bool, device=self.args.device)
                if early_stop:
                    break

            if early_stop:
                break

        self.model.eval()
        stopper = EarlyStopping(self.args.linear["patience"], self._checkpoint, self.args.compile_flag, self.args.use_distribute)
        iters_per_epoch = len(self.train_loader)
        for epoch in range(self.args.linear["n_epoch"]):
            self.model.train()
            train_loss, train_acc, train_true, train_pred = self._train_step()
            print(f"Epoch:{epoch}, train_loss={train_loss}, train_acc={train_acc['Avg']}")
            print(train_acc)
            if self.plot:
                plot_confusion_matrix(train_true, train_pred, self.dataset_name, "all", self.output_dir, self.classes)

            if epoch % self.args.evaluate_interval == 0:
                loss, acc, true, pred, val_snr = self._test_step("val")
                print(f"Epoch:{epoch}, val_loss={loss}, val_acc={acc['Avg']}")
                early_stop = stopper.loss_step(loss, self.model)
                self.classifier.loss_step(loss)
                if self.plot:
                    plot_confusion_matrix(true, pred, self.dataset_name, "all", self.output_dir, self.classes)

            if early_stop:
                print("Early Stop!\tEpoch:" + str(epoch))
                break

        stopper.load_model(self.model)
        self.classifier.load_model()
        test_loss, test_acc, test_true, test_pred, test_SNR = self._test_step("test")
        performance = test_acc["Avg"]
        print(f"test acc={performance}")
        print(test_acc)
        if self.plot:
            mod_dic = {}
            for snr in self.SNR_list:
                SNR_cm = [i for i in zip(test_SNR, pred, true) if i[0] == snr]
                true_cm = []
                pred_cm = []
                true_cls = np.zeros(len(self.classes))
                all = np.zeros(len(self.classes))
                for i in SNR_cm:
                    pred_cm.append(i[1])
                    true_cm.append(i[2])
                    if i[1] == i[2]:
                        true_cls[i[1]] = true_cls[i[1]] + 1
                        all[i[1]] = all[i[1]] + 1
                    else:
                        all[i[2]] = all[i[2]] + 1
                cls_acc = {cls: x / y for cls, x, y in zip(self.classes, true_cls, all)}
                mod_dic[snr] = list(cls_acc.values())
                # modacc = pd.DataFrame.from_dict(mod_dic, orient='index', columns=classes).reset_index(names='SNR')
                # modacc.to_csv(f'logs/{model_tag}/Test_mod_SNR.csv', index=False)
                plot_confusion_matrix(test_acc, test_true, self.dataset_name, snr, self.output_dir, self.classes)
                SNR_tsne_ = [i for i in
                        zip(test_SNR, torch.stack(test_pred).cpu().data.numpy(), torch.stack(test_true).cpu().data.numpy()) if
                        i[0] == snr]
                _, pred_0, true_0 = zip(*SNR_tsne_)
                plot_tsne(np.array(list(pred_0)), np.array(list(true_0)), self.dataset_name, snr, self.output_dir, self.classes)

    def _train_step(self):
        SNR = dict([(key, 0) for key in self.SNR_list])
        SNR_true = dict([(key, 0) for key in self.SNR_list])
        y_true = []
        y_pred = []
        labels = []
        num_total = 0
        loss = 0.0
        for _, data in enumerate(tqdm(self.train_loader)):
            batch_x, batch_stft, batch_y, batch_SNR = data
            num_sample = batch_x.size(0)
            num_total += num_sample
            batch_SNR = batch_SNR.numpy().tolist()
            batch_y = batch_y.to(self.device)
            batch_x = batch_x.to(self.device)

            outputs, batch_loss = self.classifier.train(batch_x, batch_y)
            preds = {k: [] for k in self.classifier.optims.keys()}
            for key in outputs:
                train_pred = outputs[key].cpu().detach().numpy().tolist()

            # batch_loss = self.loss_fn(batch_out, batch_y)
            print(batch_loss)

            # train_pred = batch_out.cpu().detach().numpy()
            # train_pred = train_pred.argmax(1).tolist()
            train_true = batch_y.cpu().detach().numpy().tolist()
            y_true.extend(train_true)
            y_pred.extend(train_pred)

            for slice in range(num_sample):
                if isinstance(batch_SNR[slice], list):
                    batch_SNR[slice] = batch_SNR[slice][0]
                if train_pred[slice] == train_true[slice]:
                    SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1
                    SNR_true[batch_SNR[slice]] = SNR_true.get(batch_SNR[slice]) + 1
                else:
                    SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1

            loss += (batch_loss.item() * num_sample)
            # self.optimizer.zero_grad()
            # batch_loss.backward()
            # self.optimizer.step()

        loss /= num_total
        avg_true = 0
        avg_all = 0
        for key in self.SNR_list:
            avg_all += SNR[key]
            avg_true += SNR_true[key]
            SNR[key] = SNR_true[key] / float(SNR[key])
        SNR['Avg'] = avg_true / float(avg_all)

        
        return loss, SNR, y_true, y_pred

    def _test_step(self, mode):
        self.model.eval()
        SNR = dict([(key, 0) for key in self.SNR_list])
        SNR_true = dict([(key, 0) for key in self.SNR_list])
        y_true = []
        y_pred = []
        eval_SNR = []
        
        num_total = 0
        loss = 0.0
        if "val" in mode:
            loader = self.val_loader
        elif "test" in mode:
            loader = self.test_loader

        with torch.no_grad():
            for _, data in enumerate(loader):
                batch_x, batch_stft, batch_y, batch_SNR = data
                num_sample = batch_x.size(0)
                num_total += num_sample
                batch_x = batch_x.to(self.device)
                batch_stft = batch_stft.to(self.device)           
                batch_SNR = batch_SNR.numpy().tolist()
                batch_y = batch_y.to(self.device)
                
                outputs, batch_loss = self.classifier.test(batch_x, batch_y)
                preds = {k: [] for k in self.classifier.optims.keys()}
                for key in outputs:
                    train_pred = outputs[key].cpu().detach().numpy().tolist()

                train_true = batch_y.cpu().detach().numpy().tolist()
                y_true.extend(train_true)
                y_pred.extend(train_pred)
                eval_SNR.extend(batch_SNR)

                for slice in range(num_sample):
                    if isinstance(batch_SNR[slice], list):
                        batch_SNR[slice] = batch_SNR[slice][0]
                    if train_pred[slice] == train_true[slice]:
                        SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1
                        SNR_true[batch_SNR[slice]] = SNR_true.get(batch_SNR[slice]) + 1
                    else:
                        SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1

                loss += (batch_loss.item() * num_sample)

        loss /= num_total
        avg_true = 0
        avg_all = 0
        for key in self.SNR_list:
            avg_all += SNR[key]
            avg_true += SNR_true[key]
            SNR[key] = SNR_true[key] / float(SNR[key])
        SNR['Avg'] = avg_true / float(avg_all)
        
        return loss, SNR, y_true, y_pred, eval_SNR

    def get_feature(self, x, t, name=None, norm=False, use_amp=False):
        ''' Get network's intermediate activation in a forward pass.

            Args:
                x: The clean image tensor ranged in `[0, 1]`.
                t: The specified timestep ranged in `[1, n_T]`. Type: int / torch.LongTensor.
                norm: to normalize features to the the unit hypersphere.
            Returns:
                A {name: tensor} dict which contains global average pooled features.
        '''
        # t = torch.randint(0, self.max_step, (x.shape[0],), dtype=torch.int64).to(self.device)
        t = torch.tensor([t], dtype=torch.int64).to(self.device)
        x_noised, noise = self.diffusion.q_sample(x, t)

        def gap_and_norm(act, norm=False):
            # unet (B, C, H, W)
            act = act.view(act.shape[0], act.shape[1], -1).float()
            act = torch.mean(act, dim=2)
            if norm:
                act = torch.nn.functional.normalize(act)
            return act

        with torch.autocast("cuda", enabled=use_amp):
            _, acts = self.model(x_noised, t / self.max_step, ret_activation=True)
        all_feats = {blockname: gap_and_norm(acts[blockname], norm) for blockname in acts}
        
        if name is not None:
            return all_feats[name]
        else:
            return all_feats
        