import copy
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from collections import defaultdict
from tqdm import tqdm
from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from sklearn.metrics import roc_auc_score, roc_curve
from ..tasks import build_task
from ..dataset import build_dataset
from ..models import build_model
from . import register_pipe
from .base_pipe import BasePipe
from ..utils import Time_Freq_Diffusion, plot_tsne
from ..utils.early_stop import EarlyStopping
from ..utils import plot_confusion_matrix
from tqdm import tqdm
from ..utils import plot_signal_spectrum, plot_signal_time, plot_signal_time_freq

@register_pipe("anomaly")
class anomaly(BasePipe):
    def __init__(self, args):
        super(anomaly, self).__init__(args)
        self.model = build_model(args.model).build_model_from_args(args).to(args.device)
        self.norm_epsilon = nn.LayerNorm(args.hidden_dim).to(args.device)
        self.norm_eta = nn.LayerNorm(args.hidden_dim).to(args.device)
        self.act = nn.SiLU().to(args.device)
        self.epsilon = nn.Linear(args.hidden_dim, 2).to(args.device)
        self.eta = nn.Linear(args.hidden_dim, 2).to(args.device)

        # self.model = build_model(args.model).build_model_from_args(args).to(args.device)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f'{total_params:,} total parameters.')
        total_trainable_params = sum(
            p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'{total_trainable_params:,} training parameters.')
        if args.load_from_pretrained:
            self.load_from_pretrained()
        if hasattr(args, "compile_flag"):
            if args.compile_flag:
                self.compile()
        if args.use_distribute:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.device])
        print("-----------------------load model done-----------------------")
        self.max_step = args.max_step
        self.optimizer = self.candidate_optimizer[args.optimizer]([{"params": self.model.parameters()},
                                                                   {"params": self.norm_epsilon.parameters()},
                                                                   {"params": self.norm_eta.parameters()},
                                                                   {"params": self.epsilon.parameters()},
                                                                   {"params": self.eta.parameters()}],
                                                                  lr=args.lr, weight_decay=args.weight_decay)
        # self.optimizer = self.candidate_optimizer[args.optimizer](self.model.parameters(),
        #                                                           lr=args.lr, weight_decay=args.weight_decay)

        dataset = build_dataset("ICARUS", self.args.test_size, self.args.dataset_path)
        IQ_data, label, SIR = dataset().get_pretrain_data
        train_dataset = dataset("train")
        val_dataset = dataset("valid")
        test_dataset = dataset("test")
        self.classes = 2

        self.train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False)
        self.val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
        self.test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
        self.diffusion = Time_Freq_Diffusion(self.max_step, self.args.min_noise, self.args.max_noise, args.ratio, args.device)

        self.classifier = nn.ModuleList([nn.AdaptiveAvgPool1d(1),
                                         nn.AdaptiveAvgPool1d(1),
                                         nn.Sequential(
                                                        nn.Linear(2 * args.hidden_dim, args.hidden_dim),
                                                        nn.Dropout(0.2),
                                                        nn.PReLU(),
                                                        nn.Linear(args.hidden_dim, 2))]).to(args.device)

        self.classifier_optimizer = self.candidate_optimizer[args.optimizer]([{"params": self.model.parameters()},
                                                                             {"params": self.classifier.parameters()}],
                                                                            lr=1e-3, weight_decay=args.weight_decay)
        self.scheduler = ReduceLROnPlateau(self.optimizer, 'min', factor=0.5, patience=250, verbose=True, min_lr=1e-12)
        self.classifier_scheduler = ReduceLROnPlateau(self.classifier_optimizer, 'min', factor=0.5, patience=3, verbose=True, min_lr=1e-6)

        self.SIR_list = np.unique(SIR)
        self.output_dir = args.output_dir
        self.checkpoint = osp.join(self.args.output_dir,
                                                    f"{self.args.model}_Anomaly_pretrain.pt")
        if not hasattr(args, "plot"):
            self.plot = False
        else:
            self.plot = args.plot

    def get_fft_input(self, data):
        I = data[:, 0, :]
        Q = data[:, 1, :]
        s = torch.complex(I, Q)
        freq = torch.fft.fft(s)
        re = torch.real(freq)
        im = torch.imag(freq)
        fft_data = torch.concat([re.unsqueeze(1), im.unsqueeze(1)], dim=1)
        return fft_data

    def train(self):
        stopper = EarlyStopping(self.args.patience, self.checkpoint, 
                                self.args.compile_flag, self.args.use_distribute)
        self.model.train()
        early_stop = torch.tensor(False, dtype=torch.bool, device=self.args.device)
        loss_list = []
        for epoch in range(15):
            for i, (data, _, _, sir) in enumerate(tqdm(self.train_loader)):
                data = data.to(self.device, non_blocking=True)
                max_abs = data.abs().amax(dim=-1, keepdim=True)
                data = data / max_abs
                t = torch.randint(0, self.max_step, (data.shape[0], ), dtype=torch.int64).to(self.device)
                x_noised, epsilon, eta, ratio = self.diffusion.q_sample(data, t)
                fft_data = self.get_fft_input(x_noised)
                out1, out2 = self.model(x_noised, fft_data, t / self.max_step)
                pred_epsilon = self.epsilon(self.act(self.norm_epsilon(out1))).transpose(1, 2)
                pred_eta = self.eta(self.act(self.norm_eta(out2))).transpose(1, 2)
                loss = F.mse_loss(epsilon + ratio * eta, pred_epsilon + ratio * pred_eta)
                print(loss)
                loss_list.append(loss.detach().cpu().numpy())
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # self.scaler.scale(loss).backward()
                # self.scaler.step(self.optimizer)
                # self.scaler.update()
                self.scheduler.step(loss.item())
                
                if i % self.args.evaluate_interval == 0:
                    print(f"---loss:{loss.item()}---lr:{self.optimizer.param_groups[0]['lr']}---")
                early_stop = torch.tensor(stopper.loss_step(loss, self.model), dtype=torch.bool, device=self.args.device)
                if early_stop:
                    break

            if early_stop:
                break

        stopper.load_model(self.model)
        stopper = EarlyStopping(25, self._checkpoint, self.args.compile_flag, self.args.use_distribute)
        iters_per_epoch = len(self.train_loader)
        best_loss = None
        for epoch in range(100):
            self.model.train()
            self.classifier.train()
            train_loss, train_acc, train_true, train_pred = self._train_step()
            print(f"Epoch:{epoch}, train_loss={train_loss}, train_acc={train_acc['Avg']}")
            print(train_acc)
            if self.plot:
                plot_confusion_matrix(train_true, train_pred, self.dataset_name, "all", self.output_dir, self.classes)

            if epoch % self.args.evaluate_interval == 0:
                loss, acc, true, pred, val_sir = self._test_step("val")
                print(f"Epoch:{epoch}, val_loss={loss}, val_acc={acc['Avg']}")
                early_stop = stopper.loss_step(loss, self.model)
                if isinstance(loss, torch.Tensor):
                    loss = loss.item()
                if best_loss is None:
                    best_loss = loss
                    best_classifier = copy.deepcopy(self.classifier)
                else:
                    if loss < best_loss:
                        best_classifier = copy.deepcopy(self.classifier)
                    best_loss = np.min((loss, best_loss))
                
                if self.plot:
                    plot_confusion_matrix(true, pred, self.dataset_name, "all", self.output_dir, self.classes)

            if early_stop:
                print("Early Stop!\tEpoch:" + str(epoch))
                break

        stopper.load_model(self.model)
        # self.classifier = best_classifier
        classifier_path = osp.join(osp.dirname(self._checkpoint), "classifier_" + self.args.dataset[0] + ".pt")
        torch.save(self.classifier.state_dict(), classifier_path)
        # self.classifier.load_state_dict(torch.load(classifier_path), strict=False)
        test_loss, test_acc, test_true, test_pred, test_SIR = self._test_step("test")
        performance = test_acc["Avg"]
        print(f"test acc={performance}")
        print(test_acc)

        if self.plot:
            mod_dic = {}
            for sir in self.SIR_list:
                SIR_cm = [i for i in zip(test_SIR, pred, true) if i[0] == sir]
                true_cm = []
                pred_cm = []
                true_cls = np.zeros(len(self.classes))
                all = np.zeros(len(self.classes))
                for i in SIR_cm:
                    pred_cm.append(i[1])
                    true_cm.append(i[2])
                    if i[1] == i[2]:
                        true_cls[i[1]] = true_cls[i[1]] + 1
                        all[i[1]] = all[i[1]] + 1
                    else:
                        all[i[2]] = all[i[2]] + 1
                cls_acc = {cls: x / y for cls, x, y in zip(self.classes, true_cls, all)}
                mod_dic[sir] = list(cls_acc.values())
                # modacc = pd.DataFrame.from_dict(mod_dic, orient='index', columns=classes).reset_index(names='SIR')
                # modacc.to_csv(f'logs/{model_tag}/Test_mod_SIR.csv', index=False)
                plot_confusion_matrix(test_acc, test_true, self.dataset_name, sir, self.output_dir, self.classes)
                SIR_tsne_ = [i for i in
                        zip(test_SIR, torch.stack(test_pred).cpu().data.numpy(), torch.stack(test_true).cpu().data.numpy()) if
                        i[0] == sir]
                _, pred_0, true_0 = zip(*SIR_tsne_)
                plot_tsne(np.array(list(pred_0)), np.array(list(true_0)), self.dataset_name, sir, self.output_dir, self.classes)
        return test_acc

    def _train_step(self):
        SIR = dict([(key, 0) for key in self.SIR_list])
        SIR_true = dict([(key, 0) for key in self.SIR_list])
        y_true = []
        y_pred = []
        num_total = 0
        loss = 0.0
        ssim_list = []
        groups = defaultdict(lambda: {"pred": [], "label": []})
        for i, data in enumerate(tqdm(self.train_loader)):
            batch_x, _, batch_y, batch_SIR = data
            num_sample = batch_x.size(0)
            num_total += num_sample
            batch_SIR = batch_SIR.numpy().tolist()
            batch_y = batch_y.to(self.device)
            batch_x = batch_x.to(self.device)
            max_abs = batch_x.abs().amax(dim=-1, keepdim=True)
            batch_x = batch_x / max_abs
            t = torch.tensor([self.args.timestep], dtype=torch.int64).to(self.device)
            
            # t = torch.randint(0, self.max_step, (batch_x.shape[0], ), dtype=torch.int64).to(self.device)
            # x_noised, noise = self.diffusion.q_sample(batch_x, t)
            # batch_out = self.model(x_noised, t / self.max_step)[:num_sample].permute(0, 2, 1)
            x_noised, epsilon, eta, ratio = self.diffusion.q_sample(batch_x, t)
            fft_data = self.get_fft_input(x_noised)
            out1, out2 = self.model(x_noised, fft_data, t / self.max_step)
            batch_out1 = self.classifier[0](out1.transpose(1, 2)).squeeze(-1)
            batch_out2 = self.classifier[1](out2.transpose(1, 2)).squeeze(-1)
            batch_out = torch.concat([batch_out1, batch_out2], dim=-1)
            batch_out = self.classifier[2](batch_out)
            batch_loss = F.cross_entropy(batch_out, batch_y)
            print(batch_loss)

            train_pred = batch_out.cpu().detach()
            pos_score = train_pred[:, 1] 

            for p, l, i in zip(pos_score, batch_y, batch_SIR):
                groups[i]["pred"].append(p.item())
                groups[i]["label"].append(l.item())
                groups["Avg"]["pred"].append(p.item())
                groups["Avg"]["label"].append(l.item())

            # train_pred = train_pred.argmax(1).tolist()
            # train_true = batch_y.cpu().detach().numpy().tolist()

            # y_true.extend(train_true)
            # y_pred.extend(train_pred)

            # for slice in range(num_sample):
            #     if (type(batch_SIR[slice])).__name__ == 'list':
            #         batch_SIR[slice] = batch_SIR[slice][0]
            #     if train_pred[slice] == train_true[slice]:
            #         SIR[batch_SIR[slice]] = SIR.get(batch_SIR[slice]) + 1
            #         SIR_true[batch_SIR[slice]] = SIR_true.get(batch_SIR[slice]) + 1
            #     else:
            #         SIR[batch_SIR[slice]] = SIR.get(batch_SIR[slice]) + 1

            loss += (batch_loss.item() * num_sample)
            self.classifier_optimizer.zero_grad()
            batch_loss.backward()
            self.classifier_optimizer.step()
            if i % self.args.evaluate_interval == 0:
                print(f"---loss:{batch_loss.item()}---lr:{self.classifier_optimizer.param_groups[0]['lr']}---")

        for k in groups:
            groups[k]["pred"]  = torch.tensor(groups[k]["pred"]).numpy()
            groups[k]["label"] = torch.tensor(groups[k]["label"]).numpy()

        loss /= num_total
        avg_true = 0
        avg_all = 0
        clean_pred = groups[-1]["pred"]
        clean_label = groups[-1]["label"]
        for key in self.SIR_list:
            if key == -1:
                continue
            
            pred = groups[key]["pred"]
            label = groups[key]["label"]
            SIR[key] = roc_auc_score(np.concatenate((clean_label, label)),
                                     np.concatenate((clean_pred, pred)))
        SIR["Avg"] = roc_auc_score(groups["Avg"]["label"], groups["Avg"]["pred"])
        
        #     avg_all += SIR[key]
        #     avg_true += SIR_true[key]
        #     SIR[key] = SIR_true[key] / float(SIR[key])
        # SIR['Avg'] = avg_true / float(avg_all)

        
        return loss, SIR, y_true, y_pred

    def _test_step(self, mode):
        self.model.eval()
        self.classifier.eval()
        SIR = dict([(str(key), 0) for key in self.SIR_list])
        SIR_true = dict([(key, 0) for key in self.SIR_list])
        false_pr = dict([(str(key), 0) for key in self.SIR_list])
        true_pr = dict([(str(key), 0) for key in self.SIR_list])
        y_true = []
        y_pred = []
        eval_SIR = []
        groups = defaultdict(lambda: {"pred": [], "label": []})

        num_total = 0
        loss = 0.0
        if "val" in mode:
            loader = self.val_loader
        elif "test" in mode:
            loader = self.test_loader

        with torch.no_grad():
            for _, data in enumerate(loader):
                batch_x, batch_stft, batch_y, batch_SIR = data
                num_sample = batch_x.size(0)
                num_total += num_sample
                batch_x = batch_x.to(self.device)
                batch_stft = batch_stft.to(self.device)           
                batch_SIR = batch_SIR.numpy().tolist()
                batch_y = batch_y.to(self.device)
                batch_x = batch_x.to(self.device)
                max_abs = batch_x.abs().amax(dim=-1, keepdim=True)
                batch_x = batch_x / max_abs
                t = torch.tensor([self.args.timestep], dtype=torch.int64).to(self.device)
                # t = torch.randint(0, self.max_step, (data.shape[0], ), dtype=torch.int64).to(self.device)
                x_noised, epsilon, eta, ratio = self.diffusion.q_sample(batch_x, t)
                fft_data = self.get_fft_input(x_noised)
                out1, out2 = self.model(x_noised, fft_data, t / self.max_step)
                batch_out1 = self.classifier[0](out1.transpose(1, 2)).squeeze(-1)
                batch_out2 = self.classifier[1](out2.transpose(1, 2)).squeeze(-1)
                batch_out = torch.concat([batch_out1, batch_out2], dim=-1)
                batch_out = self.classifier[2](batch_out)
                batch_loss = F.cross_entropy(batch_out, batch_y)

                train_pred = batch_out.cpu().detach()
                pos_score = train_pred[:, 1] 

                for p, l, i in zip(pos_score, batch_y, batch_SIR):
                    groups[str(i)]["pred"].append(p.item())
                    groups[str(i)]["label"].append(l.item())
                    groups["Avg"]["pred"].append(p.item())
                    groups["Avg"]["label"].append(l.item())

                # train_pred = batch_out.cpu().detach().numpy()
                # train_pred = train_pred.argmax(1).tolist()
                # train_true = batch_y.cpu().detach().numpy().tolist()
                # y_true.extend(train_true)
                # y_pred.extend(train_pred)
                # eval_SIR.extend(batch_SIR)

                # for slice in range(num_sample):
                #     if isinstance(batch_SIR[slice], list):
                #         batch_SIR[slice] = batch_SIR[slice][0]
                #     if train_pred[slice] == train_true[slice]:
                #         SIR[batch_SIR[slice]] = SIR.get(batch_SIR[slice]) + 1
                #         SIR_true[batch_SIR[slice]] = SIR_true.get(batch_SIR[slice]) + 1
                #     else:
                #         SIR[batch_SIR[slice]] = SIR.get(batch_SIR[slice]) + 1

                loss += (batch_loss.item() * num_sample)

        for k in groups:
            groups[str(k)]["pred"]  = torch.tensor(groups[str(k)]["pred"]).numpy()
            groups[str(k)]["label"] = torch.tensor(groups[str(k)]["label"]).numpy()
        loss /= num_total
        self.classifier_scheduler.step(loss)
        clean_pred = groups[str(-1)]["pred"]
        clean_label = groups[str(-1)]["label"]
        for key in self.SIR_list:
            if key == -1:
                continue
            
            pred = groups[str(key)]["pred"]
            label = groups[str(key)]["label"]
            y_label = np.concatenate((clean_label, label))
            y_pred = np.concatenate((clean_pred, pred))
            SIR[str(key)] = roc_auc_score(np.concatenate((clean_label, label)),
                                     np.concatenate((clean_pred, pred)))
            fpr, tpr, thresholds = roc_curve(y_label, y_pred)
            false_pr[str(key)] = fpr.tolist()
            true_pr[str(key)] = tpr.tolist()
            
        SIR["Avg"] = roc_auc_score(groups["Avg"]["label"], groups["Avg"]["pred"])
        
        
        return loss, SIR, y_true, y_pred, eval_SIR
