import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from tqdm import tqdm
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from ..tasks import build_task
from ..models import build_model
from . import register_pipe
from .base_pipe import BasePipe
from ..utils.early_stop import EarlyStopping
from ..utils.plot import plot_confusion_matrix, plot_tsne
from ..utils import SignalDiffusion
from sklearn.metrics import roc_auc_score

@register_pipe("modulation_classification")
class ModulationClassification(BasePipe):
    def __init__(self, args):
        super(ModulationClassification, self).__init__(args)
        self.model_name = args.model
        self.model = build_model(args.model).build_model_from_args(args).to(args.device)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f'{total_params:,} total parameters.')
        total_trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'{total_trainable_params:,} training parameters.')
        # self.model = torch.compile(self.model, mode="max-autotune")
        if args.load_from_pretrained:
            self.load_from_pretrained()
        if hasattr(args, "compile_flag"):
            if args.compile_flag:
                self.compile()
        if args.use_distribute:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.device])

        print("-----------------------load model done-----------------------")

        task_name = "modulation_classification"
        self.task = build_task(args, task_name)
        self.loss_fn = self.task.get_loss_func()
        train_dataset, val_dataset, test_dataset = self.task.get_data()
        if args.model == "RF_Diffusion":
            self.diffusion = SignalDiffusion(self.args.batch_size, args.extra_dim, args.max_step,
                            args.blur_noise, args.min_noise, args.max_noise)
            self.classifier = nn.Linear(args.hidden_dim * 2, 2).to(args.device)
            self.optimizer = self.candidate_optimizer[args.optimizer]([{"params": self.model.parameters()},
                                                                       {"params": self.classifier.parameters()}],
                                                                        lr=args.lr, weight_decay=args.weight_decay)
        elif args.model == "SpectrumFM":
            self.gru = nn.GRU(args.hidden_dim, args.hidden_dim, batch_first=True, bidirectional=True).to(args.device)
            self.fc = nn.Linear(args.hidden_dim * 2, 2).to(args.device)
            self.dropout = nn.Dropout(0.2).to(args.device)
            self.optimizer = self.candidate_optimizer[args.optimizer]([{"params": self.model.parameters()},
                                                                       {"params": self.gru.parameters()},
                                                                       {"params": self.fc.parameters()}],
                                                                       lr=args.lr, weight_decay=args.weight_decay)
        else:
            self.optimizer = self.candidate_optimizer[args.optimizer](self.model.parameters(),
                                                                        lr=args.lr, weight_decay=args.weight_decay)
        self.train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False)
        self.val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
        self.test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)

        self.scaler = torch.cuda.amp.GradScaler()
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', factor=0.5, patience=3, verbose=True, min_lr=5e-5)
        self.SNR_list = self.task.get_snr()
        self.output_dir = args.output_dir
        if not hasattr(args, "plot"):
            self.plot = False
        else:
            self.plot = args.plot

    def train(self):
        stopper = EarlyStopping(self.args.patience, self._checkpoint, self.args.compile_flag, self.args.use_distribute)
        iters_per_epoch = len(self.train_loader)
        for epoch in range(self.args.num_epochs):
            self.model.train()
            train_loss, train_acc, train_true, train_pred = self._train_step()
            print(f"Epoch:{epoch}, train_loss={train_loss}, train_acc={train_acc['Avg']}")
            print(train_acc)

            if epoch % self.args.evaluate_interval == 0:
                loss, acc, true, pred, val_snr = self._test_step("val")
                print(f"Epoch:{epoch}, val_loss={loss}, val_acc={acc['Avg']}")
                early_stop = stopper.loss_step(loss, self.model)

            if early_stop:
                print("Early Stop!\tEpoch:" + str(epoch))
                break

        stopper.load_model(self.model)
        test_loss, test_acc, test_true, test_pred, test_SNR = self._test_step("test")
        performance = test_acc["Avg"]
        with open("performance.txt", "a") as f:
            import json
            f.write(self.model_name + "\n")
            json.dump(test_acc, f, ensure_ascii=False, indent=4)
            f.write("\n")
        print(f"test acc={performance}")
        print(test_acc)

        return test_acc

    def _train_step(self):
        SIR = dict([(key, 0) for key in self.SIR_list])
        # SNR_true = dict([(key, 0) for key in self.SNR_list])
        y_true = []
        y_pred = []
        num_total = 0
        loss = 0.0
        ssim_list = []
        groups = defaultdict(lambda: {"pred": [], "label": []})
        for _, data in enumerate(tqdm(self.train_loader)):
            batch_x, batch_stft, batch_y, batch_SIR = data
            num_sample = batch_x.size(0)
            num_total += num_sample
            batch_SIR = batch_SIR.numpy().tolist()
            batch_y = batch_y.to(self.device)

            if self.model_name == "RF_Diffusion":
                # if num_sample < self.args.batch_size:
                #     continue
                batch_x = batch_x.transpose(1, 2).to(self.device)
                # pred = self.diffusion.native_sampling(self.model, batch_x, self.device)
                # data_samples = torch.view_as_complex(batch_x.contiguous()).unsqueeze(0)
                # pred_samples = torch.view_as_complex(pred.contiguous()).unsqueeze(0)
                # cur_ssim = eval_ssim(pred_samples, data_samples, self.args.batch_size, self.args.input_dim, device=self.device)
                # ssim_list.append(cur_ssim)
                # continue

                t = torch.tensor([self.args.max_step - 1], dtype=torch.int64)
                # t = torch.randint(0, self.args.max_step, (1,), dtype=torch.int64)
                if num_sample < self.args.batch_size:
                    padding = torch.zeros((self.args.batch_size - num_sample, self.args.input_dim, 2)).to(batch_x.device)
                    batch_x = torch.cat((batch_x, padding), dim=0)

                # x = self.diffusion.degrade_fn(batch_x, t)
                x = self.model.p_embed(batch_x)
                t = self.model.t_embed(t)
                for block in self.model.blocks:
                    x = block(x, t)

                # batch_out = torch.concat([output[:, :, 0], output[:, :, 1]], dim=1)
                batch_out = x.transpose(2, 1)
                batch_out = torch.concat(([batch_out[:, 0], batch_out[:, 0]]), dim=1)
                batch_out = self.classifier(batch_out)[:num_sample]
            elif self.model_name == "IQFormer":
                batch_x = batch_x.to(self.device)
                batch_stft = batch_stft.to(self.device)
                batch_out = self.model((batch_x, batch_stft))
            elif self.model_name == "SpectrumFM":
                batch_x = batch_x.transpose(1, 2).to(self.device)
                if num_sample < self.args.batch_size:
                    padding = torch.zeros((self.args.batch_size - num_sample, self.args.signal_length, 2)).to(batch_x.device)
                    batch_x = torch.cat((batch_x, padding), dim=0)
                emb = self.model(batch_x)
                output, _ = self.gru(emb)
                output = output[:, -1, :]
                output = self.dropout(output)
                batch_out = self.fc(output)[:num_sample]
            elif self.model_name == "DAE":
                batch_x = batch_x.transpose(1, 2).to(self.device)
                batch_out, rec = self.model(batch_x)
                rec_loss = F.mse_loss(rec, batch_x)
            else:
                batch_x = batch_x.transpose(1, 2).to(self.device)
                batch_out = self.model(batch_x)
            
            batch_loss = self.loss_fn(batch_out, batch_y)
            print(batch_loss)
            if self.model_name == "DAE":
                batch_loss = batch_loss * 0.5 + 0.5 * rec_loss

            train_pred = batch_out.cpu().detach()
            pos_score = train_pred[:, 1] 

            for p, l, i in zip(pos_score, batch_y, batch_SIR):
                groups[i]["pred"].append(p.item())
                groups[i]["label"].append(l.item())
                groups["Avg"]["pred"].append(p.item())
                groups["Avg"]["label"].append(l.item())

            # train_pred = batch_out.cpu().detach().numpy()
            # train_pred = train_pred.argmax(1).tolist()
            # train_true = batch_y.cpu().detach().numpy().tolist()

            # y_true.extend(train_true)
            # y_pred.extend(train_pred)

            # for slice in range(num_sample):
            #     if (type(batch_SNR[slice])).__name__ == 'list':
            #         batch_SNR[slice] = batch_SNR[slice][0]
            #     if train_pred[slice] == train_true[slice]:
            #         SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1
            #         SNR_true[batch_SNR[slice]] = SNR_true.get(batch_SNR[slice]) + 1
            #     else:
            #         SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1

            loss += (batch_loss.item() * num_sample)
            self.optimizer.zero_grad()
            batch_loss.backward()
            self.optimizer.step()

        for k in groups:
            groups[k]["pred"]  = torch.tensor(groups[k]["pred"]).numpy()
            groups[k]["label"] = torch.tensor(groups[k]["label"]).numpy()

        loss /= num_total
        avg_true = 0
        avg_all = 0
        clean_pred = groups[-1]["pred"]
        clean_label = groups[-1]["label"]
        for key in self.SIR_list:
            if key == -1:
                continue
            
            pred = groups[key]["pred"]
            label = groups[key]["label"]
            SIR[key] = roc_auc_score(np.concatenate((clean_label, label)),
                                     np.concatenate((clean_pred, pred)))
        SIR["Avg"] = roc_auc_score(groups["Avg"]["label"], groups["Avg"]["pred"])
        # for key in self.SNR_list:
        #     avg_all += SNR[key]
        #     avg_true += SNR_true[key]
        #     SNR[key] = SNR_true[key] / float(SNR[key])
        # SNR['Avg'] = avg_true / float(avg_all)

        
        return loss, SIR, y_true, y_pred

    def _test_step(self, mode):
        self.model.eval()
        SIR = dict([(key, 0) for key in self.SIR_list])
        SNR_true = dict([(key, 0) for key in self.SNR_list])
        y_true = []
        y_pred = []
        eval_SIR = []
        groups = defaultdict(lambda: {"pred": [], "label": []})
        
        num_total = 0
        loss = 0.0
        if "val" in mode:
            loader = self.val_loader
        elif "test" in mode:
            loader = self.test_loader

        with torch.no_grad():
            for _, data in enumerate(loader):
                batch_x, batch_stft, batch_y, batch_SIR = data
                num_sample = batch_x.size(0)
                num_total += num_sample
                batch_x = batch_x.to(self.device)
                batch_stft = batch_stft.to(self.device)           
                batch_SIR = batch_SIR.numpy().tolist()
                batch_y = batch_y.to(self.device)

                if self.model_name == "RF_Diffusion":
                    batch_x = batch_x.transpose(1, 2).to(self.device)
                    t = (self.args.max_step - 1) * torch.ones((1, ), dtype=torch.int64)
                    if num_sample < self.args.batch_size:
                        padding = torch.zeros((self.args.batch_size - num_sample, self.args.input_dim, 2)).to(batch_x.device)
                        batch_x = torch.cat((batch_x, padding), dim=0)

                    # x = self.diffusion.degrade_fn(batch_x, t)
                    x = self.model.p_embed(batch_x)
                    t = self.model.t_embed(t)
                    for block in self.model.blocks:
                        x = block(x, t)

                    batch_out = x.transpose(2, 1)
                    batch_out = torch.concat(([batch_out[:, 0], batch_out[:, 0]]), dim=1)
                    batch_out = self.classifier(batch_out)[:num_sample]
                elif self.model_name == "IQFormer":
                    batch_x = batch_x.to(self.device)
                    batch_stft = batch_stft.to(self.device)
                    batch_out = self.model((batch_x, batch_stft))
                elif self.model_name == "SpectrumFM":
                    batch_x = batch_x.transpose(1, 2).to(self.device)
                    if num_sample < self.args.batch_size:
                        padding = torch.zeros((self.args.batch_size - num_sample, self.args.signal_length, 2)).to(batch_x.device)
                        batch_x = torch.cat((batch_x, padding), dim=0)
                    emb = self.model(batch_x)
                    output, _ = self.gru(emb)
                    output = output[:, -1, :]
                    output = self.dropout(output)
                    batch_out = self.fc(output)[:num_sample]
                elif self.model_name == "DAE":
                    batch_x = batch_x.transpose(1, 2).to(self.device)
                    batch_out, rec = self.model(batch_x)
                    rec_loss = F.mse_loss(rec, batch_x)
                else:
                    batch_x = batch_x.transpose(1, 2).to(self.device)
                    batch_out = self.model(batch_x)

                batch_loss = self.loss_fn(batch_out, batch_y)
                if self.model_name == "DAE":
                    batch_loss = batch_loss * 0.5 + 0.5 * rec_loss

                train_pred = batch_out.cpu().detach()
                pos_score = train_pred[:, 1] 

                for p, l, i in zip(pos_score, batch_y, batch_SIR):
                    groups[i]["pred"].append(p.item())
                    groups[i]["label"].append(l.item())
                    groups["Avg"]["pred"].append(p.item())
                    groups["Avg"]["label"].append(l.item())
                # train_pred = batch_out.cpu().detach().numpy()
                # train_pred = train_pred.argmax(1).tolist()
                # train_true = batch_y.cpu().detach().numpy().tolist()
                # y_true.extend(train_true)
                # y_pred.extend(train_pred)
                # eval_SNR.extend(batch_SNR)

                # for slice in range(num_sample):
                #     if isinstance(batch_SNR[slice], list):
                #         batch_SNR[slice] = batch_SNR[slice][0]
                #     if train_pred[slice] == train_true[slice]:
                #         SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1
                #         SNR_true[batch_SNR[slice]] = SNR_true.get(batch_SNR[slice]) + 1
                #     else:
                #         SNR[batch_SNR[slice]] = SNR.get(batch_SNR[slice]) + 1

                loss += (batch_loss.item() * num_sample)

        loss /= num_total
        self.scheduler.step(loss)
        clean_pred = groups[-1]["pred"]
        clean_label = groups[-1]["label"]
        for key in self.SIR_list:
            if key == -1:
                continue
            
            pred = groups[key]["pred"]
            label = groups[key]["label"]
            SIR[key] = roc_auc_score(np.concatenate((clean_label, label)),
                                     np.concatenate((clean_pred, pred)))
        SIR["Avg"] = roc_auc_score(groups["Avg"]["label"], groups["Avg"]["pred"])
        # for key in self.SNR_list:
        #     avg_all += SNR[key]
        #     avg_true += SNR_true[key]
        #     SNR[key] = SNR_true[key] / float(SNR[key])
        # SNR['Avg'] = avg_true / float(avg_all)
        
        return loss, SIR, y_true, y_pred, eval_SIR
